Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
04fed4c2af
@ -1,54 +1,28 @@
|
|||||||
BasedOnStyle: LLVM
|
BasedOnStyle: Google
|
||||||
IndentWidth: 4
|
IndentWidth: 4
|
||||||
UseTab: Never
|
UseTab: Never
|
||||||
|
ColumnLimit: 120
|
||||||
|
|
||||||
Language: Cpp
|
Language: Cpp
|
||||||
Standard: Cpp11
|
Standard: Cpp11
|
||||||
|
|
||||||
AccessModifierOffset: -4
|
AccessModifierOffset: -4
|
||||||
|
AlignConsecutiveMacros: true
|
||||||
AlignConsecutiveDeclarations: false
|
AllowAllArgumentsOnNextLine: false
|
||||||
AlignConsecutiveAssignments: false
|
AllowAllConstructorInitializersOnNextLine: false
|
||||||
AlignTrailingComments: true
|
AllowAllParametersOfDeclarationOnNextLine: false
|
||||||
|
AllowShortFunctionsOnASingleLine: Empty
|
||||||
AllowShortBlocksOnASingleLine: true
|
AllowShortIfStatementsOnASingleLine: Never
|
||||||
AllowShortCaseLabelsOnASingleLine: true
|
AllowShortLambdasOnASingleLine: Empty
|
||||||
AllowShortFunctionsOnASingleLine: Inline
|
AllowShortLoopsOnASingleLine: false
|
||||||
|
AlwaysBreakBeforeMultilineStrings: false
|
||||||
AlwaysBreakBeforeMultilineStrings: true
|
|
||||||
AlwaysBreakTemplateDeclarations: true
|
|
||||||
|
|
||||||
BinPackArguments: false
|
BinPackArguments: false
|
||||||
BinPackParameters: false
|
BinPackParameters: false
|
||||||
|
CommentPragmas: '^#'
|
||||||
BreakBeforeBraces: Allman
|
DerivePointerAlignment: false
|
||||||
BreakConstructorInitializersBeforeComma: true
|
|
||||||
|
|
||||||
ColumnLimit: 100
|
|
||||||
|
|
||||||
IndentCaseLabels: false
|
|
||||||
IndentWrappedFunctionNames: true
|
|
||||||
|
|
||||||
KeepEmptyLinesAtTheStartOfBlocks: false
|
|
||||||
NamespaceIndentation: All
|
|
||||||
|
|
||||||
PointerAlignment: Left
|
|
||||||
SpaceAfterCStyleCast: false
|
|
||||||
SpaceBeforeAssignmentOperators: true
|
|
||||||
SpaceBeforeParens: ControlStatements
|
|
||||||
SpaceInEmptyParentheses: false
|
|
||||||
SpacesInAngles: false
|
|
||||||
SpacesInCStyleCastParentheses: false
|
|
||||||
SpacesInParentheses: false
|
|
||||||
SpacesInSquareBrackets: false
|
|
||||||
|
|
||||||
SortIncludes: false
|
|
||||||
ReflowComments: true
|
|
||||||
|
|
||||||
IncludeCategories:
|
|
||||||
- Regex: '^".*'
|
|
||||||
Priority: 3
|
|
||||||
- Regex: '^<.*'
|
|
||||||
Priority: 2
|
|
||||||
SortIncludes: true
|
|
||||||
|
|
||||||
FixNamespaceComments: true
|
FixNamespaceComments: true
|
||||||
|
IndentCaseLabels: false
|
||||||
|
IndentPPDirectives: AfterHash
|
||||||
|
ForEachMacros:
|
||||||
|
- foreach
|
||||||
|
- FOREACH_CHILD
|
||||||
|
@ -12,15 +12,11 @@
|
|||||||
#include "ngraph/op/broadcast.hpp"
|
#include "ngraph/op/broadcast.hpp"
|
||||||
#include "ngraph/op/constant.hpp"
|
#include "ngraph/op/constant.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
class numpy_autobroadcast_incompatible_shapes : public ngraph::ngraph_error {
|
||||||
{
|
|
||||||
class numpy_autobroadcast_incompatible_shapes : public ngraph::ngraph_error
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
numpy_autobroadcast_incompatible_shapes(const ngraph::Shape& shape1,
|
numpy_autobroadcast_incompatible_shapes(const ngraph::Shape& shape1, const ngraph::Shape& shape2);
|
||||||
const ngraph::Shape& shape2);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const ngraph::Shape m_shape1;
|
const ngraph::Shape m_shape1;
|
||||||
@ -84,8 +80,8 @@ namespace ngraph
|
|||||||
/// elements point to ngraph::Node objects whose output values have the same shape.
|
/// elements point to ngraph::Node objects whose output values have the same shape.
|
||||||
///
|
///
|
||||||
/// \exception ngraph::builder::numpy_autobroadcast_incompatible_shapes
|
/// \exception ngraph::builder::numpy_autobroadcast_incompatible_shapes
|
||||||
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
|
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> numpy_broadcast(
|
||||||
numpy_broadcast(const std::pair<Output<Node>, Output<Node>>& args);
|
const std::pair<Output<Node>, Output<Node>>& args);
|
||||||
|
|
||||||
/// \brief Broadcast shape of two nodes to make them compatible for a matrix
|
/// \brief Broadcast shape of two nodes to make them compatible for a matrix
|
||||||
/// multiplication.
|
/// multiplication.
|
||||||
@ -103,8 +99,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing both outputs broadcasted.
|
/// \return The vector containing both outputs broadcasted.
|
||||||
///
|
///
|
||||||
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left,
|
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left, const Output<Node>& right);
|
||||||
const Output<Node>& right);
|
|
||||||
|
|
||||||
/// \brief Cast shape of all input nodes for an element-wise operation that requires
|
/// \brief Cast shape of all input nodes for an element-wise operation that requires
|
||||||
/// shape-compatibility
|
/// shape-compatibility
|
||||||
@ -149,23 +144,19 @@ namespace ngraph
|
|||||||
/// \return A pair that contains the target shape as its first object and a vector of
|
/// \return A pair that contains the target shape as its first object and a vector of
|
||||||
/// padded input shapes ready to be broadcasted as the second object
|
/// padded input shapes ready to be broadcasted as the second object
|
||||||
///
|
///
|
||||||
std::pair<Shape, std::vector<Shape>>
|
std::pair<Shape, std::vector<Shape>> get_numpy_broadcast_shapes(const std::vector<Shape>& input_shapes);
|
||||||
get_numpy_broadcast_shapes(const std::vector<Shape>& input_shapes);
|
|
||||||
|
|
||||||
inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& value,
|
inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& value,
|
||||||
const Shape& new_shape,
|
const Shape& new_shape,
|
||||||
std::size_t start_match_axis)
|
std::size_t start_match_axis) {
|
||||||
{
|
auto shape_const = op::Constant::create(element::u64, Shape{new_shape.size()}, new_shape);
|
||||||
auto shape_const =
|
|
||||||
op::Constant::create(element::u64, Shape{new_shape.size()}, new_shape);
|
|
||||||
return std::make_shared<op::v1::Broadcast>(
|
return std::make_shared<op::v1::Broadcast>(
|
||||||
value,
|
value,
|
||||||
shape_const,
|
shape_const,
|
||||||
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
|
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace opset1
|
namespace opset1 {
|
||||||
{
|
|
||||||
///
|
///
|
||||||
/// \brief Broadcast right node to left node's shape using legacy scheme.
|
/// \brief Broadcast right node to left node's shape using legacy scheme.
|
||||||
///
|
///
|
||||||
@ -189,8 +180,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector with axes indexes mapping .
|
/// \return The vector with axes indexes mapping .
|
||||||
///
|
///
|
||||||
std::vector<std::size_t> get_axes_mapping(const Shape& output_shape,
|
std::vector<std::size_t> get_axes_mapping(const Shape& output_shape, const AxisSet& broadcast_axes);
|
||||||
const AxisSet& broadcast_axes);
|
|
||||||
|
|
||||||
///
|
///
|
||||||
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
|
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
|
||||||
@ -202,9 +192,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return Returns the Output object pointing to node with the axes mapping.
|
/// \return Returns the Output object pointing to node with the axes mapping.
|
||||||
///
|
///
|
||||||
Output<Node> get_axes_mapping_output(const Shape& output_shape,
|
Output<Node> get_axes_mapping_output(const Shape& output_shape, const Shape& input_shape, std::size_t start_match_axis);
|
||||||
const Shape& input_shape,
|
|
||||||
std::size_t start_match_axis);
|
|
||||||
|
|
||||||
///
|
///
|
||||||
/// \brief Creates Node returning the axes mapping for Broadcast operation.
|
/// \brief Creates Node returning the axes mapping for Broadcast operation.
|
||||||
@ -230,16 +218,11 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The Output object with Node returning axes mapping.
|
/// \return The Output object with Node returning axes mapping.
|
||||||
///
|
///
|
||||||
Output<Node> get_axes_mapping_output(const Shape& output_shape,
|
Output<Node> get_axes_mapping_output(const Shape& output_shape, const AxisSet& broadcast_axes);
|
||||||
const AxisSet& broadcast_axes);
|
|
||||||
|
|
||||||
Output<Node> make_broadcast(const Output<Node>& node,
|
Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, const AxisSet& broadcast_axes);
|
||||||
const Shape& target_shape,
|
|
||||||
const AxisSet& broadcast_axes);
|
|
||||||
|
|
||||||
Output<Node> make_broadcast(const Output<Node>& node,
|
Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, std::size_t start_match_axis);
|
||||||
const Shape& target_shape,
|
|
||||||
std::size_t start_match_axis);
|
|
||||||
|
|
||||||
} // namespace opset1
|
} // namespace opset1
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
@ -10,14 +10,10 @@
|
|||||||
#include "ngraph/op/constant.hpp"
|
#include "ngraph/op/constant.hpp"
|
||||||
#include "ngraph/type/float16.hpp"
|
#include "ngraph/type/float16.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
|
||||||
{
|
|
||||||
template <class T>
|
template <class T>
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> make_constant(const element::Type& type, const Shape& shape, const T& num) {
|
||||||
make_constant(const element::Type& type, const Shape& shape, const T& num)
|
|
||||||
{
|
|
||||||
std::shared_ptr<Node> val = nullptr;
|
std::shared_ptr<Node> val = nullptr;
|
||||||
|
|
||||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||||
@ -25,15 +21,15 @@ namespace ngraph
|
|||||||
# pragma GCC diagnostic error "-Wswitch"
|
# pragma GCC diagnostic error "-Wswitch"
|
||||||
# pragma GCC diagnostic error "-Wswitch-enum"
|
# pragma GCC diagnostic error "-Wswitch-enum"
|
||||||
#endif
|
#endif
|
||||||
switch (type)
|
switch (type) {
|
||||||
{
|
|
||||||
case element::Type_t::f32:
|
case element::Type_t::f32:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val =
|
||||||
type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)});
|
std::make_shared<ngraph::op::Constant>(type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::f64:
|
case element::Type_t::f64:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(type,
|
||||||
type, ngraph::Shape{}, std::vector<double>{static_cast<double>(num)});
|
ngraph::Shape{},
|
||||||
|
std::vector<double>{static_cast<double>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::f16:
|
case element::Type_t::f16:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(
|
||||||
@ -48,36 +44,44 @@ namespace ngraph
|
|||||||
std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))});
|
std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::i64:
|
case element::Type_t::i64:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(type,
|
||||||
type, ngraph::Shape{}, std::vector<int64_t>{static_cast<int64_t>(num)});
|
ngraph::Shape{},
|
||||||
|
std::vector<int64_t>{static_cast<int64_t>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::i32:
|
case element::Type_t::i32:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(type,
|
||||||
type, ngraph::Shape{}, std::vector<int32_t>{static_cast<int32_t>(num)});
|
ngraph::Shape{},
|
||||||
|
std::vector<int32_t>{static_cast<int32_t>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::i16:
|
case element::Type_t::i16:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(type,
|
||||||
type, ngraph::Shape{}, std::vector<int16_t>{static_cast<int16_t>(num)});
|
ngraph::Shape{},
|
||||||
|
std::vector<int16_t>{static_cast<int16_t>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::i8:
|
case element::Type_t::i8:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(type,
|
||||||
type, ngraph::Shape{}, std::vector<int8_t>{static_cast<int8_t>(num)});
|
ngraph::Shape{},
|
||||||
|
std::vector<int8_t>{static_cast<int8_t>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::u64:
|
case element::Type_t::u64:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(type,
|
||||||
type, ngraph::Shape{}, std::vector<uint64_t>{static_cast<uint64_t>(num)});
|
ngraph::Shape{},
|
||||||
|
std::vector<uint64_t>{static_cast<uint64_t>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::u32:
|
case element::Type_t::u32:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(type,
|
||||||
type, ngraph::Shape{}, std::vector<uint32_t>{static_cast<uint32_t>(num)});
|
ngraph::Shape{},
|
||||||
|
std::vector<uint32_t>{static_cast<uint32_t>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::u16:
|
case element::Type_t::u16:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(type,
|
||||||
type, ngraph::Shape{}, std::vector<uint16_t>{static_cast<uint16_t>(num)});
|
ngraph::Shape{},
|
||||||
|
std::vector<uint16_t>{static_cast<uint16_t>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::u8:
|
case element::Type_t::u8:
|
||||||
val = std::make_shared<ngraph::op::Constant>(
|
val = std::make_shared<ngraph::op::Constant>(type,
|
||||||
type, ngraph::Shape{}, std::vector<uint8_t>{static_cast<uint8_t>(num)});
|
ngraph::Shape{},
|
||||||
|
std::vector<uint8_t>{static_cast<uint8_t>(num)});
|
||||||
break;
|
break;
|
||||||
case element::Type_t::dynamic:
|
case element::Type_t::dynamic:
|
||||||
throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
|
throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
|
||||||
@ -96,11 +100,9 @@ namespace ngraph
|
|||||||
# pragma GCC diagnostic pop
|
# pragma GCC diagnostic pop
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (shape.size() > 0)
|
if (shape.size() > 0) {
|
||||||
{
|
|
||||||
ngraph::AxisSet axes;
|
ngraph::AxisSet axes;
|
||||||
for (size_t i = 0; i < shape.size(); i++)
|
for (size_t i = 0; i < shape.size(); i++) {
|
||||||
{
|
|
||||||
axes.insert(i);
|
axes.insert(i);
|
||||||
}
|
}
|
||||||
val = builder::opset1::make_broadcast(val, shape, axes).get_node_shared_ptr();
|
val = builder::opset1::make_broadcast(val, shape, axes).get_node_shared_ptr();
|
||||||
@ -119,7 +121,6 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The Constant node which have expected type, shape and value.
|
/// \return The Constant node which have expected type, shape and value.
|
||||||
///
|
///
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> make_constant_from_double(const element::Type& type, const Shape& shape, double num);
|
||||||
make_constant_from_double(const element::Type& type, const Shape& shape, double num);
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -10,21 +10,17 @@
|
|||||||
#include "ngraph/axis_set.hpp"
|
#include "ngraph/axis_set.hpp"
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
|
||||||
{
|
|
||||||
/// \brief Specifies method of bias application to avoid numerical problems
|
/// \brief Specifies method of bias application to avoid numerical problems
|
||||||
enum class BiasMode
|
enum class BiasMode {
|
||||||
{
|
|
||||||
// Add bias to intermediate result
|
// Add bias to intermediate result
|
||||||
ADD,
|
ADD,
|
||||||
// Calculate max of intermediate result and bias
|
// Calculate max of intermediate result and bias
|
||||||
MAX
|
MAX
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace opset1
|
namespace opset1 {
|
||||||
{
|
|
||||||
/// \brief Calculates L-0 norm of input tensor.
|
/// \brief Calculates L-0 norm of input tensor.
|
||||||
///
|
///
|
||||||
/// \note The L-0 norm represents the cardinality of elements different
|
/// \note The L-0 norm represents the cardinality of elements different
|
||||||
@ -36,9 +32,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return L-0 norm of value. The output sub-graph is composed of v1 ops.
|
/// \return L-0 norm of value. The output sub-graph is composed of v1 ops.
|
||||||
///
|
///
|
||||||
std::shared_ptr<Node> l0_norm(const Output<Node>& value,
|
std::shared_ptr<Node> l0_norm(const Output<Node>& value, const Output<Node>& reduction_axes, bool keep_dims = false);
|
||||||
const Output<Node>& reduction_axes,
|
|
||||||
bool keep_dims = false);
|
|
||||||
|
|
||||||
/// \brief Calculates L-1 norm of a value.
|
/// \brief Calculates L-1 norm of a value.
|
||||||
///
|
///
|
||||||
|
@ -7,12 +7,9 @@
|
|||||||
#include "ngraph/axis_set.hpp"
|
#include "ngraph/axis_set.hpp"
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
namespace opset1 {
|
||||||
{
|
|
||||||
namespace opset1
|
|
||||||
{
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
/// \brief Sum-based Mean of a Tensor.
|
/// \brief Sum-based Mean of a Tensor.
|
||||||
///
|
///
|
||||||
@ -36,13 +33,9 @@ namespace ngraph
|
|||||||
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
|
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
|
||||||
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
|
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
|
||||||
// clang-format on
|
// clang-format on
|
||||||
std::shared_ptr<Node> mean(const Output<Node>& node,
|
std::shared_ptr<Node> mean(const Output<Node>& node, const AxisSet& reduction_axes, bool keep_dims = false);
|
||||||
const AxisSet& reduction_axes,
|
|
||||||
bool keep_dims = false);
|
|
||||||
|
|
||||||
std::shared_ptr<Node> mean(const Output<Node>& node,
|
std::shared_ptr<Node> mean(const Output<Node>& node, const Output<Node>& reduction_axes, bool keep_dims = false);
|
||||||
const Output<Node>& reduction_axes,
|
|
||||||
bool keep_dims = false);
|
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
/// \brief Sum-based Variance of a Tensor.
|
/// \brief Sum-based Variance of a Tensor.
|
||||||
|
@ -11,12 +11,9 @@
|
|||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/shape.hpp"
|
#include "ngraph/shape.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
namespace opset1 {
|
||||||
{
|
|
||||||
namespace opset1
|
|
||||||
{
|
|
||||||
/// \brief Change shape of a value
|
/// \brief Change shape of a value
|
||||||
///
|
///
|
||||||
/// \param[in] value The value to be reshaped.
|
/// \param[in] value The value to be reshaped.
|
||||||
@ -31,8 +28,7 @@ namespace ngraph
|
|||||||
/// \param axes_order The permutation of axes.
|
/// \param axes_order The permutation of axes.
|
||||||
///
|
///
|
||||||
/// \return Transpose:v1 op.
|
/// \return Transpose:v1 op.
|
||||||
std::shared_ptr<Node> reorder_axes(const Output<Node>& value,
|
std::shared_ptr<Node> reorder_axes(const Output<Node>& value, std::vector<size_t> axes_order = {});
|
||||||
std::vector<size_t> axes_order = {});
|
|
||||||
|
|
||||||
/// \brief Return transposed value (with axes in reversed order).
|
/// \brief Return transposed value (with axes in reversed order).
|
||||||
///
|
///
|
||||||
@ -66,8 +62,7 @@ namespace ngraph
|
|||||||
/// \param[in] axes The vector defining indexes of axes to be removed.
|
/// \param[in] axes The vector defining indexes of axes to be removed.
|
||||||
///
|
///
|
||||||
/// \return Reshape:v1 op.
|
/// \return Reshape:v1 op.
|
||||||
std::shared_ptr<Node> squeeze(const Output<Node>& value,
|
std::shared_ptr<Node> squeeze(const Output<Node>& value, std::vector<std::size_t> axes = {0});
|
||||||
std::vector<std::size_t> axes = {0});
|
|
||||||
|
|
||||||
/// \brief Collapse specified axes into single one.
|
/// \brief Collapse specified axes into single one.
|
||||||
///
|
///
|
||||||
@ -79,9 +74,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The node with collapsed specified axes.
|
/// \return The node with collapsed specified axes.
|
||||||
///
|
///
|
||||||
std::shared_ptr<Node> collapse(const Output<Node>& value,
|
std::shared_ptr<Node> collapse(const Output<Node>& value, const std::size_t start_axis, const std::size_t end_axis);
|
||||||
const std::size_t start_axis,
|
|
||||||
const std::size_t end_axis);
|
|
||||||
} // namespace opset1
|
} // namespace opset1
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -7,10 +7,8 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
|
||||||
{
|
|
||||||
/// \brief Split value on specified axis into multiple parts.
|
/// \brief Split value on specified axis into multiple parts.
|
||||||
///
|
///
|
||||||
/// \param value The value to be split.
|
/// \param value The value to be split.
|
||||||
@ -20,9 +18,7 @@ namespace ngraph
|
|||||||
/// \return The vector containing multiple nodes we split input node into.
|
/// \return The vector containing multiple nodes we split input node into.
|
||||||
///
|
///
|
||||||
NGRAPH_DEPRECATED("This builder was deprecated.")
|
NGRAPH_DEPRECATED("This builder was deprecated.")
|
||||||
OutputVector split(const Output<Node>& value,
|
OutputVector split(const Output<Node>& value, const std::vector<size_t>& length_parts, int64_t axis = 0);
|
||||||
const std::vector<size_t>& length_parts,
|
|
||||||
int64_t axis = 0);
|
|
||||||
|
|
||||||
/// \brief Split node on specified axis into multiple parts.
|
/// \brief Split node on specified axis into multiple parts.
|
||||||
///
|
///
|
||||||
@ -41,8 +37,7 @@ namespace ngraph
|
|||||||
NGRAPH_DEPRECATED("This builder was deprecated.")
|
NGRAPH_DEPRECATED("This builder was deprecated.")
|
||||||
OutputVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
|
OutputVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
|
||||||
|
|
||||||
namespace opset1
|
namespace opset1 {
|
||||||
{
|
|
||||||
/// \brief Split value on specified axis into multiple parts.
|
/// \brief Split value on specified axis into multiple parts.
|
||||||
///
|
///
|
||||||
/// \param value The value to be split.
|
/// \param value The value to be split.
|
||||||
@ -56,9 +51,7 @@ namespace ngraph
|
|||||||
/// \return The vector containing multiple outputs we split input node into.
|
/// \return The vector containing multiple outputs we split input node into.
|
||||||
/// The vector is output of Split:v1 op
|
/// The vector is output of Split:v1 op
|
||||||
///
|
///
|
||||||
OutputVector split(const Output<Node>& value,
|
OutputVector split(const Output<Node>& value, const std::vector<size_t>& split_lengths, int64_t axis = 0);
|
||||||
const std::vector<size_t>& split_lengths,
|
|
||||||
int64_t axis = 0);
|
|
||||||
|
|
||||||
/// \brief Split value on specified axis into multiple parts.
|
/// \brief Split value on specified axis into multiple parts.
|
||||||
///
|
///
|
||||||
|
@ -18,21 +18,15 @@
|
|||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
numpy_autobroadcast_incompatible_shapes::numpy_autobroadcast_incompatible_shapes(const Shape& shape1,
|
||||||
{
|
|
||||||
numpy_autobroadcast_incompatible_shapes::numpy_autobroadcast_incompatible_shapes(
|
|
||||||
const Shape& shape1, const Shape& shape2)
|
|
||||||
: ngraph_error(error_str(shape1, shape2))
|
|
||||||
, m_shape1(shape1)
|
|
||||||
, m_shape2(shape2)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
string numpy_autobroadcast_incompatible_shapes::error_str(const Shape& shape1,
|
|
||||||
const Shape& shape2)
|
const Shape& shape2)
|
||||||
{
|
: ngraph_error(error_str(shape1, shape2)),
|
||||||
|
m_shape1(shape1),
|
||||||
|
m_shape2(shape2) {}
|
||||||
|
|
||||||
|
string numpy_autobroadcast_incompatible_shapes::error_str(const Shape& shape1, const Shape& shape2) {
|
||||||
ostringstream os;
|
ostringstream os;
|
||||||
os << "Auto-broadcast not possible for these input shapes:"
|
os << "Auto-broadcast not possible for these input shapes:"
|
||||||
<< " shape1=" << vector_to_string(shape1) << " shape2=" << vector_to_string(shape2);
|
<< " shape1=" << vector_to_string(shape1) << " shape2=" << vector_to_string(shape2);
|
||||||
@ -52,8 +46,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return Broadcast shape of input shapes.
|
/// \return Broadcast shape of input shapes.
|
||||||
///
|
///
|
||||||
static Shape calculate_broadcast_shape(Shape lhs_shape, Shape rhs_shape)
|
static Shape calculate_broadcast_shape(Shape lhs_shape, Shape rhs_shape) {
|
||||||
{
|
|
||||||
Shape result;
|
Shape result;
|
||||||
auto lhs_rank = lhs_shape.size();
|
auto lhs_rank = lhs_shape.size();
|
||||||
auto rhs_rank = rhs_shape.size();
|
auto rhs_rank = rhs_shape.size();
|
||||||
@ -64,13 +57,11 @@ namespace ngraph
|
|||||||
// left-pad the rhs_shape with ones
|
// left-pad the rhs_shape with ones
|
||||||
rhs_shape.insert(begin(rhs_shape), max_rank - rhs_rank, 1);
|
rhs_shape.insert(begin(rhs_shape), max_rank - rhs_rank, 1);
|
||||||
|
|
||||||
for (size_t index = 0; index < max_rank; ++index)
|
for (size_t index = 0; index < max_rank; ++index) {
|
||||||
{
|
|
||||||
size_t lhs_dim = lhs_shape.at(index);
|
size_t lhs_dim = lhs_shape.at(index);
|
||||||
size_t rhs_dim = rhs_shape.at(index);
|
size_t rhs_dim = rhs_shape.at(index);
|
||||||
|
|
||||||
if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1)
|
if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1) {
|
||||||
{
|
|
||||||
throw numpy_autobroadcast_incompatible_shapes(lhs_shape, rhs_shape);
|
throw numpy_autobroadcast_incompatible_shapes(lhs_shape, rhs_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,29 +71,23 @@ namespace ngraph
|
|||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const vector<Shape>& input_shapes)
|
pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const vector<Shape>& input_shapes) {
|
||||||
{
|
Shape target_shape = accumulate(begin(input_shapes), end(input_shapes), Shape{}, calculate_broadcast_shape);
|
||||||
Shape target_shape = accumulate(
|
|
||||||
begin(input_shapes), end(input_shapes), Shape{}, calculate_broadcast_shape);
|
|
||||||
|
|
||||||
vector<Shape> full_shapes;
|
vector<Shape> full_shapes;
|
||||||
for (const Shape& input : input_shapes)
|
for (const Shape& input : input_shapes) {
|
||||||
{
|
|
||||||
Shape padded_shape{input};
|
Shape padded_shape{input};
|
||||||
padded_shape.insert(
|
padded_shape.insert(begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
|
||||||
begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
|
|
||||||
full_shapes.push_back(move(padded_shape));
|
full_shapes.push_back(move(padded_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
return {target_shape, full_shapes};
|
return {target_shape, full_shapes};
|
||||||
}
|
}
|
||||||
|
|
||||||
static pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const OutputVector& values)
|
static pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const OutputVector& values) {
|
||||||
{
|
|
||||||
vector<Shape> input_shapes;
|
vector<Shape> input_shapes;
|
||||||
|
|
||||||
for (const auto& input : values)
|
for (const auto& input : values) {
|
||||||
{
|
|
||||||
input_shapes.push_back(input.get_shape());
|
input_shapes.push_back(input.get_shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,12 +110,10 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
static shared_ptr<Node> numpy_broadcast_node(const Output<Node>& value,
|
static shared_ptr<Node> numpy_broadcast_node(const Output<Node>& value,
|
||||||
const Shape& output_shape,
|
const Shape& output_shape,
|
||||||
const Shape& source_shape)
|
const Shape& source_shape) {
|
||||||
{
|
|
||||||
shared_ptr<Node> broadcasted_node = value.get_node_shared_ptr();
|
shared_ptr<Node> broadcasted_node = value.get_node_shared_ptr();
|
||||||
// If node already has the required shape, return original node
|
// If node already has the required shape, return original node
|
||||||
if (output_shape == value.get_shape())
|
if (output_shape == value.get_shape()) {
|
||||||
{
|
|
||||||
return broadcasted_node;
|
return broadcasted_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,29 +128,22 @@ namespace ngraph
|
|||||||
// Positions of axes which have length of 1 are needed to calculate broadcast_axes
|
// Positions of axes which have length of 1 are needed to calculate broadcast_axes
|
||||||
// for nGraph broadcast operation. We need to remove ones from source shape
|
// for nGraph broadcast operation. We need to remove ones from source shape
|
||||||
// to avoid broadcasting axis conflict.
|
// to avoid broadcasting axis conflict.
|
||||||
for (size_t index = 0; index < output_shape.size(); ++index)
|
for (size_t index = 0; index < output_shape.size(); ++index) {
|
||||||
{
|
if (source_shape.at(index) == 1 && output_shape.at(index) != 1) {
|
||||||
if (source_shape.at(index) == 1 && output_shape.at(index) != 1)
|
|
||||||
{
|
|
||||||
broadcast_axes.push_back(index);
|
broadcast_axes.push_back(index);
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
squeezed_shape.push_back(source_shape.at(index));
|
squeezed_shape.push_back(source_shape.at(index));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (squeezed_shape != value.get_shape())
|
if (squeezed_shape != value.get_shape()) {
|
||||||
{
|
|
||||||
broadcasted_node = builder::opset1::reshape(value, squeezed_shape);
|
broadcasted_node = builder::opset1::reshape(value, squeezed_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!broadcast_axes.empty())
|
if (!broadcast_axes.empty()) {
|
||||||
{
|
auto shape_const = op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
|
||||||
auto shape_const =
|
broadcasted_node =
|
||||||
op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
|
make_shared<op::v1::Broadcast>(broadcasted_node,
|
||||||
broadcasted_node = make_shared<op::v1::Broadcast>(
|
|
||||||
broadcasted_node,
|
|
||||||
shape_const,
|
shape_const,
|
||||||
opset1::get_axes_mapping_output(output_shape, broadcast_axes));
|
opset1::get_axes_mapping_output(output_shape, broadcast_axes));
|
||||||
}
|
}
|
||||||
@ -183,57 +159,45 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The broadcasted Node.
|
/// \return The broadcasted Node.
|
||||||
///
|
///
|
||||||
static shared_ptr<Node> broadcast_value_pdpd_style(const Output<Node>& value,
|
static shared_ptr<Node> broadcast_value_pdpd_style(const Output<Node>& value, const Shape& output_shape, int64_t axis) {
|
||||||
const Shape& output_shape,
|
|
||||||
int64_t axis)
|
|
||||||
{
|
|
||||||
auto value_shape = value.get_shape();
|
auto value_shape = value.get_shape();
|
||||||
|
|
||||||
// If node already has the required shape, return original node
|
// If node already has the required shape, return original node
|
||||||
if (output_shape == value_shape)
|
if (output_shape == value_shape) {
|
||||||
{
|
|
||||||
return value.get_node_shared_ptr();
|
return value.get_node_shared_ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (axis == -1)
|
if (axis == -1) {
|
||||||
{
|
|
||||||
axis = output_shape.size() - value_shape.size();
|
axis = output_shape.size() - value_shape.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto trimmed_value_shape = value_shape;
|
auto trimmed_value_shape = value_shape;
|
||||||
while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1)
|
while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1) {
|
||||||
{
|
|
||||||
trimmed_value_shape.pop_back();
|
trimmed_value_shape.pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
AxisSet axes;
|
AxisSet axes;
|
||||||
for (int64_t i = 0; i < axis; ++i)
|
for (int64_t i = 0; i < axis; ++i) {
|
||||||
{
|
|
||||||
axes.insert(static_cast<size_t>(i));
|
axes.insert(static_cast<size_t>(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i)
|
for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i) {
|
||||||
{
|
|
||||||
axes.insert(i);
|
axes.insert(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto trimmed_value = value;
|
auto trimmed_value = value;
|
||||||
if (value_shape != trimmed_value_shape)
|
if (value_shape != trimmed_value_shape) {
|
||||||
{
|
|
||||||
trimmed_value = builder::opset1::reshape(value, trimmed_value_shape);
|
trimmed_value = builder::opset1::reshape(value, trimmed_value_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto shape_const =
|
auto shape_const = op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
|
||||||
op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
|
auto value_bcast =
|
||||||
auto value_bcast = make_shared<op::v1::Broadcast>(
|
make_shared<op::v1::Broadcast>(trimmed_value, shape_const, opset1::get_axes_mapping_output(output_shape, axes));
|
||||||
trimmed_value, shape_const, opset1::get_axes_mapping_output(output_shape, axes));
|
|
||||||
|
|
||||||
return move(value_bcast);
|
return move(value_bcast);
|
||||||
}
|
}
|
||||||
|
|
||||||
pair<shared_ptr<Node>, shared_ptr<Node>>
|
pair<shared_ptr<Node>, shared_ptr<Node>> numpy_broadcast(const pair<Output<Node>, Output<Node>>& args) {
|
||||||
numpy_broadcast(const pair<Output<Node>, Output<Node>>& args)
|
|
||||||
{
|
|
||||||
NGRAPH_CHECK(args.first.get_node());
|
NGRAPH_CHECK(args.first.get_node());
|
||||||
NGRAPH_CHECK(args.second.get_node());
|
NGRAPH_CHECK(args.second.get_node());
|
||||||
|
|
||||||
@ -241,22 +205,17 @@ namespace ngraph
|
|||||||
const Shape& arg2_in_shape = args.second.get_shape();
|
const Shape& arg2_in_shape = args.second.get_shape();
|
||||||
|
|
||||||
// Handle the trivial case...
|
// Handle the trivial case...
|
||||||
if (arg1_in_shape == arg2_in_shape)
|
if (arg1_in_shape == arg2_in_shape) {
|
||||||
{
|
return make_pair(args.first.get_node_shared_ptr(), args.second.get_node_shared_ptr());
|
||||||
return make_pair(args.first.get_node_shared_ptr(),
|
|
||||||
args.second.get_node_shared_ptr());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector bcasted_outputs =
|
NodeVector bcasted_outputs = as_node_vector(numpy_broadcast_outputs({args.first, args.second}));
|
||||||
as_node_vector(numpy_broadcast_outputs({args.first, args.second}));
|
|
||||||
|
|
||||||
return make_pair(bcasted_outputs.at(0), bcasted_outputs.at(1));
|
return make_pair(bcasted_outputs.at(0), bcasted_outputs.at(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
OutputVector numpy_broadcast_outputs(const OutputVector& values)
|
OutputVector numpy_broadcast_outputs(const OutputVector& values) {
|
||||||
{
|
if (values.size() <= 1) {
|
||||||
if (values.size() <= 1)
|
|
||||||
{
|
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,37 +223,29 @@ namespace ngraph
|
|||||||
auto bcast_shapes = get_numpy_broadcast_shapes(values);
|
auto bcast_shapes = get_numpy_broadcast_shapes(values);
|
||||||
|
|
||||||
OutputVector broadcasted_inputs;
|
OutputVector broadcasted_inputs;
|
||||||
for (size_t i = 0; i < values.size(); ++i)
|
for (size_t i = 0; i < values.size(); ++i) {
|
||||||
{
|
broadcasted_inputs.push_back(numpy_broadcast_node(values[i], bcast_shapes.first, bcast_shapes.second[i]));
|
||||||
broadcasted_inputs.push_back(
|
|
||||||
numpy_broadcast_node(values[i], bcast_shapes.first, bcast_shapes.second[i]));
|
|
||||||
}
|
}
|
||||||
return broadcasted_inputs;
|
return broadcasted_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape)
|
shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape) {
|
||||||
{
|
|
||||||
auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
|
auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
|
||||||
return numpy_broadcast_node(value, bcast_shape.first, bcast_shape.second[0]);
|
return numpy_broadcast_node(value, bcast_shape.first, bcast_shape.second[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left,
|
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left, const Output<Node>& right) {
|
||||||
const Output<Node>& right)
|
|
||||||
{
|
|
||||||
const auto& left_shape = left.get_shape();
|
const auto& left_shape = left.get_shape();
|
||||||
const auto& right_shape = right.get_shape();
|
const auto& right_shape = right.get_shape();
|
||||||
// Broadcast only _stack of matrices_ axes.
|
// Broadcast only _stack of matrices_ axes.
|
||||||
const auto& numpy_shapes =
|
const auto& numpy_shapes = get_numpy_broadcast_shapes(
|
||||||
get_numpy_broadcast_shapes({Shape{begin(left_shape), next(end(left_shape), -2)},
|
{Shape{begin(left_shape), next(end(left_shape), -2)}, Shape{begin(right_shape), next(end(right_shape), -2)}});
|
||||||
Shape{begin(right_shape), next(end(right_shape), -2)}});
|
|
||||||
|
|
||||||
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
|
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
|
||||||
auto left_output_shape = numpy_shapes.first;
|
auto left_output_shape = numpy_shapes.first;
|
||||||
auto right_output_shape = numpy_shapes.first;
|
auto right_output_shape = numpy_shapes.first;
|
||||||
// Append the last two axes original dimensions.
|
// Append the last two axes original dimensions.
|
||||||
left_output_shape.insert(end(left_output_shape),
|
left_output_shape.insert(end(left_output_shape), next(begin(left_shape), left_shape.size() - 2), end(left_shape));
|
||||||
next(begin(left_shape), left_shape.size() - 2),
|
|
||||||
end(left_shape));
|
|
||||||
right_output_shape.insert(end(right_output_shape),
|
right_output_shape.insert(end(right_output_shape),
|
||||||
next(begin(right_shape), right_shape.size() - 2),
|
next(begin(right_shape), right_shape.size() - 2),
|
||||||
end(right_shape));
|
end(right_shape));
|
||||||
@ -302,37 +253,28 @@ namespace ngraph
|
|||||||
auto left_full_shape = numpy_shapes.second.at(0);
|
auto left_full_shape = numpy_shapes.second.at(0);
|
||||||
auto right_full_shape = numpy_shapes.second.at(1);
|
auto right_full_shape = numpy_shapes.second.at(1);
|
||||||
// Append the last two axes original dimensions.
|
// Append the last two axes original dimensions.
|
||||||
left_full_shape.insert(end(left_full_shape),
|
left_full_shape.insert(end(left_full_shape), next(begin(left_shape), left_shape.size() - 2), end(left_shape));
|
||||||
next(begin(left_shape), left_shape.size() - 2),
|
right_full_shape.insert(end(right_full_shape), next(begin(right_shape), right_shape.size() - 2), end(right_shape));
|
||||||
end(left_shape));
|
|
||||||
right_full_shape.insert(end(right_full_shape),
|
|
||||||
next(begin(right_shape), right_shape.size() - 2),
|
|
||||||
end(right_shape));
|
|
||||||
|
|
||||||
return {numpy_broadcast_node(left, left_output_shape, left_full_shape),
|
return {numpy_broadcast_node(left, left_output_shape, left_full_shape),
|
||||||
numpy_broadcast_node(right, right_output_shape, right_full_shape)};
|
numpy_broadcast_node(right, right_output_shape, right_full_shape)};
|
||||||
}
|
}
|
||||||
|
|
||||||
OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis)
|
OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis) {
|
||||||
{
|
if (inputs.size() <= 1) {
|
||||||
if (inputs.size() <= 1)
|
|
||||||
{
|
|
||||||
return inputs;
|
return inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
OutputVector broadcasted_inputs{inputs[0]};
|
OutputVector broadcasted_inputs{inputs[0]};
|
||||||
for (size_t i = 1; i < inputs.size(); ++i)
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
{
|
broadcasted_inputs.push_back(broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
|
||||||
broadcasted_inputs.push_back(
|
|
||||||
broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
|
|
||||||
}
|
}
|
||||||
return broadcasted_inputs;
|
return broadcasted_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node> calculate_broadcast_axes(const Shape& output_shape,
|
std::shared_ptr<Node> calculate_broadcast_axes(const Shape& output_shape,
|
||||||
const Shape& input_shape,
|
const Shape& input_shape,
|
||||||
size_t start_match_axis)
|
size_t start_match_axis) {
|
||||||
{
|
|
||||||
vector<size_t> axes(output_shape.size() - input_shape.size());
|
vector<size_t> axes(output_shape.size() - input_shape.size());
|
||||||
// Populate the axes vector with monotonic increasing series from 0 until
|
// Populate the axes vector with monotonic increasing series from 0 until
|
||||||
// output_shape_size, excluding values in range:
|
// output_shape_size, excluding values in range:
|
||||||
@ -344,53 +286,41 @@ namespace ngraph
|
|||||||
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
|
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace opset1
|
namespace opset1 {
|
||||||
{
|
|
||||||
Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
|
Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
|
||||||
const Output<Node>& right,
|
const Output<Node>& right,
|
||||||
size_t start_match_axis)
|
size_t start_match_axis) {
|
||||||
{
|
|
||||||
const auto& left_shape = left.get_shape();
|
const auto& left_shape = left.get_shape();
|
||||||
const auto& right_shape = right.get_shape();
|
const auto& right_shape = right.get_shape();
|
||||||
|
|
||||||
bool dimensions_identical = (left_shape == right_shape);
|
bool dimensions_identical = (left_shape == right_shape);
|
||||||
if (dimensions_identical)
|
if (dimensions_identical) {
|
||||||
{
|
|
||||||
return right;
|
return right;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare new shape of right operand for broadcasting
|
// Prepare new shape of right operand for broadcasting
|
||||||
// Remove dimensions with length=1 from back
|
// Remove dimensions with length=1 from back
|
||||||
auto new_right_shape = right_shape;
|
auto new_right_shape = right_shape;
|
||||||
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
|
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension) {
|
||||||
{
|
if (new_right_shape.at(dimension) == 1) {
|
||||||
if (new_right_shape.at(dimension) == 1)
|
|
||||||
{
|
|
||||||
new_right_shape.pop_back();
|
new_right_shape.pop_back();
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find first dimensions at front with length different from 1
|
// Find first dimensions at front with length different from 1
|
||||||
size_t num_ones = 0;
|
size_t num_ones = 0;
|
||||||
for (size_t dimension : new_right_shape)
|
for (size_t dimension : new_right_shape) {
|
||||||
{
|
if (dimension == 1) {
|
||||||
if (dimension == 1)
|
|
||||||
{
|
|
||||||
++num_ones;
|
++num_ones;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove dimensions with length=1 from front
|
// Remove dimensions with length=1 from front
|
||||||
new_right_shape.erase(begin(new_right_shape),
|
new_right_shape.erase(begin(new_right_shape), next(begin(new_right_shape), num_ones));
|
||||||
next(begin(new_right_shape), num_ones));
|
|
||||||
|
|
||||||
auto reshape_right = reshape(right, new_right_shape);
|
auto reshape_right = reshape(right, new_right_shape);
|
||||||
|
|
||||||
@ -400,14 +330,11 @@ namespace ngraph
|
|||||||
return make_broadcast(reshape_right, left_shape, start_match_axis);
|
return make_broadcast(reshape_right, left_shape, start_match_axis);
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<size_t> get_axes_mapping(const Shape& output_shape,
|
vector<size_t> get_axes_mapping(const Shape& output_shape, const AxisSet& broadcast_axes) {
|
||||||
const AxisSet& broadcast_axes)
|
|
||||||
{
|
|
||||||
NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size()));
|
NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size()));
|
||||||
vector<size_t> axes_mapping(output_shape.size());
|
vector<size_t> axes_mapping(output_shape.size());
|
||||||
iota(axes_mapping.begin(), axes_mapping.end(), 0);
|
iota(axes_mapping.begin(), axes_mapping.end(), 0);
|
||||||
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i)
|
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i) {
|
||||||
{
|
|
||||||
axes_mapping.erase(axes_mapping.begin() + *i);
|
axes_mapping.erase(axes_mapping.begin() + *i);
|
||||||
}
|
}
|
||||||
return axes_mapping;
|
return axes_mapping;
|
||||||
@ -415,13 +342,11 @@ namespace ngraph
|
|||||||
|
|
||||||
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
|
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
|
||||||
const PartialShape& input_shape,
|
const PartialShape& input_shape,
|
||||||
std::size_t start_match_axis)
|
std::size_t start_match_axis) {
|
||||||
{
|
|
||||||
NGRAPH_CHECK((input_shape.rank().is_static() && output_shape.rank().is_static()),
|
NGRAPH_CHECK((input_shape.rank().is_static() && output_shape.rank().is_static()),
|
||||||
"Tensor's rank has to be static.");
|
"Tensor's rank has to be static.");
|
||||||
NGRAPH_CHECK(
|
NGRAPH_CHECK(
|
||||||
(input_shape.rank().get_length() + static_cast<int64_t>(start_match_axis) <=
|
(input_shape.rank().get_length() + static_cast<int64_t>(start_match_axis) <= output_shape.rank().get_length()),
|
||||||
output_shape.rank().get_length()),
|
|
||||||
"Unable to figure out axes mapping.");
|
"Unable to figure out axes mapping.");
|
||||||
|
|
||||||
vector<int64_t> mapping(input_shape.rank().get_length());
|
vector<int64_t> mapping(input_shape.rank().get_length());
|
||||||
@ -430,54 +355,40 @@ namespace ngraph
|
|||||||
return op::Constant::create(element::i64, Shape{mapping.size()}, mapping);
|
return op::Constant::create(element::i64, Shape{mapping.size()}, mapping);
|
||||||
}
|
}
|
||||||
|
|
||||||
Output<Node> get_axes_mapping_output(const Shape& output_shape,
|
Output<Node> get_axes_mapping_output(const Shape& output_shape, const AxisSet& broadcast_axes) {
|
||||||
const AxisSet& broadcast_axes)
|
|
||||||
{
|
|
||||||
vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)};
|
vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)};
|
||||||
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
|
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
|
||||||
}
|
}
|
||||||
|
|
||||||
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
|
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
|
||||||
const Output<Node>& input_shape,
|
const Output<Node>& input_shape,
|
||||||
std::size_t start_match_axis)
|
std::size_t start_match_axis) {
|
||||||
{
|
|
||||||
const auto one_node = opset7::Constant::create(element::i64, Shape{}, {1});
|
const auto one_node = opset7::Constant::create(element::i64, Shape{}, {1});
|
||||||
const auto zero_node = opset7::Constant::create(element::i64, Shape{}, {0});
|
const auto zero_node = opset7::Constant::create(element::i64, Shape{}, {0});
|
||||||
const auto start_match_axis_node =
|
const auto start_match_axis_node = opset7::Constant::create(element::i64, Shape{}, {start_match_axis});
|
||||||
opset7::Constant::create(element::i64, Shape{}, {start_match_axis});
|
const auto target_shape_rank_node =
|
||||||
const auto target_shape_rank_node = builder::opset1::reshape(
|
builder::opset1::reshape(std::make_shared<opset7::ShapeOf>(input_shape), Shape{});
|
||||||
std::make_shared<opset7::ShapeOf>(input_shape), Shape{});
|
|
||||||
|
|
||||||
const auto range_node = std::make_shared<opset7::Range>(
|
const auto range_node = std::make_shared<opset7::Range>(zero_node, target_shape_rank_node, one_node, element::i64);
|
||||||
zero_node, target_shape_rank_node, one_node, element::i64);
|
|
||||||
|
|
||||||
// workaround for GPU plugin type incompatibility
|
// workaround for GPU plugin type incompatibility
|
||||||
const auto range_node_converted = std::make_shared<opset7::Convert>(
|
const auto range_node_converted =
|
||||||
range_node, start_match_axis_node->get_element_type());
|
std::make_shared<opset7::Convert>(range_node, start_match_axis_node->get_element_type());
|
||||||
// end of workaround
|
// end of workaround
|
||||||
|
|
||||||
const auto result =
|
const auto result = std::make_shared<opset7::Add>(range_node_converted, start_match_axis_node);
|
||||||
std::make_shared<opset7::Add>(range_node_converted, start_match_axis_node);
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Output<Node> make_broadcast(const Output<Node>& node,
|
Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, const AxisSet& broadcast_axes) {
|
||||||
const Shape& target_shape,
|
return make_shared<op::v1::Broadcast>(node,
|
||||||
const AxisSet& broadcast_axes)
|
|
||||||
{
|
|
||||||
return make_shared<op::v1::Broadcast>(
|
|
||||||
node,
|
|
||||||
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
|
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
|
||||||
get_axes_mapping_output(target_shape, broadcast_axes));
|
get_axes_mapping_output(target_shape, broadcast_axes));
|
||||||
}
|
}
|
||||||
|
|
||||||
Output<Node> make_broadcast(const Output<Node>& node,
|
Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, size_t start_match_axis) {
|
||||||
const Shape& target_shape,
|
|
||||||
size_t start_match_axis)
|
|
||||||
{
|
|
||||||
const auto node_shape = std::make_shared<opset7::ShapeOf>(node);
|
const auto node_shape = std::make_shared<opset7::ShapeOf>(node);
|
||||||
return make_shared<op::v1::Broadcast>(
|
return make_shared<op::v1::Broadcast>(node,
|
||||||
node,
|
|
||||||
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
|
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
|
||||||
get_axes_mapping_output(target_shape, node_shape, start_match_axis));
|
get_axes_mapping_output(target_shape, node_shape, start_match_axis));
|
||||||
}
|
}
|
||||||
|
@ -4,83 +4,60 @@
|
|||||||
|
|
||||||
#include "ngraph/builder/make_constant.hpp"
|
#include "ngraph/builder/make_constant.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
std::shared_ptr<Node> make_constant_from_double(const element::Type& type, const Shape& shape, double num) {
|
||||||
{
|
auto ceil_func = [](double x) {
|
||||||
std::shared_ptr<Node>
|
return ceil(x);
|
||||||
make_constant_from_double(const element::Type& type, const Shape& shape, double num)
|
};
|
||||||
{
|
|
||||||
auto ceil_func = [](double x) { return ceil(x); };
|
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> result = nullptr;
|
std::shared_ptr<ngraph::Node> result = nullptr;
|
||||||
switch (type)
|
switch (type) {
|
||||||
{
|
case element::Type_t::i8: {
|
||||||
case element::Type_t::i8:
|
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int8_t>(num, ceil_func));
|
||||||
{
|
|
||||||
result = std::make_shared<ngraph::op::Constant>(
|
|
||||||
type, shape, double_to_int<int8_t>(num, ceil_func));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::i16:
|
case element::Type_t::i16: {
|
||||||
{
|
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int16_t>(num, ceil_func));
|
||||||
result = std::make_shared<ngraph::op::Constant>(
|
|
||||||
type, shape, double_to_int<int16_t>(num, ceil_func));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::i32:
|
case element::Type_t::i32: {
|
||||||
{
|
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int32_t>(num, ceil_func));
|
||||||
result = std::make_shared<ngraph::op::Constant>(
|
|
||||||
type, shape, double_to_int<int32_t>(num, ceil_func));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::i64:
|
case element::Type_t::i64: {
|
||||||
{
|
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int64_t>(num, ceil_func));
|
||||||
result = std::make_shared<ngraph::op::Constant>(
|
|
||||||
type, shape, double_to_int<int64_t>(num, ceil_func));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::u8:
|
case element::Type_t::u8: {
|
||||||
{
|
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint8_t>(num, ceil_func));
|
||||||
result = std::make_shared<ngraph::op::Constant>(
|
|
||||||
type, shape, double_to_int<uint8_t>(num, ceil_func));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::u16:
|
case element::Type_t::u16: {
|
||||||
{
|
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint16_t>(num, ceil_func));
|
||||||
result = std::make_shared<ngraph::op::Constant>(
|
|
||||||
type, shape, double_to_int<uint16_t>(num, ceil_func));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::u32:
|
case element::Type_t::u32: {
|
||||||
{
|
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint32_t>(num, ceil_func));
|
||||||
result = std::make_shared<ngraph::op::Constant>(
|
|
||||||
type, shape, double_to_int<uint32_t>(num, ceil_func));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::u64:
|
case element::Type_t::u64: {
|
||||||
{
|
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint64_t>(num, ceil_func));
|
||||||
result = std::make_shared<ngraph::op::Constant>(
|
|
||||||
type, shape, double_to_int<uint64_t>(num, ceil_func));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::f16:
|
case element::Type_t::f16: {
|
||||||
{
|
|
||||||
result = builder::make_constant(type, shape, static_cast<float16>(num));
|
result = builder::make_constant(type, shape, static_cast<float16>(num));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::bf16:
|
case element::Type_t::bf16: {
|
||||||
{
|
|
||||||
result = builder::make_constant(type, shape, static_cast<bfloat16>(num));
|
result = builder::make_constant(type, shape, static_cast<bfloat16>(num));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::f32:
|
case element::Type_t::f32: {
|
||||||
{
|
|
||||||
result = builder::make_constant(type, shape, static_cast<float>(num));
|
result = builder::make_constant(type, shape, static_cast<float>(num));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case element::Type_t::f64:
|
case element::Type_t::f64: {
|
||||||
{
|
|
||||||
result = builder::make_constant(type, shape, num);
|
result = builder::make_constant(type, shape, num);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include "ngraph/builder/norm.hpp"
|
#include "ngraph/builder/norm.hpp"
|
||||||
|
|
||||||
#include "ngraph/op/abs.hpp"
|
#include "ngraph/op/abs.hpp"
|
||||||
#include "ngraph/op/add.hpp"
|
#include "ngraph/op/add.hpp"
|
||||||
#include "ngraph/op/constant.hpp"
|
#include "ngraph/op/constant.hpp"
|
||||||
@ -18,105 +19,84 @@
|
|||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
namespace detail {
|
||||||
{
|
namespace opset1 {
|
||||||
namespace detail
|
|
||||||
{
|
|
||||||
namespace opset1
|
|
||||||
{
|
|
||||||
shared_ptr<Node> lp_norm(const Output<Node>& value,
|
shared_ptr<Node> lp_norm(const Output<Node>& value,
|
||||||
size_t p_norm,
|
size_t p_norm,
|
||||||
const Output<Node>& reduction_axes,
|
const Output<Node>& reduction_axes,
|
||||||
float bias,
|
float bias,
|
||||||
bool keep_dims)
|
bool keep_dims) {
|
||||||
{
|
|
||||||
// In general "entrywise" lp-norm for matrix `A` is defined as following double
|
// In general "entrywise" lp-norm for matrix `A` is defined as following double
|
||||||
// sum:
|
// sum:
|
||||||
// ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p}
|
// ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p}
|
||||||
shared_ptr<Node> abs_values{make_shared<ngraph::opset1::Abs>(value)};
|
shared_ptr<Node> abs_values{make_shared<ngraph::opset1::Abs>(value)};
|
||||||
shared_ptr<Node> p_node = ngraph::opset1::Constant::create(
|
shared_ptr<Node> p_node = ngraph::opset1::Constant::create(value.get_element_type(), Shape{}, {p_norm});
|
||||||
value.get_element_type(), Shape{}, {p_norm});
|
|
||||||
|
|
||||||
// Get inner part of equation: abs_values^p_node, then sum over reduction_axes.
|
// Get inner part of equation: abs_values^p_node, then sum over reduction_axes.
|
||||||
shared_ptr<Node> values{make_shared<ngraph::opset1::Power>(abs_values, p_node)};
|
shared_ptr<Node> values{make_shared<ngraph::opset1::Power>(abs_values, p_node)};
|
||||||
values =
|
values = make_shared<ngraph::opset1::ReduceSum>(values, reduction_axes, keep_dims);
|
||||||
make_shared<ngraph::opset1::ReduceSum>(values, reduction_axes, keep_dims);
|
|
||||||
|
|
||||||
shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(
|
shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
|
||||||
values->get_element_type(), Shape{}, {bias})};
|
|
||||||
|
|
||||||
values = make_shared<ngraph::opset1::Add>(values, bias_node);
|
values = make_shared<ngraph::opset1::Add>(values, bias_node);
|
||||||
|
|
||||||
// Get outer part of equation: raise values to 1/p_norm exponent.
|
// Get outer part of equation: raise values to 1/p_norm exponent.
|
||||||
shared_ptr<Node> inv_p_node = ngraph::opset1::Constant::create(
|
shared_ptr<Node> inv_p_node = ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {1.f / p_norm});
|
||||||
values->get_element_type(), Shape{}, {1.f / p_norm});
|
|
||||||
|
|
||||||
return {make_shared<ngraph::opset1::Power>(values, inv_p_node)
|
return {make_shared<ngraph::opset1::Power>(values, inv_p_node)->add_provenance_group_members_above({value})};
|
||||||
->add_provenance_group_members_above({value})};
|
|
||||||
}
|
}
|
||||||
} // namespace opset1
|
} // namespace opset1
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::l0_norm(const Output<Node>& value,
|
shared_ptr<Node> builder::opset1::l0_norm(const Output<Node>& value,
|
||||||
const Output<Node>& reduction_axes,
|
const Output<Node>& reduction_axes,
|
||||||
bool keep_dims)
|
bool keep_dims) {
|
||||||
{
|
|
||||||
// L0 norm returns number of elements different from zero.
|
// L0 norm returns number of elements different from zero.
|
||||||
const shared_ptr<Node> zero_node{
|
const shared_ptr<Node> zero_node{ngraph::opset1::Constant::create(value.get_element_type(), Shape{}, {0.f})};
|
||||||
ngraph::opset1::Constant::create(value.get_element_type(), Shape{}, {0.f})};
|
|
||||||
|
|
||||||
// Convert bool values to input node data type.
|
// Convert bool values to input node data type.
|
||||||
const shared_ptr<Node> non_zero_values = make_shared<ngraph::opset1::Convert>(
|
const shared_ptr<Node> non_zero_values =
|
||||||
make_shared<ngraph::opset1::NotEqual>(value, zero_node), value.get_element_type());
|
make_shared<ngraph::opset1::Convert>(make_shared<ngraph::opset1::NotEqual>(value, zero_node),
|
||||||
|
value.get_element_type());
|
||||||
|
|
||||||
return make_shared<ngraph::opset1::ReduceSum>(
|
return make_shared<ngraph::opset1::ReduceSum>(non_zero_values, reduction_axes, keep_dims)
|
||||||
non_zero_values, reduction_axes, keep_dims)
|
|
||||||
->add_provenance_group_members_above({value});
|
->add_provenance_group_members_above({value});
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::l1_norm(const Output<Node>& value,
|
shared_ptr<Node> builder::opset1::l1_norm(const Output<Node>& value,
|
||||||
const Output<Node>& reduction_axes,
|
const Output<Node>& reduction_axes,
|
||||||
float bias,
|
float bias,
|
||||||
bool keep_dims)
|
bool keep_dims) {
|
||||||
{
|
const shared_ptr<Node> values{
|
||||||
const shared_ptr<Node> values{make_shared<ngraph::opset1::ReduceSum>(
|
make_shared<ngraph::opset1::ReduceSum>(make_shared<ngraph::opset1::Abs>(value), reduction_axes, keep_dims)};
|
||||||
make_shared<ngraph::opset1::Abs>(value), reduction_axes, keep_dims)};
|
|
||||||
|
|
||||||
const shared_ptr<Node> bias_node{
|
const shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
|
||||||
ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
|
|
||||||
|
|
||||||
return make_shared<ngraph::opset1::Add>(values, bias_node)
|
return make_shared<ngraph::opset1::Add>(values, bias_node)->add_provenance_group_members_above({value});
|
||||||
->add_provenance_group_members_above({value});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::l2_norm(const Output<Node>& value,
|
shared_ptr<Node> builder::opset1::l2_norm(const Output<Node>& value,
|
||||||
const Output<Node>& reduction_axes,
|
const Output<Node>& reduction_axes,
|
||||||
float bias,
|
float bias,
|
||||||
BiasMode bias_mode,
|
BiasMode bias_mode,
|
||||||
bool keep_dims)
|
bool keep_dims) {
|
||||||
{
|
shared_ptr<Node> pow =
|
||||||
shared_ptr<Node> pow = make_shared<ngraph::opset1::Power>(
|
make_shared<ngraph::opset1::Power>(value,
|
||||||
value, make_shared<ngraph::opset1::Constant>(value.get_element_type(), Shape{}, 2));
|
make_shared<ngraph::opset1::Constant>(value.get_element_type(), Shape{}, 2));
|
||||||
shared_ptr<Node> values{
|
shared_ptr<Node> values{make_shared<ngraph::opset1::ReduceSum>(pow, reduction_axes, keep_dims)};
|
||||||
make_shared<ngraph::opset1::ReduceSum>(pow, reduction_axes, keep_dims)};
|
|
||||||
|
|
||||||
shared_ptr<Node> bias_node{
|
shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
|
||||||
ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
|
|
||||||
shared_ptr<Node> result;
|
shared_ptr<Node> result;
|
||||||
switch (bias_mode)
|
switch (bias_mode) {
|
||||||
{
|
case BiasMode::MAX: {
|
||||||
case BiasMode::MAX:
|
result = make_shared<ngraph::opset1::Sqrt>(make_shared<ngraph::opset1::Maximum>(values, bias_node));
|
||||||
{
|
|
||||||
result = make_shared<ngraph::opset1::Sqrt>(
|
|
||||||
make_shared<ngraph::opset1::Maximum>(values, bias_node));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BiasMode::ADD:
|
case BiasMode::ADD:
|
||||||
default:
|
default:
|
||||||
result = make_shared<ngraph::opset1::Sqrt>(
|
result = make_shared<ngraph::opset1::Sqrt>(make_shared<ngraph::opset1::Add>(values, bias_node));
|
||||||
make_shared<ngraph::opset1::Add>(values, bias_node));
|
|
||||||
}
|
}
|
||||||
return result->add_provenance_group_members_above({value});
|
return result->add_provenance_group_members_above({value});
|
||||||
}
|
}
|
||||||
@ -125,26 +105,21 @@ namespace ngraph
|
|||||||
const Output<Node>& reduction_axes,
|
const Output<Node>& reduction_axes,
|
||||||
size_t p_norm,
|
size_t p_norm,
|
||||||
float bias,
|
float bias,
|
||||||
bool keep_dims)
|
bool keep_dims) {
|
||||||
{
|
|
||||||
// The number of non-zero elements
|
// The number of non-zero elements
|
||||||
if (p_norm == 0)
|
if (p_norm == 0) {
|
||||||
{
|
|
||||||
return opset1::l0_norm(value, reduction_axes, keep_dims);
|
return opset1::l0_norm(value, reduction_axes, keep_dims);
|
||||||
}
|
}
|
||||||
// sum of absolute values.
|
// sum of absolute values.
|
||||||
else if (p_norm == 1)
|
else if (p_norm == 1) {
|
||||||
{
|
|
||||||
return opset1::l1_norm(value, reduction_axes, bias, keep_dims);
|
return opset1::l1_norm(value, reduction_axes, bias, keep_dims);
|
||||||
}
|
}
|
||||||
// sqrt of sum of squares - Euclidean norm
|
// sqrt of sum of squares - Euclidean norm
|
||||||
else if (p_norm == 2)
|
else if (p_norm == 2) {
|
||||||
{
|
|
||||||
return opset1::l2_norm(value, reduction_axes, bias, BiasMode::ADD, keep_dims);
|
return opset1::l2_norm(value, reduction_axes, bias, BiasMode::ADD, keep_dims);
|
||||||
}
|
}
|
||||||
// generic case
|
// generic case
|
||||||
else
|
else {
|
||||||
{
|
|
||||||
return detail::opset1::lp_norm(value, p_norm, reduction_axes, bias, keep_dims);
|
return detail::opset1::lp_norm(value, p_norm, reduction_axes, bias, keep_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,11 +2,12 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "ngraph/builder/reduce_ops.hpp"
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "ngraph/axis_set.hpp"
|
#include "ngraph/axis_set.hpp"
|
||||||
#include "ngraph/builder/autobroadcast.hpp"
|
#include "ngraph/builder/autobroadcast.hpp"
|
||||||
#include "ngraph/builder/reduce_ops.hpp"
|
|
||||||
#include "ngraph/op/constant.hpp"
|
#include "ngraph/op/constant.hpp"
|
||||||
#include "ngraph/op/divide.hpp"
|
#include "ngraph/op/divide.hpp"
|
||||||
#include "ngraph/op/multiply.hpp"
|
#include "ngraph/op/multiply.hpp"
|
||||||
@ -16,54 +17,39 @@
|
|||||||
#include "ngraph/opsets/opset1.hpp"
|
#include "ngraph/opsets/opset1.hpp"
|
||||||
#include "ngraph/util.hpp"
|
#include "ngraph/util.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
size_t get_num_elements(const Shape& shape, const AxisSet& reduction_axes) {
|
||||||
{
|
|
||||||
size_t get_num_elements(const Shape& shape, const AxisSet& reduction_axes)
|
|
||||||
{
|
|
||||||
size_t N = 1;
|
size_t N = 1;
|
||||||
for (auto a : reduction_axes)
|
for (auto a : reduction_axes) {
|
||||||
{
|
|
||||||
N *= shape[a];
|
N *= shape[a];
|
||||||
}
|
}
|
||||||
return N;
|
return N;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node> get_num_elements(const Output<Node>& value,
|
std::shared_ptr<Node> get_num_elements(const Output<Node>& value, const Output<Node>& reduction_axes) {
|
||||||
const Output<Node>& reduction_axes)
|
|
||||||
{
|
|
||||||
const auto value_shape = std::make_shared<ngraph::opset1::ShapeOf>(value);
|
const auto value_shape = std::make_shared<ngraph::opset1::ShapeOf>(value);
|
||||||
const auto dim_values = std::make_shared<ngraph::opset1::Gather>(
|
const auto dim_values =
|
||||||
value_shape,
|
std::make_shared<ngraph::opset1::Gather>(value_shape,
|
||||||
reduction_axes,
|
reduction_axes,
|
||||||
ngraph::opset1::Constant::create(element::i64, {}, {0}));
|
ngraph::opset1::Constant::create(element::i64, {}, {0}));
|
||||||
|
|
||||||
return std::make_shared<ngraph::opset1::ReduceProd>(
|
return std::make_shared<ngraph::opset1::ReduceProd>(dim_values,
|
||||||
dim_values, ngraph::opset1::Constant::create(element::i64, {}, {0}));
|
ngraph::opset1::Constant::create(element::i64, {}, {0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value,
|
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value, const AxisSet& reduction_axes, bool keep_dims) {
|
||||||
const AxisSet& reduction_axes,
|
|
||||||
bool keep_dims)
|
|
||||||
{
|
|
||||||
std::shared_ptr<Node> elems_number;
|
std::shared_ptr<Node> elems_number;
|
||||||
const auto value_elem_type = value.get_element_type();
|
const auto value_elem_type = value.get_element_type();
|
||||||
const auto reduction_axes_const = ngraph::opset1::Constant::create(
|
const auto reduction_axes_const =
|
||||||
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector());
|
ngraph::opset1::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector());
|
||||||
const auto value_elems_sum =
|
const auto value_elems_sum = std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes_const, keep_dims);
|
||||||
std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes_const, keep_dims);
|
if (value.get_partial_shape().is_static()) {
|
||||||
if (value.get_partial_shape().is_static())
|
|
||||||
{
|
|
||||||
const auto elems_number_value = get_num_elements(value.get_shape(), reduction_axes);
|
const auto elems_number_value = get_num_elements(value.get_shape(), reduction_axes);
|
||||||
elems_number = ngraph::opset1::Constant::create(
|
elems_number = ngraph::opset1::Constant::create(value_elem_type, Shape{}, {elems_number_value});
|
||||||
value_elem_type, Shape{}, {elems_number_value});
|
} else {
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
elems_number = get_num_elements(value, reduction_axes_const);
|
elems_number = get_num_elements(value, reduction_axes_const);
|
||||||
elems_number =
|
elems_number = std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
|
||||||
std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_shared<ngraph::opset1::Divide>(value_elems_sum, elems_number)
|
return std::make_shared<ngraph::opset1::Divide>(value_elems_sum, elems_number)
|
||||||
@ -72,12 +58,10 @@ namespace ngraph
|
|||||||
|
|
||||||
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value,
|
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value,
|
||||||
const Output<Node>& reduction_axes,
|
const Output<Node>& reduction_axes,
|
||||||
bool keep_dims)
|
bool keep_dims) {
|
||||||
{
|
|
||||||
std::shared_ptr<Node> elems_number;
|
std::shared_ptr<Node> elems_number;
|
||||||
const auto value_elem_type = value.get_element_type();
|
const auto value_elem_type = value.get_element_type();
|
||||||
const auto value_elems_sum =
|
const auto value_elems_sum = std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes, keep_dims);
|
||||||
std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes, keep_dims);
|
|
||||||
elems_number = get_num_elements(value, reduction_axes);
|
elems_number = get_num_elements(value, reduction_axes);
|
||||||
elems_number = std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
|
elems_number = std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
|
||||||
|
|
||||||
@ -87,8 +71,7 @@ namespace ngraph
|
|||||||
|
|
||||||
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
|
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
|
||||||
const AxisSet& reduction_axes,
|
const AxisSet& reduction_axes,
|
||||||
const bool bessel_correction)
|
const bool bessel_correction) {
|
||||||
{
|
|
||||||
const bool keep_dims = true;
|
const bool keep_dims = true;
|
||||||
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
|
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
|
||||||
|
|
||||||
@ -96,21 +79,17 @@ namespace ngraph
|
|||||||
|
|
||||||
diff = std::make_shared<ngraph::opset1::ReduceSum>(
|
diff = std::make_shared<ngraph::opset1::ReduceSum>(
|
||||||
std::make_shared<ngraph::opset1::Multiply>(diff, diff),
|
std::make_shared<ngraph::opset1::Multiply>(diff, diff),
|
||||||
ngraph::opset1::Constant::create(
|
ngraph::opset1::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()),
|
||||||
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()),
|
|
||||||
false);
|
false);
|
||||||
|
|
||||||
const auto& et = value.get_element_type();
|
const auto& et = value.get_element_type();
|
||||||
const auto N = get_num_elements(value.get_shape(), reduction_axes);
|
const auto N = get_num_elements(value.get_shape(), reduction_axes);
|
||||||
|
|
||||||
std::shared_ptr<Node> result;
|
std::shared_ptr<Node> result;
|
||||||
if (bessel_correction)
|
if (bessel_correction) {
|
||||||
{
|
|
||||||
const auto N1const = ngraph::opset1::Constant::create(et, Shape{}, {N - 1});
|
const auto N1const = ngraph::opset1::Constant::create(et, Shape{}, {N - 1});
|
||||||
result = std::make_shared<ngraph::opset1::Divide>(diff, N1const);
|
result = std::make_shared<ngraph::opset1::Divide>(diff, N1const);
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
const auto Nconst = ngraph::opset1::Constant::create(et, Shape{}, {N});
|
const auto Nconst = ngraph::opset1::Constant::create(et, Shape{}, {N});
|
||||||
result = std::make_shared<ngraph::opset1::Divide>(diff, Nconst);
|
result = std::make_shared<ngraph::opset1::Divide>(diff, Nconst);
|
||||||
}
|
}
|
||||||
@ -120,22 +99,21 @@ namespace ngraph
|
|||||||
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
|
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
|
||||||
const Output<Node>& reduction_axes,
|
const Output<Node>& reduction_axes,
|
||||||
bool keep_dims,
|
bool keep_dims,
|
||||||
bool bessel_correction)
|
bool bessel_correction) {
|
||||||
{
|
|
||||||
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
|
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
|
||||||
|
|
||||||
Output<Node> diff = std::make_shared<ngraph::opset1::Subtract>(value, mu);
|
Output<Node> diff = std::make_shared<ngraph::opset1::Subtract>(value, mu);
|
||||||
|
|
||||||
diff = std::make_shared<ngraph::opset1::ReduceSum>(
|
diff = std::make_shared<ngraph::opset1::ReduceSum>(std::make_shared<ngraph::opset1::Multiply>(diff, diff),
|
||||||
std::make_shared<ngraph::opset1::Multiply>(diff, diff), reduction_axes, keep_dims);
|
reduction_axes,
|
||||||
|
keep_dims);
|
||||||
|
|
||||||
const auto& et = value.get_element_type();
|
const auto& et = value.get_element_type();
|
||||||
auto N = get_num_elements(value, reduction_axes);
|
auto N = get_num_elements(value, reduction_axes);
|
||||||
N = std::make_shared<ngraph::opset1::Convert>(N, et);
|
N = std::make_shared<ngraph::opset1::Convert>(N, et);
|
||||||
|
|
||||||
std::shared_ptr<Node> result;
|
std::shared_ptr<Node> result;
|
||||||
if (bessel_correction)
|
if (bessel_correction) {
|
||||||
{
|
|
||||||
const auto one = std::make_shared<ngraph::opset1::Constant>(et, Shape{}, 1);
|
const auto one = std::make_shared<ngraph::opset1::Constant>(et, Shape{}, 1);
|
||||||
N = std::make_shared<ngraph::opset1::Subtract>(N, one);
|
N = std::make_shared<ngraph::opset1::Subtract>(N, one);
|
||||||
}
|
}
|
||||||
|
@ -2,13 +2,14 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "ngraph/builder/reshape.hpp"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "ngraph/axis_vector.hpp"
|
#include "ngraph/axis_vector.hpp"
|
||||||
#include "ngraph/builder/reshape.hpp"
|
|
||||||
#include "ngraph/op/concat.hpp"
|
#include "ngraph/op/concat.hpp"
|
||||||
#include "ngraph/op/constant.hpp"
|
#include "ngraph/op/constant.hpp"
|
||||||
#include "ngraph/op/reduce_prod.hpp"
|
#include "ngraph/op/reduce_prod.hpp"
|
||||||
@ -24,71 +25,54 @@
|
|||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::reshape(const Output<Node>& value, const Shape& shape)
|
shared_ptr<Node> builder::opset1::reshape(const Output<Node>& value, const Shape& shape) {
|
||||||
{
|
if (value.get_partial_shape().same_scheme(shape)) {
|
||||||
if (value.get_partial_shape().same_scheme(shape))
|
|
||||||
{
|
|
||||||
return value.get_node_shared_ptr();
|
return value.get_node_shared_ptr();
|
||||||
}
|
} else if (is_scalar(shape)) {
|
||||||
else if (is_scalar(shape))
|
|
||||||
{
|
|
||||||
auto value_rank = value.get_shape().size();
|
auto value_rank = value.get_shape().size();
|
||||||
AxisVector axes_vector(value_rank);
|
AxisVector axes_vector(value_rank);
|
||||||
std::iota(axes_vector.begin(), axes_vector.end(), 0);
|
std::iota(axes_vector.begin(), axes_vector.end(), 0);
|
||||||
auto axes = op::Constant::create(element::i64, Shape{value_rank}, axes_vector);
|
auto axes = op::Constant::create(element::i64, Shape{value_rank}, axes_vector);
|
||||||
return std::make_shared<op::Squeeze>(value, axes);
|
return std::make_shared<op::Squeeze>(value, axes);
|
||||||
}
|
} else {
|
||||||
else
|
auto out_pattern =
|
||||||
{
|
op::Constant::create(element::i64, Shape{shape.size()}, vector<int64_t>(shape.begin(), shape.end()));
|
||||||
auto out_pattern = op::Constant::create(
|
|
||||||
element::i64, Shape{shape.size()}, vector<int64_t>(shape.begin(), shape.end()));
|
|
||||||
|
|
||||||
return make_shared<ngraph::opset1::Reshape>(value, out_pattern, false)
|
return make_shared<ngraph::opset1::Reshape>(value, out_pattern, false)
|
||||||
->add_provenance_group_members_above({value});
|
->add_provenance_group_members_above({value});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::reorder_axes(const Output<Node>& value, vector<size_t> axes_order)
|
shared_ptr<Node> builder::opset1::reorder_axes(const Output<Node>& value, vector<size_t> axes_order) {
|
||||||
{
|
const auto axes_order_const = op::Constant::create(element::i64,
|
||||||
const auto axes_order_const =
|
|
||||||
op::Constant::create(element::i64,
|
|
||||||
Shape{axes_order.size()},
|
Shape{axes_order.size()},
|
||||||
vector<int64_t>(axes_order.begin(), axes_order.end()));
|
vector<int64_t>(axes_order.begin(), axes_order.end()));
|
||||||
return make_shared<ngraph::opset1::Transpose>(value, axes_order_const)
|
return make_shared<ngraph::opset1::Transpose>(value, axes_order_const)->add_provenance_group_members_above({value});
|
||||||
->add_provenance_group_members_above({value});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::transpose(const Output<Node>& value)
|
shared_ptr<Node> builder::opset1::transpose(const Output<Node>& value) {
|
||||||
{
|
|
||||||
// This part is left to preserve backward compatibility and ensure passing ONNX tests.
|
// This part is left to preserve backward compatibility and ensure passing ONNX tests.
|
||||||
if (value.get_partial_shape().is_static())
|
if (value.get_partial_shape().is_static()) {
|
||||||
{
|
|
||||||
vector<size_t> axes_order(value.get_shape().size());
|
vector<size_t> axes_order(value.get_shape().size());
|
||||||
iota(begin(axes_order), end(axes_order), 0);
|
iota(begin(axes_order), end(axes_order), 0);
|
||||||
reverse(begin(axes_order), end(axes_order));
|
reverse(begin(axes_order), end(axes_order));
|
||||||
return builder::opset1::reorder_axes(value, axes_order);
|
return builder::opset1::reorder_axes(value, axes_order);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto input_rank =
|
const auto input_rank = std::make_shared<ngraph::opset1::ShapeOf>(std::make_shared<ngraph::opset1::ShapeOf>(value));
|
||||||
std::make_shared<ngraph::opset1::ShapeOf>(std::make_shared<ngraph::opset1::ShapeOf>(value));
|
|
||||||
const auto neg_one = ngraph::opset1::Constant::create(element::i64, Shape{}, {-1});
|
const auto neg_one = ngraph::opset1::Constant::create(element::i64, Shape{}, {-1});
|
||||||
const auto start_node = std::make_shared<ngraph::opset1::Add>(input_rank, neg_one);
|
const auto start_node = std::make_shared<ngraph::opset1::Add>(input_rank, neg_one);
|
||||||
const auto reverse_axes_order =
|
const auto reverse_axes_order = std::make_shared<ngraph::opset1::Range>(reshape(start_node, Shape{}), // start
|
||||||
std::make_shared<ngraph::opset1::Range>(reshape(start_node, Shape{}), // start
|
|
||||||
neg_one, // stop (exclusive)
|
neg_one, // stop (exclusive)
|
||||||
neg_one); // step
|
neg_one); // step
|
||||||
return std::make_shared<ngraph::opset1::Transpose>(value, reverse_axes_order)
|
return std::make_shared<ngraph::opset1::Transpose>(value, reverse_axes_order)
|
||||||
->add_provenance_group_members_above({value});
|
->add_provenance_group_members_above({value});
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace builder {
|
||||||
namespace builder
|
namespace opset1 {
|
||||||
{
|
namespace {
|
||||||
namespace opset1
|
|
||||||
{
|
|
||||||
namespace
|
|
||||||
{
|
|
||||||
///
|
///
|
||||||
/// \brief Return the node representing normalized axis with respect to
|
/// \brief Return the node representing normalized axis with respect to
|
||||||
/// provided rank.
|
/// provided rank.
|
||||||
@ -98,14 +82,10 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The new Constant node representing normalized axis value.
|
/// \return The new Constant node representing normalized axis value.
|
||||||
///
|
///
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> get_normalized_axis_node(const std::shared_ptr<Node> node_rank, int64_t axis) {
|
||||||
get_normalized_axis_node(const std::shared_ptr<Node> node_rank, int64_t axis)
|
auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{1}, {axis});
|
||||||
{
|
|
||||||
auto axis_node =
|
|
||||||
ngraph::opset1::Constant::create(element::i64, Shape{1}, {axis});
|
|
||||||
// shortcut for alredy positive value
|
// shortcut for alredy positive value
|
||||||
if (axis >= 0)
|
if (axis >= 0) {
|
||||||
{
|
|
||||||
return axis_node;
|
return axis_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,47 +98,40 @@ namespace ngraph
|
|||||||
} // namespace builder
|
} // namespace builder
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::flatten(const Output<Node>& value, int axis)
|
shared_ptr<Node> builder::opset1::flatten(const Output<Node>& value, int axis) {
|
||||||
{
|
|
||||||
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of
|
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of
|
||||||
// input tensor. The last dimension is the product of the rest of input tensor dimensions:
|
// input tensor. The last dimension is the product of the rest of input tensor dimensions:
|
||||||
// [d_{axis}, ..., d_n]
|
// [d_{axis}, ..., d_n]
|
||||||
shared_ptr<Node> output_shape;
|
shared_ptr<Node> output_shape;
|
||||||
if (axis == 0)
|
if (axis == 0) {
|
||||||
{
|
|
||||||
output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {1, -1});
|
output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {1, -1});
|
||||||
}
|
} else if (axis == 1) {
|
||||||
else if (axis == 1)
|
|
||||||
{
|
|
||||||
output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {0, -1});
|
output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {0, -1});
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
const auto value_shape = make_shared<ngraph::opset1::ShapeOf>(value);
|
const auto value_shape = make_shared<ngraph::opset1::ShapeOf>(value);
|
||||||
const auto value_rank = make_shared<ngraph::opset1::ShapeOf>(value_shape);
|
const auto value_rank = make_shared<ngraph::opset1::ShapeOf>(value_shape);
|
||||||
const auto axis_node = get_normalized_axis_node(value_rank, axis);
|
const auto axis_node = get_normalized_axis_node(value_rank, axis);
|
||||||
|
|
||||||
const auto first_part_dims = make_shared<ngraph::opset1::StridedSlice>(
|
const auto first_part_dims =
|
||||||
value_shape,
|
make_shared<ngraph::opset1::StridedSlice>(value_shape,
|
||||||
ngraph::opset1::Constant::create(element::i64, {1}, {0}),
|
ngraph::opset1::Constant::create(element::i64, {1}, {0}),
|
||||||
axis_node,
|
axis_node,
|
||||||
vector<int64_t>{},
|
vector<int64_t>{},
|
||||||
vector<int64_t>{});
|
vector<int64_t>{});
|
||||||
const auto first_part_dims_length = make_shared<ngraph::opset1::ReduceProd>(
|
const auto first_part_dims_length =
|
||||||
first_part_dims, ngraph::opset1::Constant::create(element::i64, {}, {0}), true);
|
make_shared<ngraph::opset1::ReduceProd>(first_part_dims,
|
||||||
|
ngraph::opset1::Constant::create(element::i64, {}, {0}),
|
||||||
|
true);
|
||||||
|
|
||||||
const auto remaining_part_length =
|
const auto remaining_part_length = ngraph::opset1::Constant::create(element::i64, {1}, {-1});
|
||||||
ngraph::opset1::Constant::create(element::i64, {1}, {-1});
|
|
||||||
|
|
||||||
output_shape = make_shared<ngraph::opset1::Concat>(
|
output_shape =
|
||||||
OutputVector{first_part_dims_length, remaining_part_length}, 0);
|
make_shared<ngraph::opset1::Concat>(OutputVector{first_part_dims_length, remaining_part_length}, 0);
|
||||||
}
|
}
|
||||||
return make_shared<ngraph::opset1::Reshape>(value, output_shape, true)
|
return make_shared<ngraph::opset1::Reshape>(value, output_shape, true)->add_provenance_group_members_above({value});
|
||||||
->add_provenance_group_members_above({value});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t axis)
|
shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t axis) {
|
||||||
{
|
|
||||||
Shape output_shape(value.get_shape());
|
Shape output_shape(value.get_shape());
|
||||||
// Add empty axis at specified position.
|
// Add empty axis at specified position.
|
||||||
auto empty_axis_it = begin(output_shape);
|
auto empty_axis_it = begin(output_shape);
|
||||||
@ -167,40 +140,30 @@ shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t
|
|||||||
return builder::opset1::reshape(value, output_shape);
|
return builder::opset1::reshape(value, output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<size_t> axes)
|
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<size_t> axes) {
|
||||||
{
|
if (axes.empty()) {
|
||||||
if (axes.empty())
|
|
||||||
{
|
|
||||||
return value.get_node_shared_ptr();
|
return value.get_node_shared_ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
Shape in_shape{value.get_shape()};
|
Shape in_shape{value.get_shape()};
|
||||||
for (size_t idx = 0; idx < axes.size(); ++idx)
|
for (size_t idx = 0; idx < axes.size(); ++idx) {
|
||||||
{
|
|
||||||
in_shape.at(axes.at(idx)) = 0;
|
in_shape.at(axes.at(idx)) = 0;
|
||||||
}
|
}
|
||||||
Shape output_shape;
|
Shape output_shape;
|
||||||
for (auto axis : in_shape)
|
for (auto axis : in_shape) {
|
||||||
{
|
if (axis != 0) {
|
||||||
if (axis != 0)
|
|
||||||
{
|
|
||||||
output_shape.push_back(axis);
|
output_shape.push_back(axis);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return builder::opset1::reshape(value, output_shape);
|
return builder::opset1::reshape(value, output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> builder::opset1::collapse(const Output<Node>& value,
|
shared_ptr<Node> builder::opset1::collapse(const Output<Node>& value, const size_t start_axis, const size_t end_axis) {
|
||||||
const size_t start_axis,
|
if (start_axis == end_axis) {
|
||||||
const size_t end_axis)
|
|
||||||
{
|
|
||||||
if (start_axis == end_axis)
|
|
||||||
{
|
|
||||||
return value.get_node_shared_ptr();
|
return value.get_node_shared_ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (value.get_partial_shape().is_static())
|
if (value.get_partial_shape().is_static()) {
|
||||||
{
|
|
||||||
auto shape = value.get_shape();
|
auto shape = value.get_shape();
|
||||||
// Multiply all alements of shape from start_axis to end_axis inclusive
|
// Multiply all alements of shape from start_axis to end_axis inclusive
|
||||||
size_t collapsed_axis_size = accumulate(next(begin(shape), start_axis),
|
size_t collapsed_axis_size = accumulate(next(begin(shape), start_axis),
|
||||||
@ -220,22 +183,20 @@ shared_ptr<Node> builder::opset1::collapse(const Output<Node>& value,
|
|||||||
// Split lengths used in VariadicSplit
|
// Split lengths used in VariadicSplit
|
||||||
const auto start_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {start_axis});
|
const auto start_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {start_axis});
|
||||||
const auto end_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {end_axis + 1});
|
const auto end_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {end_axis + 1});
|
||||||
const auto collapsed_axis =
|
const auto collapsed_axis = make_shared<ngraph::opset1::Subtract>(end_axis_node, start_axis_node);
|
||||||
make_shared<ngraph::opset1::Subtract>(end_axis_node, start_axis_node);
|
|
||||||
const auto post_axis = make_shared<ngraph::opset1::Subtract>(rank, end_axis_node);
|
const auto post_axis = make_shared<ngraph::opset1::Subtract>(rank, end_axis_node);
|
||||||
|
|
||||||
const auto split_lengths = make_shared<ngraph::opset1::Concat>(
|
const auto split_lengths =
|
||||||
OutputVector{start_axis_node, collapsed_axis, post_axis}, 0);
|
make_shared<ngraph::opset1::Concat>(OutputVector{start_axis_node, collapsed_axis, post_axis}, 0);
|
||||||
const auto split_axis = ngraph::opset1::Constant::create(element::i64, {}, {0});
|
const auto split_axis = ngraph::opset1::Constant::create(element::i64, {}, {0});
|
||||||
const auto split_node =
|
const auto split_node = make_shared<ngraph::opset1::VariadicSplit>(shape, split_axis, split_lengths);
|
||||||
make_shared<ngraph::opset1::VariadicSplit>(shape, split_axis, split_lengths);
|
|
||||||
|
|
||||||
const auto reduced_axis = ngraph::opset1::Constant::create(element::i64, {1}, {0});
|
const auto reduced_axis = ngraph::opset1::Constant::create(element::i64, {1}, {0});
|
||||||
const auto collapsed_axis_size =
|
const auto collapsed_axis_size = make_shared<ngraph::opset1::ReduceProd>(split_node->output(1), reduced_axis, true);
|
||||||
make_shared<ngraph::opset1::ReduceProd>(split_node->output(1), reduced_axis, true);
|
|
||||||
|
|
||||||
const auto collapsed_shape = make_shared<ngraph::opset1::Concat>(
|
const auto collapsed_shape = make_shared<ngraph::opset1::Concat>(
|
||||||
OutputVector{split_node->output(0), collapsed_axis_size, split_node->output(2)}, 0);
|
OutputVector{split_node->output(0), collapsed_axis_size, split_node->output(2)},
|
||||||
|
0);
|
||||||
|
|
||||||
return make_shared<ngraph::opset1::Reshape>(value, collapsed_shape, false);
|
return make_shared<ngraph::opset1::Reshape>(value, collapsed_shape, false);
|
||||||
}
|
}
|
||||||
|
@ -3,25 +3,21 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include "ngraph/builder/split.hpp"
|
#include "ngraph/builder/split.hpp"
|
||||||
|
|
||||||
#include "ngraph/opsets/opset1.hpp"
|
#include "ngraph/opsets/opset1.hpp"
|
||||||
|
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
OutputVector builder::opset1::split(const Output<Node>& value,
|
OutputVector builder::opset1::split(const Output<Node>& value, const std::vector<size_t>& split_lengths, int64_t axis) {
|
||||||
const std::vector<size_t>& split_lengths,
|
|
||||||
int64_t axis)
|
|
||||||
{
|
|
||||||
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
|
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
|
||||||
const auto split_lengths_node =
|
const auto split_lengths_node =
|
||||||
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
|
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
|
||||||
const auto variadic_split =
|
const auto variadic_split = std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
|
||||||
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
|
|
||||||
|
|
||||||
return variadic_split->outputs();
|
return variadic_split->outputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
|
OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis) {
|
||||||
{
|
|
||||||
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
|
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
|
||||||
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
|
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
|
||||||
|
|
||||||
|
@ -12,8 +12,7 @@
|
|||||||
#include "ngraph/type.hpp"
|
#include "ngraph/type.hpp"
|
||||||
|
|
||||||
///
|
///
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
class AttributeVisitor;
|
class AttributeVisitor;
|
||||||
|
|
||||||
/// \brief Provides access to an attribute of type AT as a value accessor type VAT
|
/// \brief Provides access to an attribute of type AT as a value accessor type VAT
|
||||||
@ -26,8 +25,7 @@ namespace ngraph
|
|||||||
/// All ValueAccessors must be derived from ValueAccessor<void> so that an AttributeVisitor
|
/// All ValueAccessors must be derived from ValueAccessor<void> so that an AttributeVisitor
|
||||||
/// only needs to implement a subset of the on_adapter methods.
|
/// only needs to implement a subset of the on_adapter methods.
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API ValueAccessor<void>
|
class NGRAPH_API ValueAccessor<void> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
/// \brief type info enables identification of the value accessor, as well as is_type and
|
/// \brief type info enables identification of the value accessor, as well as is_type and
|
||||||
/// as_type.
|
/// as_type.
|
||||||
@ -45,8 +43,7 @@ namespace ngraph
|
|||||||
/// changed.
|
/// changed.
|
||||||
/// \tparam VAT The adapter value type; may be wider than the value being accessed.
|
/// \tparam VAT The adapter value type; may be wider than the value being accessed.
|
||||||
template <typename VAT>
|
template <typename VAT>
|
||||||
class ValueAccessor : public ValueAccessor<void>
|
class ValueAccessor : public ValueAccessor<void> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
/// Returns the value
|
/// Returns the value
|
||||||
virtual const VAT& get() = 0;
|
virtual const VAT& get() = 0;
|
||||||
@ -55,50 +52,41 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class ValueAccessor<void*> : public ValueAccessor<void>
|
class ValueAccessor<void*> : public ValueAccessor<void> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
virtual void* get_ptr() = 0;
|
virtual void* get_ptr() = 0;
|
||||||
virtual size_t size() = 0;
|
virtual size_t size() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename AT>
|
template <typename AT>
|
||||||
class DirectValueAccessor : public ValueAccessor<AT>
|
class DirectValueAccessor : public ValueAccessor<AT> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
DirectValueAccessor(AT& ref)
|
DirectValueAccessor(AT& ref) : m_ref(ref) {}
|
||||||
: m_ref(ref)
|
const AT& get() override {
|
||||||
{
|
return m_ref;
|
||||||
|
}
|
||||||
|
void set(const AT& value) override {
|
||||||
|
m_ref = value;
|
||||||
}
|
}
|
||||||
const AT& get() override { return m_ref; }
|
|
||||||
void set(const AT& value) override { m_ref = value; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
AT& m_ref;
|
AT& m_ref;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename AT, typename VAT>
|
template <typename AT, typename VAT>
|
||||||
class IndirectScalarValueAccessor : public ValueAccessor<VAT>
|
class IndirectScalarValueAccessor : public ValueAccessor<VAT> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
IndirectScalarValueAccessor(AT& ref)
|
IndirectScalarValueAccessor(AT& ref) : m_ref(ref), m_buffer() {}
|
||||||
: m_ref(ref)
|
|
||||||
, m_buffer()
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
const VAT& get() override
|
const VAT& get() override {
|
||||||
{
|
if (!m_buffer_valid) {
|
||||||
if (!m_buffer_valid)
|
|
||||||
{
|
|
||||||
m_buffer = static_cast<VAT>(m_ref);
|
m_buffer = static_cast<VAT>(m_ref);
|
||||||
m_buffer_valid = true;
|
m_buffer_valid = true;
|
||||||
}
|
}
|
||||||
return m_buffer;
|
return m_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set(const VAT& value) override
|
void set(const VAT& value) override {
|
||||||
{
|
|
||||||
m_ref = static_cast<AT>(value);
|
m_ref = static_cast<AT>(value);
|
||||||
m_buffer_valid = false;
|
m_buffer_valid = false;
|
||||||
}
|
}
|
||||||
@ -110,43 +98,35 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename A, typename B>
|
template <typename A, typename B>
|
||||||
A copy_from(B& b)
|
A copy_from(B& b) {
|
||||||
{
|
|
||||||
A result(b.size());
|
A result(b.size());
|
||||||
for (size_t i = 0; i < b.size(); ++i)
|
for (size_t i = 0; i < b.size(); ++i) {
|
||||||
{
|
result[i] = static_cast<typename std::remove_reference<decltype(result[i])>::type>(b[i]);
|
||||||
result[i] =
|
|
||||||
static_cast<typename std::remove_reference<decltype(result[i])>::type>(b[i]);
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AT, typename VAT>
|
template <typename AT, typename VAT>
|
||||||
class IndirectVectorValueAccessor : public ValueAccessor<VAT>
|
class IndirectVectorValueAccessor : public ValueAccessor<VAT> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
IndirectVectorValueAccessor(AT& ref)
|
IndirectVectorValueAccessor(AT& ref) : m_ref(ref) {}
|
||||||
: m_ref(ref)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
const VAT& get() override
|
const VAT& get() override {
|
||||||
{
|
if (!m_buffer_valid) {
|
||||||
if (!m_buffer_valid)
|
|
||||||
{
|
|
||||||
m_buffer = copy_from<typename std::remove_cv<VAT>::type>(m_ref);
|
m_buffer = copy_from<typename std::remove_cv<VAT>::type>(m_ref);
|
||||||
m_buffer_valid = true;
|
m_buffer_valid = true;
|
||||||
}
|
}
|
||||||
return m_buffer;
|
return m_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set(const VAT& value) override
|
void set(const VAT& value) override {
|
||||||
{
|
|
||||||
m_ref = copy_from<AT>(value);
|
m_ref = copy_from<AT>(value);
|
||||||
m_buffer_valid = false;
|
m_buffer_valid = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
operator AT&() { return m_ref; }
|
operator AT&() {
|
||||||
|
return m_ref;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
AT& m_ref;
|
AT& m_ref;
|
||||||
@ -157,236 +137,202 @@ namespace ngraph
|
|||||||
/// \brief An AttributeAdapter "captures" an attribute as an AT& and makes it available as a
|
/// \brief An AttributeAdapter "captures" an attribute as an AT& and makes it available as a
|
||||||
/// ValueAccessor<VAT>.
|
/// ValueAccessor<VAT>.
|
||||||
template <typename AT>
|
template <typename AT>
|
||||||
class AttributeAdapter
|
class AttributeAdapter {};
|
||||||
{
|
|
||||||
};
|
|
||||||
|
|
||||||
/// \brief Access an enum via a string
|
/// \brief Access an enum via a string
|
||||||
/// \tparam AT The attribute type enum class
|
/// \tparam AT The attribute type enum class
|
||||||
template <typename AT>
|
template <typename AT>
|
||||||
class EnumAttributeAdapterBase : public ValueAccessor<std::string>
|
class EnumAttributeAdapterBase : public ValueAccessor<std::string> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
EnumAttributeAdapterBase(AT& value)
|
EnumAttributeAdapterBase(AT& value) : m_ref(value) {}
|
||||||
: m_ref(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string& get() override { return as_string(m_ref); }
|
const std::string& get() override {
|
||||||
void set(const std::string& value) override { m_ref = as_enum<AT>(value); }
|
return as_string(m_ref);
|
||||||
operator AT&() { return m_ref; }
|
}
|
||||||
|
void set(const std::string& value) override {
|
||||||
|
m_ref = as_enum<AT>(value);
|
||||||
|
}
|
||||||
|
operator AT&() {
|
||||||
|
return m_ref;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
AT& m_ref;
|
AT& m_ref;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Adapters will see visitor
|
/// Adapters will see visitor
|
||||||
class VisitorAdapter : public ValueAccessor<void>
|
class VisitorAdapter : public ValueAccessor<void> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
virtual bool visit_attributes(AttributeVisitor& visitor) = 0;
|
virtual bool visit_attributes(AttributeVisitor& visitor) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<float> : public IndirectScalarValueAccessor<float, double>
|
class NGRAPH_API AttributeAdapter<float> : public IndirectScalarValueAccessor<float, double> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(float& value)
|
AttributeAdapter(float& value) : IndirectScalarValueAccessor<float, double>(value) {}
|
||||||
: IndirectScalarValueAccessor<float, double>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<float>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<float>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a double as a double
|
/// \brief Access a double as a double
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<double> : public DirectValueAccessor<double>
|
class NGRAPH_API AttributeAdapter<double> : public DirectValueAccessor<double> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(double& value)
|
AttributeAdapter(double& value) : DirectValueAccessor<double>(value) {}
|
||||||
: DirectValueAccessor<double>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<double>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<double>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a string as a string
|
/// \brief Access a string as a string
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::string> : public DirectValueAccessor<std::string>
|
class NGRAPH_API AttributeAdapter<std::string> : public DirectValueAccessor<std::string> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::string& value)
|
AttributeAdapter(std::string& value) : DirectValueAccessor<std::string>(value) {}
|
||||||
: DirectValueAccessor<std::string>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<string>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<string>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a bool as a bool
|
/// \brief Access a bool as a bool
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<bool> : public DirectValueAccessor<bool>
|
class NGRAPH_API AttributeAdapter<bool> : public DirectValueAccessor<bool> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(bool& value)
|
AttributeAdapter(bool& value) : DirectValueAccessor<bool>(value) {}
|
||||||
: DirectValueAccessor<bool>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<bool>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<bool>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access an int8_t and an int64_t
|
/// \brief Access an int8_t and an int64_t
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<int8_t> : public IndirectScalarValueAccessor<int8_t, int64_t>
|
class NGRAPH_API AttributeAdapter<int8_t> : public IndirectScalarValueAccessor<int8_t, int64_t> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(int8_t& value)
|
AttributeAdapter(int8_t& value) : IndirectScalarValueAccessor<int8_t, int64_t>(value) {}
|
||||||
: IndirectScalarValueAccessor<int8_t, int64_t>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int8_t>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int8_t>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access an int16_t as an int64_t
|
/// \brief Access an int16_t as an int64_t
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<int16_t>
|
class NGRAPH_API AttributeAdapter<int16_t> : public IndirectScalarValueAccessor<int16_t, int64_t> {
|
||||||
: public IndirectScalarValueAccessor<int16_t, int64_t>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(int16_t& value)
|
AttributeAdapter(int16_t& value) : IndirectScalarValueAccessor<int16_t, int64_t>(value) {}
|
||||||
: IndirectScalarValueAccessor<int16_t, int64_t>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int16_t>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int16_t>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access an int32_t as an int64_t
|
/// \brief Access an int32_t as an int64_t
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<int32_t>
|
class NGRAPH_API AttributeAdapter<int32_t> : public IndirectScalarValueAccessor<int32_t, int64_t> {
|
||||||
: public IndirectScalarValueAccessor<int32_t, int64_t>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(int32_t& value)
|
AttributeAdapter(int32_t& value) : IndirectScalarValueAccessor<int32_t, int64_t>(value) {}
|
||||||
: IndirectScalarValueAccessor<int32_t, int64_t>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int32_t>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int32_t>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access an int64_t as an int64_t
|
/// \brief Access an int64_t as an int64_t
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<int64_t> : public DirectValueAccessor<int64_t>
|
class NGRAPH_API AttributeAdapter<int64_t> : public DirectValueAccessor<int64_t> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(int64_t& value)
|
AttributeAdapter(int64_t& value) : DirectValueAccessor<int64_t>(value) {}
|
||||||
: DirectValueAccessor<int64_t>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int64_t>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int64_t>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a uint8_t as an int64_t
|
/// \brief Access a uint8_t as an int64_t
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<uint8_t>
|
class NGRAPH_API AttributeAdapter<uint8_t> : public IndirectScalarValueAccessor<uint8_t, int64_t> {
|
||||||
: public IndirectScalarValueAccessor<uint8_t, int64_t>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(uint8_t& value)
|
AttributeAdapter(uint8_t& value) : IndirectScalarValueAccessor<uint8_t, int64_t>(value) {}
|
||||||
: IndirectScalarValueAccessor<uint8_t, int64_t>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint8_t>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint8_t>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a uint16_t as an int64_t
|
/// \brief Access a uint16_t as an int64_t
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<uint16_t>
|
class NGRAPH_API AttributeAdapter<uint16_t> : public IndirectScalarValueAccessor<uint16_t, int64_t> {
|
||||||
: public IndirectScalarValueAccessor<uint16_t, int64_t>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(uint16_t& value)
|
AttributeAdapter(uint16_t& value) : IndirectScalarValueAccessor<uint16_t, int64_t>(value) {}
|
||||||
: IndirectScalarValueAccessor<uint16_t, int64_t>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint16_t>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint16_t>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a uint32_t as an int64_t
|
/// \brief Access a uint32_t as an int64_t
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<uint32_t>
|
class NGRAPH_API AttributeAdapter<uint32_t> : public IndirectScalarValueAccessor<uint32_t, int64_t> {
|
||||||
: public IndirectScalarValueAccessor<uint32_t, int64_t>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(uint32_t& value)
|
AttributeAdapter(uint32_t& value) : IndirectScalarValueAccessor<uint32_t, int64_t>(value) {}
|
||||||
: IndirectScalarValueAccessor<uint32_t, int64_t>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint32_t>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint32_t>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a uint64_t as an int64_t
|
/// \brief Access a uint64_t as an int64_t
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<uint64_t>
|
class NGRAPH_API AttributeAdapter<uint64_t> : public IndirectScalarValueAccessor<uint64_t, int64_t> {
|
||||||
: public IndirectScalarValueAccessor<uint64_t, int64_t>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(uint64_t& value)
|
AttributeAdapter(uint64_t& value) : IndirectScalarValueAccessor<uint64_t, int64_t>(value) {}
|
||||||
: IndirectScalarValueAccessor<uint64_t, int64_t>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint64_t>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint64_t>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
// size_t is one of the uint types on _WIN32
|
// size_t is one of the uint types on _WIN32
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<size_t> : public IndirectScalarValueAccessor<size_t, int64_t>
|
class NGRAPH_API AttributeAdapter<size_t> : public IndirectScalarValueAccessor<size_t, int64_t> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(size_t& value)
|
AttributeAdapter(size_t& value) : IndirectScalarValueAccessor<size_t, int64_t>(value) {}
|
||||||
: IndirectScalarValueAccessor<size_t, int64_t>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<size_t>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<size_t>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<size_t>>
|
class NGRAPH_API AttributeAdapter<std::vector<size_t>>
|
||||||
: public IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>>
|
: public IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<size_t>& value)
|
AttributeAdapter(std::vector<size_t>& value)
|
||||||
: IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>>(value)
|
: IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>>(value) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<size_t>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<size_t>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -395,166 +341,133 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \brief Access a vector<int8_t>
|
/// \brief Access a vector<int8_t>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<int8_t>>
|
class NGRAPH_API AttributeAdapter<std::vector<int8_t>> : public DirectValueAccessor<std::vector<int8_t>> {
|
||||||
: public DirectValueAccessor<std::vector<int8_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<int8_t>& value)
|
AttributeAdapter(std::vector<int8_t>& value) : DirectValueAccessor<std::vector<int8_t>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<int8_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int8_t>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int8_t>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<int16_t>
|
/// \brief Access a vector<int16_t>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<int16_t>>
|
class NGRAPH_API AttributeAdapter<std::vector<int16_t>> : public DirectValueAccessor<std::vector<int16_t>> {
|
||||||
: public DirectValueAccessor<std::vector<int16_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<int16_t>& value)
|
AttributeAdapter(std::vector<int16_t>& value) : DirectValueAccessor<std::vector<int16_t>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<int16_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int16_t>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int16_t>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<int32_t>
|
/// \brief Access a vector<int32_t>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<int32_t>>
|
class NGRAPH_API AttributeAdapter<std::vector<int32_t>> : public DirectValueAccessor<std::vector<int32_t>> {
|
||||||
: public DirectValueAccessor<std::vector<int32_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<int32_t>& value)
|
AttributeAdapter(std::vector<int32_t>& value) : DirectValueAccessor<std::vector<int32_t>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<int32_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int32_t>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int32_t>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<int64_t>
|
/// \brief Access a vector<int64_t>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<int64_t>>
|
class NGRAPH_API AttributeAdapter<std::vector<int64_t>> : public DirectValueAccessor<std::vector<int64_t>> {
|
||||||
: public DirectValueAccessor<std::vector<int64_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<int64_t>& value)
|
AttributeAdapter(std::vector<int64_t>& value) : DirectValueAccessor<std::vector<int64_t>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<int64_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int64_t>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int64_t>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<uint8_t>
|
/// \brief Access a vector<uint8_t>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<uint8_t>>
|
class NGRAPH_API AttributeAdapter<std::vector<uint8_t>> : public DirectValueAccessor<std::vector<uint8_t>> {
|
||||||
: public DirectValueAccessor<std::vector<uint8_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<uint8_t>& value)
|
AttributeAdapter(std::vector<uint8_t>& value) : DirectValueAccessor<std::vector<uint8_t>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<uint8_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint8_t>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint8_t>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<uint16_t>
|
/// \brief Access a vector<uint16_t>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<uint16_t>>
|
class NGRAPH_API AttributeAdapter<std::vector<uint16_t>> : public DirectValueAccessor<std::vector<uint16_t>> {
|
||||||
: public DirectValueAccessor<std::vector<uint16_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<uint16_t>& value)
|
AttributeAdapter(std::vector<uint16_t>& value) : DirectValueAccessor<std::vector<uint16_t>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<uint16_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint16_t>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint16_t>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<uint32_t>
|
/// \brief Access a vector<uint32_t>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<uint32_t>>
|
class NGRAPH_API AttributeAdapter<std::vector<uint32_t>> : public DirectValueAccessor<std::vector<uint32_t>> {
|
||||||
: public DirectValueAccessor<std::vector<uint32_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<uint32_t>& value)
|
AttributeAdapter(std::vector<uint32_t>& value) : DirectValueAccessor<std::vector<uint32_t>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<uint32_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint32_t>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint32_t>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<uint64_t>
|
/// \brief Access a vector<uint64_t>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<uint64_t>>
|
class NGRAPH_API AttributeAdapter<std::vector<uint64_t>> : public DirectValueAccessor<std::vector<uint64_t>> {
|
||||||
: public DirectValueAccessor<std::vector<uint64_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<uint64_t>& value)
|
AttributeAdapter(std::vector<uint64_t>& value) : DirectValueAccessor<std::vector<uint64_t>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<uint64_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint64_t>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint64_t>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<float>
|
/// \brief Access a vector<float>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<float>>
|
class NGRAPH_API AttributeAdapter<std::vector<float>> : public DirectValueAccessor<std::vector<float>> {
|
||||||
: public DirectValueAccessor<std::vector<float>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<float>& value)
|
AttributeAdapter(std::vector<float>& value) : DirectValueAccessor<std::vector<float>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<float>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<float>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<float>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<double>
|
/// \brief Access a vector<double>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<double>>
|
class NGRAPH_API AttributeAdapter<std::vector<double>> : public DirectValueAccessor<std::vector<double>> {
|
||||||
: public DirectValueAccessor<std::vector<double>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<double>& value)
|
AttributeAdapter(std::vector<double>& value) : DirectValueAccessor<std::vector<double>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<double>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<double>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<double>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Access a vector<string>
|
/// \brief Access a vector<string>
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::vector<std::string>>
|
class NGRAPH_API AttributeAdapter<std::vector<std::string>> : public DirectValueAccessor<std::vector<std::string>> {
|
||||||
: public DirectValueAccessor<std::vector<std::string>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::vector<std::string>& value)
|
AttributeAdapter(std::vector<std::string>& value) : DirectValueAccessor<std::vector<std::string>>(value) {}
|
||||||
: DirectValueAccessor<std::vector<std::string>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<string>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<string>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -12,8 +12,7 @@
|
|||||||
#include "ngraph/type.hpp"
|
#include "ngraph/type.hpp"
|
||||||
#include "ngraph/type/element_type.hpp"
|
#include "ngraph/type/element_type.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class ValueAccessor;
|
class ValueAccessor;
|
||||||
class VisitorAdapter;
|
class VisitorAdapter;
|
||||||
@ -55,8 +54,7 @@ namespace ngraph
|
|||||||
/// registered with the visitor using register_node, which needs a shared pointer to a node and
|
/// registered with the visitor using register_node, which needs a shared pointer to a node and
|
||||||
/// a string ID. The ID string will be used to serialize the node or find the node during
|
/// a string ID. The ID string will be used to serialize the node or find the node during
|
||||||
/// deserialization.
|
/// deserialization.
|
||||||
class NGRAPH_API AttributeVisitor
|
class NGRAPH_API AttributeVisitor {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
virtual ~AttributeVisitor() {}
|
virtual ~AttributeVisitor() {}
|
||||||
// Must implement these methods
|
// Must implement these methods
|
||||||
@ -80,49 +78,38 @@ namespace ngraph
|
|||||||
virtual void on_adapter(const std::string& name, ValueAccessor<uint64_t>& adapter);
|
virtual void on_adapter(const std::string& name, ValueAccessor<uint64_t>& adapter);
|
||||||
virtual void on_adapter(const std::string& name, ValueAccessor<float>& adapter);
|
virtual void on_adapter(const std::string& name, ValueAccessor<float>& adapter);
|
||||||
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter);
|
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter);
|
||||||
virtual void on_adapter(const std::string& name,
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int8_t>>& adapter);
|
||||||
ValueAccessor<std::vector<int8_t>>& adapter);
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int16_t>>& adapter);
|
||||||
virtual void on_adapter(const std::string& name,
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int32_t>>& adapter);
|
||||||
ValueAccessor<std::vector<int16_t>>& adapter);
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int64_t>>& adapter);
|
||||||
virtual void on_adapter(const std::string& name,
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint8_t>>& adapter);
|
||||||
ValueAccessor<std::vector<int32_t>>& adapter);
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint16_t>>& adapter);
|
||||||
virtual void on_adapter(const std::string& name,
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint32_t>>& adapter);
|
||||||
ValueAccessor<std::vector<int64_t>>& adapter);
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint64_t>>& adapter);
|
||||||
virtual void on_adapter(const std::string& name,
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<float>>& adapter);
|
||||||
ValueAccessor<std::vector<uint8_t>>& adapter);
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<double>>& adapter);
|
||||||
virtual void on_adapter(const std::string& name,
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<std::string>>& adapter);
|
||||||
ValueAccessor<std::vector<uint16_t>>& adapter);
|
|
||||||
virtual void on_adapter(const std::string& name,
|
|
||||||
ValueAccessor<std::vector<uint32_t>>& adapter);
|
|
||||||
virtual void on_adapter(const std::string& name,
|
|
||||||
ValueAccessor<std::vector<uint64_t>>& adapter);
|
|
||||||
virtual void on_adapter(const std::string& name,
|
|
||||||
ValueAccessor<std::vector<float>>& adapter);
|
|
||||||
virtual void on_adapter(const std::string& name,
|
|
||||||
ValueAccessor<std::vector<double>>& adapter);
|
|
||||||
virtual void on_adapter(const std::string& name,
|
|
||||||
ValueAccessor<std::vector<std::string>>& adapter);
|
|
||||||
/// \brief Hook for adapters that need visitor access
|
/// \brief Hook for adapters that need visitor access
|
||||||
virtual void on_adapter(const std::string& name, VisitorAdapter& adapter);
|
virtual void on_adapter(const std::string& name, VisitorAdapter& adapter);
|
||||||
|
|
||||||
/// \brief Provides API to handle nGraph Function attribute type, accessed as ValueAccessor
|
/// \brief Provides API to handle nGraph Function attribute type, accessed as ValueAccessor
|
||||||
/// \param name attribute name
|
/// \param name attribute name
|
||||||
/// \param adapter reference to a Function ValueAccessor<VAT>
|
/// \param adapter reference to a Function ValueAccessor<VAT>
|
||||||
virtual void on_adapter(const std::string& name,
|
virtual void on_adapter(const std::string& name, ValueAccessor<std::shared_ptr<Function>>& adapter);
|
||||||
ValueAccessor<std::shared_ptr<Function>>& adapter);
|
|
||||||
|
|
||||||
/// The generic visitor. There must be a definition of AttributeAdapter<T> that can convert
|
/// The generic visitor. There must be a definition of AttributeAdapter<T> that can convert
|
||||||
/// to a ValueAccessor<U> for one of the on_adpater methods.
|
/// to a ValueAccessor<U> for one of the on_adpater methods.
|
||||||
template <typename AT>
|
template <typename AT>
|
||||||
void on_attribute(const std::string& name, AT& value)
|
void on_attribute(const std::string& name, AT& value) {
|
||||||
{
|
|
||||||
AttributeAdapter<AT> adapter(value);
|
AttributeAdapter<AT> adapter(value);
|
||||||
start_structure(name);
|
start_structure(name);
|
||||||
on_adapter(get_name_with_context(), adapter);
|
on_adapter(get_name_with_context(), adapter);
|
||||||
finish_structure();
|
finish_structure();
|
||||||
}
|
}
|
||||||
/// \returns The nested context of visits
|
/// \returns The nested context of visits
|
||||||
const std::vector<std::string>& get_context() const { return m_context; }
|
const std::vector<std::string>& get_context() const {
|
||||||
|
return m_context;
|
||||||
|
}
|
||||||
/// \returns context prepended to names
|
/// \returns context prepended to names
|
||||||
virtual std::string get_name_with_context();
|
virtual std::string get_name_with_context();
|
||||||
/// \brief Start visiting a nested structure
|
/// \brief Start visiting a nested structure
|
||||||
@ -135,8 +122,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// No node may be used as an attribute unless it has already been registered with an ID.
|
/// No node may be used as an attribute unless it has already been registered with an ID.
|
||||||
/// References to nodes are visited with a ValueAccessor of their ID.
|
/// References to nodes are visited with a ValueAccessor of their ID.
|
||||||
virtual void register_node(const std::shared_ptr<Node>& node,
|
virtual void register_node(const std::shared_ptr<Node>& node, node_id_t id = invalid_node_id);
|
||||||
node_id_t id = invalid_node_id);
|
|
||||||
/// Returns the node with the given id, or nullptr if there is no registered node
|
/// Returns the node with the given id, or nullptr if there is no registered node
|
||||||
virtual std::shared_ptr<Node> get_registered_node(node_id_t id);
|
virtual std::shared_ptr<Node> get_registered_node(node_id_t id);
|
||||||
/// Returns the id for the node, or -1 if the node is not registered
|
/// Returns the id for the node, or -1 if the node is not registered
|
||||||
|
@ -12,11 +12,9 @@
|
|||||||
#include "ngraph/attribute_adapter.hpp"
|
#include "ngraph/attribute_adapter.hpp"
|
||||||
#include "ngraph/ngraph_visibility.hpp"
|
#include "ngraph/ngraph_visibility.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// \brief A set of axes.
|
/// \brief A set of axes.
|
||||||
class AxisSet : public std::set<size_t>
|
class AxisSet : public std::set<size_t> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_API AxisSet();
|
NGRAPH_API AxisSet();
|
||||||
|
|
||||||
@ -36,19 +34,19 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<AxisSet> : public ValueAccessor<std::vector<int64_t>>
|
class NGRAPH_API AttributeAdapter<AxisSet> : public ValueAccessor<std::vector<int64_t>> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(AxisSet& value)
|
AttributeAdapter(AxisSet& value) : m_ref(value) {}
|
||||||
: m_ref(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<int64_t>& get() override;
|
const std::vector<int64_t>& get() override;
|
||||||
void set(const std::vector<int64_t>& value) override;
|
void set(const std::vector<int64_t>& value) override;
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
operator AxisSet&() { return m_ref; }
|
return type_info;
|
||||||
|
}
|
||||||
|
operator AxisSet&() {
|
||||||
|
return m_ref;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
AxisSet& m_ref;
|
AxisSet& m_ref;
|
||||||
|
@ -11,11 +11,9 @@
|
|||||||
#include "ngraph/attribute_adapter.hpp"
|
#include "ngraph/attribute_adapter.hpp"
|
||||||
#include "ngraph/ngraph_visibility.hpp"
|
#include "ngraph/ngraph_visibility.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// \brief A vector of axes.
|
/// \brief A vector of axes.
|
||||||
class AxisVector : public std::vector<size_t>
|
class AxisVector : public std::vector<size_t> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_API AxisVector(const std::initializer_list<size_t>& axes);
|
NGRAPH_API AxisVector(const std::initializer_list<size_t>& axes);
|
||||||
|
|
||||||
@ -26,10 +24,7 @@ namespace ngraph
|
|||||||
NGRAPH_API explicit AxisVector(size_t n);
|
NGRAPH_API explicit AxisVector(size_t n);
|
||||||
|
|
||||||
template <class InputIterator>
|
template <class InputIterator>
|
||||||
AxisVector(InputIterator first, InputIterator last)
|
AxisVector(InputIterator first, InputIterator last) : std::vector<size_t>(first, last) {}
|
||||||
: std::vector<size_t>(first, last)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
NGRAPH_API AxisVector();
|
NGRAPH_API AxisVector();
|
||||||
|
|
||||||
@ -41,17 +36,14 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<AxisVector>
|
class NGRAPH_API AttributeAdapter<AxisVector> : public IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>> {
|
||||||
: public IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(AxisVector& value)
|
AttributeAdapter(AxisVector& value) : IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>>(value) {}
|
||||||
: IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisVector>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisVector>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
|
@ -10,32 +10,26 @@
|
|||||||
|
|
||||||
#include "ngraph/except.hpp"
|
#include "ngraph/except.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
static inline std::ostream& write_all_to_stream(std::ostream& str) {
|
||||||
static inline std::ostream& write_all_to_stream(std::ostream& str) { return str; }
|
return str;
|
||||||
|
}
|
||||||
template <typename T, typename... TS>
|
template <typename T, typename... TS>
|
||||||
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS&&... args)
|
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS&&... args) {
|
||||||
{
|
|
||||||
return write_all_to_stream(str << arg, args...);
|
return write_all_to_stream(str << arg, args...);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CheckLocInfo
|
struct CheckLocInfo {
|
||||||
{
|
|
||||||
const char* file;
|
const char* file;
|
||||||
int line;
|
int line;
|
||||||
const char* check_string;
|
const char* check_string;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Base class for check failure exceptions.
|
/// Base class for check failure exceptions.
|
||||||
class NGRAPH_API CheckFailure : public ngraph_error
|
class NGRAPH_API CheckFailure : public ngraph_error {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
CheckFailure(const CheckLocInfo& check_loc_info,
|
CheckFailure(const CheckLocInfo& check_loc_info, const std::string& context_info, const std::string& explanation)
|
||||||
const std::string& context_info,
|
: ngraph_error(make_what(check_loc_info, context_info, explanation)) {}
|
||||||
const std::string& explanation)
|
|
||||||
: ngraph_error(make_what(check_loc_info, context_info, explanation))
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static std::string make_what(const CheckLocInfo& check_loc_info,
|
static std::string make_what(const CheckLocInfo& check_loc_info,
|
||||||
@ -110,22 +104,17 @@ namespace ngraph
|
|||||||
// variable (ss___) and risk shadowing.
|
// variable (ss___) and risk shadowing.
|
||||||
//
|
//
|
||||||
#define NGRAPH_CHECK_HELPER2(exc_class, ctx, check, ...) \
|
#define NGRAPH_CHECK_HELPER2(exc_class, ctx, check, ...) \
|
||||||
do \
|
do { \
|
||||||
{ \
|
if (!(check)) { \
|
||||||
if (!(check)) \
|
|
||||||
{ \
|
|
||||||
::std::stringstream ss___; \
|
::std::stringstream ss___; \
|
||||||
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
|
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
|
||||||
throw exc_class( \
|
throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
|
||||||
(::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
|
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define NGRAPH_CHECK_HELPER1(exc_class, ctx, check) \
|
#define NGRAPH_CHECK_HELPER1(exc_class, ctx, check) \
|
||||||
do \
|
do { \
|
||||||
{ \
|
if (!(check)) { \
|
||||||
if (!(check)) \
|
|
||||||
{ \
|
|
||||||
throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ""); \
|
throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ""); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
@ -143,8 +132,7 @@ namespace ngraph
|
|||||||
/// \param ... Additional error message that should describe why that execution path is unreachable.
|
/// \param ... Additional error message that should describe why that execution path is unreachable.
|
||||||
/// \throws ::ngraph::CheckFailure if the macro is executed.
|
/// \throws ::ngraph::CheckFailure if the macro is executed.
|
||||||
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", __VA_ARGS__)
|
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", __VA_ARGS__)
|
||||||
#define NGRAPH_CHECK_HELPER(exc_class, ctx, ...) \
|
#define NGRAPH_CHECK_HELPER(exc_class, ctx, ...) CALL_OVERLOAD(NGRAPH_CHECK_HELPER, exc_class, ctx, __VA_ARGS__)
|
||||||
CALL_OVERLOAD(NGRAPH_CHECK_HELPER, exc_class, ctx, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define GLUE(x, y) x y
|
#define GLUE(x, y) x y
|
||||||
|
|
||||||
@ -178,33 +166,7 @@ namespace ngraph
|
|||||||
count
|
count
|
||||||
#define EXPAND_ARGS(args) RETURN_ARG_COUNT args
|
#define EXPAND_ARGS(args) RETURN_ARG_COUNT args
|
||||||
#define COUNT_ARGS_MAXN(...) \
|
#define COUNT_ARGS_MAXN(...) \
|
||||||
EXPAND_ARGS((__VA_ARGS__, \
|
EXPAND_ARGS((__VA_ARGS__, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0))
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
2, \
|
|
||||||
1, \
|
|
||||||
0))
|
|
||||||
|
|
||||||
#define OVERLOAD_MACRO2(name, count) name##count
|
#define OVERLOAD_MACRO2(name, count) name##count
|
||||||
#define OVERLOAD_MACRO1(name, count) OVERLOAD_MACRO2(name, count)
|
#define OVERLOAD_MACRO1(name, count) OVERLOAD_MACRO2(name, count)
|
||||||
|
@ -11,11 +11,9 @@
|
|||||||
#include "ngraph/axis_set.hpp"
|
#include "ngraph/axis_set.hpp"
|
||||||
#include "ngraph/shape.hpp"
|
#include "ngraph/shape.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// \brief Coordinates for a tensor element
|
/// \brief Coordinates for a tensor element
|
||||||
class Coordinate : public std::vector<size_t>
|
class Coordinate : public std::vector<size_t> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_API Coordinate();
|
NGRAPH_API Coordinate();
|
||||||
NGRAPH_API Coordinate(const std::initializer_list<size_t>& axes);
|
NGRAPH_API Coordinate(const std::initializer_list<size_t>& axes);
|
||||||
@ -31,10 +29,7 @@ namespace ngraph
|
|||||||
NGRAPH_API ~Coordinate();
|
NGRAPH_API ~Coordinate();
|
||||||
|
|
||||||
template <class InputIterator>
|
template <class InputIterator>
|
||||||
Coordinate(InputIterator first, InputIterator last)
|
Coordinate(InputIterator first, InputIterator last) : std::vector<size_t>(first, last) {}
|
||||||
: std::vector<size_t>(first, last)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
NGRAPH_API Coordinate& operator=(const Coordinate& v);
|
NGRAPH_API Coordinate& operator=(const Coordinate& v);
|
||||||
|
|
||||||
@ -42,17 +37,14 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<Coordinate>
|
class NGRAPH_API AttributeAdapter<Coordinate> : public IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>> {
|
||||||
: public IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(Coordinate& value)
|
AttributeAdapter(Coordinate& value) : IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>>(value) {}
|
||||||
: IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Coordinate>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Coordinate>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
|
@ -11,11 +11,9 @@
|
|||||||
#include "ngraph/attribute_adapter.hpp"
|
#include "ngraph/attribute_adapter.hpp"
|
||||||
#include "ngraph/ngraph_visibility.hpp"
|
#include "ngraph/ngraph_visibility.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// \brief A difference (signed) of tensor element coordinates.
|
/// \brief A difference (signed) of tensor element coordinates.
|
||||||
class CoordinateDiff : public std::vector<std::ptrdiff_t>
|
class CoordinateDiff : public std::vector<std::ptrdiff_t> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_API CoordinateDiff(const std::initializer_list<std::ptrdiff_t>& diffs);
|
NGRAPH_API CoordinateDiff(const std::initializer_list<std::ptrdiff_t>& diffs);
|
||||||
|
|
||||||
@ -26,10 +24,7 @@ namespace ngraph
|
|||||||
NGRAPH_API explicit CoordinateDiff(size_t n, std::ptrdiff_t initial_value = 0);
|
NGRAPH_API explicit CoordinateDiff(size_t n, std::ptrdiff_t initial_value = 0);
|
||||||
|
|
||||||
template <class InputIterator>
|
template <class InputIterator>
|
||||||
CoordinateDiff(InputIterator first, InputIterator last)
|
CoordinateDiff(InputIterator first, InputIterator last) : std::vector<std::ptrdiff_t>(first, last) {}
|
||||||
: std::vector<std::ptrdiff_t>(first, last)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
NGRAPH_API ~CoordinateDiff();
|
NGRAPH_API ~CoordinateDiff();
|
||||||
|
|
||||||
@ -47,12 +42,12 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(CoordinateDiff& value)
|
AttributeAdapter(CoordinateDiff& value)
|
||||||
: IndirectVectorValueAccessor<CoordinateDiff, std::vector<int64_t>>(value)
|
: IndirectVectorValueAccessor<CoordinateDiff, std::vector<int64_t>>(value) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<CoordinateDiff>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<CoordinateDiff>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
|
@ -10,17 +10,14 @@
|
|||||||
#include "ngraph/descriptor/tensor.hpp"
|
#include "ngraph/descriptor/tensor.hpp"
|
||||||
#include "ngraph/variant.hpp"
|
#include "ngraph/variant.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
class Node;
|
class Node;
|
||||||
|
|
||||||
namespace descriptor
|
namespace descriptor {
|
||||||
{
|
|
||||||
class Output;
|
class Output;
|
||||||
|
|
||||||
// Describes a tensor that is an input to an op, directly or indirectly via a tuple
|
// Describes a tensor that is an input to an op, directly or indirectly via a tuple
|
||||||
class NGRAPH_API Input
|
class NGRAPH_API Input {
|
||||||
{
|
|
||||||
friend class ngraph::Node;
|
friend class ngraph::Node;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -38,23 +35,37 @@ namespace ngraph
|
|||||||
std::shared_ptr<Node> get_node() const;
|
std::shared_ptr<Node> get_node() const;
|
||||||
|
|
||||||
/// \return the raw pointer to the node that this is an input of
|
/// \return the raw pointer to the node that this is an input of
|
||||||
Node* get_raw_pointer_node() const { return m_node; }
|
Node* get_raw_pointer_node() const {
|
||||||
|
return m_node;
|
||||||
|
}
|
||||||
/// \return the position within all supplied tensors of this input
|
/// \return the position within all supplied tensors of this input
|
||||||
size_t get_index() const { return m_index; }
|
size_t get_index() const {
|
||||||
|
return m_index;
|
||||||
|
}
|
||||||
/// \return the connected output
|
/// \return the connected output
|
||||||
const Output& get_output() const { return *m_output; }
|
const Output& get_output() const {
|
||||||
|
return *m_output;
|
||||||
|
}
|
||||||
/// \return the connected output
|
/// \return the connected output
|
||||||
Output& get_output() { return *m_output; }
|
Output& get_output() {
|
||||||
|
return *m_output;
|
||||||
|
}
|
||||||
/// \return true if an output is connected to the input.
|
/// \return true if an output is connected to the input.
|
||||||
bool has_output() const { return m_output != nullptr; }
|
bool has_output() const {
|
||||||
|
return m_output != nullptr;
|
||||||
|
}
|
||||||
/// \return the tensor of the connected output
|
/// \return the tensor of the connected output
|
||||||
const Tensor& get_tensor() const;
|
const Tensor& get_tensor() const;
|
||||||
|
|
||||||
/// \return the tensor of the connected output
|
/// \return the tensor of the connected output
|
||||||
Tensor& get_tensor();
|
Tensor& get_tensor();
|
||||||
|
|
||||||
RTMap& get_rt_info() { return m_rt_info; }
|
RTMap& get_rt_info() {
|
||||||
const RTMap& get_rt_info() const { return m_rt_info; }
|
return m_rt_info;
|
||||||
|
}
|
||||||
|
const RTMap& get_rt_info() const {
|
||||||
|
return m_rt_info;
|
||||||
|
}
|
||||||
|
|
||||||
/// \brief Replace the current output that supplies a value for this input with output i
|
/// \brief Replace the current output that supplies a value for this input with output i
|
||||||
/// of node
|
/// of node
|
||||||
@ -69,12 +80,16 @@ namespace ngraph
|
|||||||
/// corresponding node. (Usually this is false.)
|
/// corresponding node. (Usually this is false.)
|
||||||
///
|
///
|
||||||
/// See Node::set_input_is_relevant_to_shape for more details.
|
/// See Node::set_input_is_relevant_to_shape for more details.
|
||||||
bool get_is_relevant_to_shape() const { return m_is_relevant_to_shape; }
|
bool get_is_relevant_to_shape() const {
|
||||||
|
return m_is_relevant_to_shape;
|
||||||
|
}
|
||||||
/// \return true if the value of this input is relevant to the output value of the
|
/// \return true if the value of this input is relevant to the output value of the
|
||||||
/// corresponding node. (Usually this is true.)
|
/// corresponding node. (Usually this is true.)
|
||||||
///
|
///
|
||||||
/// See Node::set_input_is_relevant_to_value for more details.
|
/// See Node::set_input_is_relevant_to_value for more details.
|
||||||
bool get_is_relevant_to_value() const { return m_is_relevant_to_value; }
|
bool get_is_relevant_to_value() const {
|
||||||
|
return m_is_relevant_to_value;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// \return the tensor for the connected output
|
/// \return the tensor for the connected output
|
||||||
|
@ -15,8 +15,7 @@
|
|||||||
#include "ngraph/node_output.hpp"
|
#include "ngraph/node_output.hpp"
|
||||||
#include "ngraph/variant.hpp"
|
#include "ngraph/variant.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
// The forward declaration of Node is needed here because Node has a deque of
|
// The forward declaration of Node is needed here because Node has a deque of
|
||||||
// Outputs, and Output is an incomplete type at this point. STL containers of
|
// Outputs, and Output is an incomplete type at this point. STL containers of
|
||||||
// incomplete type have undefined behavior according to the C++11 standard, and
|
// incomplete type have undefined behavior according to the C++11 standard, and
|
||||||
@ -24,19 +23,11 @@ namespace ngraph
|
|||||||
// systems (namely macOS).
|
// systems (namely macOS).
|
||||||
class Node;
|
class Node;
|
||||||
|
|
||||||
namespace descriptor
|
namespace descriptor {
|
||||||
{
|
|
||||||
// Describes an output tensor of an op
|
// Describes an output tensor of an op
|
||||||
class NGRAPH_API Output
|
class NGRAPH_API Output {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
Output()
|
Output() : m_node(nullptr), m_index(0), m_tensor(nullptr), m_inputs() {}
|
||||||
: m_node(nullptr)
|
|
||||||
, m_index(0)
|
|
||||||
, m_tensor(nullptr)
|
|
||||||
, m_inputs()
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \param node Node that owns this output.
|
/// \param node Node that owns this output.
|
||||||
/// \param index Position of the output tensor in all output tensors
|
/// \param index Position of the output tensor in all output tensors
|
||||||
@ -44,17 +35,29 @@ namespace ngraph
|
|||||||
Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor);
|
Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor);
|
||||||
|
|
||||||
std::shared_ptr<Node> get_node() const;
|
std::shared_ptr<Node> get_node() const;
|
||||||
size_t get_index() const { return m_index; }
|
size_t get_index() const {
|
||||||
|
return m_index;
|
||||||
|
}
|
||||||
ngraph::Output<Node> get_output() const;
|
ngraph::Output<Node> get_output() const;
|
||||||
std::shared_ptr<Tensor> get_tensor_ptr() const { return m_tensor; }
|
std::shared_ptr<Tensor> get_tensor_ptr() const {
|
||||||
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; }
|
return m_tensor;
|
||||||
|
}
|
||||||
|
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) {
|
||||||
|
m_tensor = tensor;
|
||||||
|
}
|
||||||
void add_input(Input* input);
|
void add_input(Input* input);
|
||||||
void remove_input(Input* input);
|
void remove_input(Input* input);
|
||||||
const std::vector<Input*>& get_inputs() const { return m_inputs; }
|
const std::vector<Input*>& get_inputs() const {
|
||||||
|
return m_inputs;
|
||||||
|
}
|
||||||
Tensor& get_tensor() const;
|
Tensor& get_tensor() const;
|
||||||
|
|
||||||
RTMap& get_rt_info() { return m_rt_info; }
|
RTMap& get_rt_info() {
|
||||||
const RTMap& get_rt_info() const { return m_rt_info; }
|
return m_rt_info;
|
||||||
|
}
|
||||||
|
const RTMap& get_rt_info() const {
|
||||||
|
return m_rt_info;
|
||||||
|
}
|
||||||
/// \return the shape of the output
|
/// \return the shape of the output
|
||||||
const Shape& get_shape() const;
|
const Shape& get_shape() const;
|
||||||
|
|
||||||
|
@ -14,31 +14,22 @@
|
|||||||
#include "ngraph/shape.hpp"
|
#include "ngraph/shape.hpp"
|
||||||
#include "ngraph/type/element_type.hpp"
|
#include "ngraph/type/element_type.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
class Node;
|
class Node;
|
||||||
|
|
||||||
namespace runtime
|
namespace runtime {
|
||||||
{
|
|
||||||
class HostTensor;
|
class HostTensor;
|
||||||
}
|
}
|
||||||
using HostTensorPtr = std::shared_ptr<runtime::HostTensor>;
|
using HostTensorPtr = std::shared_ptr<runtime::HostTensor>;
|
||||||
namespace descriptor
|
namespace descriptor {
|
||||||
{
|
|
||||||
/// \brief Compile-time descriptor of a first-class value that is a tensor.
|
/// \brief Compile-time descriptor of a first-class value that is a tensor.
|
||||||
class NGRAPH_API Tensor
|
class NGRAPH_API Tensor {
|
||||||
{
|
|
||||||
Tensor(const Tensor&) = delete;
|
Tensor(const Tensor&) = delete;
|
||||||
Tensor& operator=(const Tensor&) = delete;
|
Tensor& operator=(const Tensor&) = delete;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Tensor(const element::Type& element_type,
|
Tensor(const element::Type& element_type, const PartialShape& pshape, const std::string& name);
|
||||||
const PartialShape& pshape,
|
Tensor(const element::Type& element_type, const PartialShape& pshape, Node* node, size_t node_output_number);
|
||||||
const std::string& name);
|
|
||||||
Tensor(const element::Type& element_type,
|
|
||||||
const PartialShape& pshape,
|
|
||||||
Node* node,
|
|
||||||
size_t node_output_number);
|
|
||||||
|
|
||||||
NGRAPH_DEPRECATED("get_name() is deprecated! Please use get_names() instead.")
|
NGRAPH_DEPRECATED("get_name() is deprecated! Please use get_names() instead.")
|
||||||
const std::string& get_name() const;
|
const std::string& get_name() const;
|
||||||
@ -59,16 +50,23 @@ namespace ngraph
|
|||||||
/// \brief unsets bound value descriptions
|
/// \brief unsets bound value descriptions
|
||||||
void invalidate_values();
|
void invalidate_values();
|
||||||
|
|
||||||
const element::Type& get_element_type() const { return m_element_type; }
|
const element::Type& get_element_type() const {
|
||||||
|
return m_element_type;
|
||||||
|
}
|
||||||
const Shape& get_shape() const;
|
const Shape& get_shape() const;
|
||||||
const PartialShape& get_partial_shape() const { return m_partial_shape; }
|
const PartialShape& get_partial_shape() const {
|
||||||
|
return m_partial_shape;
|
||||||
|
}
|
||||||
/// \brief gets lower bound value description
|
/// \brief gets lower bound value description
|
||||||
HostTensorPtr get_lower_value() const { return m_lower_value; }
|
HostTensorPtr get_lower_value() const {
|
||||||
|
return m_lower_value;
|
||||||
|
}
|
||||||
/// \brief gets upper bound value description
|
/// \brief gets upper bound value description
|
||||||
HostTensorPtr get_upper_value() const { return m_upper_value; }
|
HostTensorPtr get_upper_value() const {
|
||||||
|
return m_upper_value;
|
||||||
|
}
|
||||||
/// \brief checks if lower and upper bound are set and point to the same HostTensor
|
/// \brief checks if lower and upper bound are set and point to the same HostTensor
|
||||||
bool has_and_set_bound() const
|
bool has_and_set_bound() const {
|
||||||
{
|
|
||||||
return m_upper_value != nullptr && m_upper_value == m_lower_value;
|
return m_upper_value != nullptr && m_upper_value == m_lower_value;
|
||||||
}
|
}
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
|
@ -4,23 +4,22 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <limits>
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#include "ngraph/deprecated.hpp"
|
#include "ngraph/deprecated.hpp"
|
||||||
#include "ngraph/interval.hpp"
|
#include "ngraph/interval.hpp"
|
||||||
#include "ngraph/ngraph_visibility.hpp"
|
#include "ngraph/ngraph_visibility.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
|
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
|
||||||
/// in a shape or shape-like object.
|
/// in a shape or shape-like object.
|
||||||
///
|
///
|
||||||
/// Static dimensions may be implicitly converted from value_type. A dynamic dimension is
|
/// Static dimensions may be implicitly converted from value_type. A dynamic dimension is
|
||||||
/// constructed with Dimension() or Dimension::dynamic().
|
/// constructed with Dimension() or Dimension::dynamic().
|
||||||
class NGRAPH_API Dimension
|
class NGRAPH_API Dimension {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
using value_type = int64_t;
|
using value_type = int64_t;
|
||||||
|
|
||||||
@ -36,20 +35,22 @@ namespace ngraph
|
|||||||
/// \brief Construct a dynamic dimension with range [0, ...]
|
/// \brief Construct a dynamic dimension with range [0, ...]
|
||||||
Dimension() = default;
|
Dimension() = default;
|
||||||
|
|
||||||
bool operator==(const Dimension& dimension) const
|
bool operator==(const Dimension& dimension) const {
|
||||||
{
|
|
||||||
return m_dimension == dimension.m_dimension;
|
return m_dimension == dimension.m_dimension;
|
||||||
}
|
}
|
||||||
bool operator!=(const Dimension& dimension) const
|
bool operator!=(const Dimension& dimension) const {
|
||||||
{
|
|
||||||
return m_dimension != dimension.m_dimension;
|
return m_dimension != dimension.m_dimension;
|
||||||
}
|
}
|
||||||
/// \brief Check whether this dimension is static.
|
/// \brief Check whether this dimension is static.
|
||||||
/// \return `true` if the dimension is static, else `false`.
|
/// \return `true` if the dimension is static, else `false`.
|
||||||
bool is_static() const { return m_dimension.size() == 1; }
|
bool is_static() const {
|
||||||
|
return m_dimension.size() == 1;
|
||||||
|
}
|
||||||
/// \brief Check whether this dimension is dynamic.
|
/// \brief Check whether this dimension is dynamic.
|
||||||
/// \return `false` if the dimension is static, else `true`.
|
/// \return `false` if the dimension is static, else `true`.
|
||||||
bool is_dynamic() const { return m_dimension.size() != 1; }
|
bool is_dynamic() const {
|
||||||
|
return m_dimension.size() != 1;
|
||||||
|
}
|
||||||
/// \brief Convert this dimension to `value_type`. This dimension must be static and
|
/// \brief Convert this dimension to `value_type`. This dimension must be static and
|
||||||
/// non-negative.
|
/// non-negative.
|
||||||
/// \throws std::invalid_argument If this dimension is dynamic or negative.
|
/// \throws std::invalid_argument If this dimension is dynamic or negative.
|
||||||
@ -59,8 +60,12 @@ namespace ngraph
|
|||||||
value_type get_max_length() const;
|
value_type get_max_length() const;
|
||||||
|
|
||||||
/// \brief Return the interval of valid lengths
|
/// \brief Return the interval of valid lengths
|
||||||
const Interval& get_interval() const { return m_dimension; }
|
const Interval& get_interval() const {
|
||||||
Interval& get_interval() { return m_dimension; }
|
return m_dimension;
|
||||||
|
}
|
||||||
|
Interval& get_interval() {
|
||||||
|
return m_dimension;
|
||||||
|
}
|
||||||
/// \brief Check whether this dimension represents the same scheme as the argument (both
|
/// \brief Check whether this dimension represents the same scheme as the argument (both
|
||||||
/// dynamic, or equal).
|
/// dynamic, or equal).
|
||||||
/// \param dim The other dimension to compare this dimension to.
|
/// \param dim The other dimension to compare this dimension to.
|
||||||
@ -115,7 +120,9 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \brief Create a dynamic dimension.
|
/// \brief Create a dynamic dimension.
|
||||||
/// \return A dynamic dimension.
|
/// \return A dynamic dimension.
|
||||||
static Dimension dynamic() { return Dimension(); }
|
static Dimension dynamic() {
|
||||||
|
return Dimension();
|
||||||
|
}
|
||||||
/// \brief Addition operator for Dimension.
|
/// \brief Addition operator for Dimension.
|
||||||
/// \param dim Right operand for addition.
|
/// \param dim Right operand for addition.
|
||||||
/// \return Smallest interval dimension enclosing inputs
|
/// \return Smallest interval dimension enclosing inputs
|
||||||
@ -135,21 +142,22 @@ namespace ngraph
|
|||||||
/// \brief Add-into operator for Dimension.
|
/// \brief Add-into operator for Dimension.
|
||||||
/// \param dim Right operand for addition.
|
/// \param dim Right operand for addition.
|
||||||
/// \return A reference to `*this`, after updating `*this` to the value `*this + dim`.
|
/// \return A reference to `*this`, after updating `*this` to the value `*this + dim`.
|
||||||
Dimension& operator+=(const Dimension& dim) { return (*this = *this + dim); }
|
Dimension& operator+=(const Dimension& dim) {
|
||||||
|
return (*this = *this + dim);
|
||||||
|
}
|
||||||
/// \brief Multiply-into operator for Dimension.
|
/// \brief Multiply-into operator for Dimension.
|
||||||
/// \param dim Right operand for multiplication.
|
/// \param dim Right operand for multiplication.
|
||||||
/// \return A reference to `*this`, after updating `*this` to the value `*this * dim`.
|
/// \return A reference to `*this`, after updating `*this` to the value `*this * dim`.
|
||||||
Dimension& operator*=(const Dimension& dim) { return (*this = *this * dim); }
|
Dimension& operator*=(const Dimension& dim) {
|
||||||
|
return (*this = *this * dim);
|
||||||
|
}
|
||||||
/// \brief Intersection of dimensions
|
/// \brief Intersection of dimensions
|
||||||
Dimension operator&(const Dimension& dim) const;
|
Dimension operator&(const Dimension& dim) const;
|
||||||
/// \brief Intersection of dimensions
|
/// \brief Intersection of dimensions
|
||||||
Dimension& operator&=(const Dimension& dim);
|
Dimension& operator&=(const Dimension& dim);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Dimension(const Interval& interval)
|
Dimension(const Interval& interval) : m_dimension(interval) {}
|
||||||
: m_dimension(interval)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
// The actual numerical value of the dimension.
|
// The actual numerical value of the dimension.
|
||||||
Interval m_dimension{};
|
Interval m_dimension{};
|
||||||
|
@ -12,12 +12,9 @@
|
|||||||
#include "ngraph/type.hpp"
|
#include "ngraph/type.hpp"
|
||||||
#include "ngraph/type/element_type.hpp"
|
#include "ngraph/type/element_type.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace reduction {
|
||||||
namespace reduction
|
enum class Type {
|
||||||
{
|
|
||||||
enum class Type
|
|
||||||
{
|
|
||||||
SUM,
|
SUM,
|
||||||
PROD,
|
PROD,
|
||||||
MIN,
|
MIN,
|
||||||
@ -29,16 +26,13 @@ namespace ngraph
|
|||||||
} // namespace reduction
|
} // namespace reduction
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<reduction::Type>
|
class NGRAPH_API AttributeAdapter<reduction::Type> : public EnumAttributeAdapterBase<reduction::Type> {
|
||||||
: public EnumAttributeAdapterBase<reduction::Type>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(reduction::Type& value)
|
AttributeAdapter(reduction::Type& value) : EnumAttributeAdapterBase<reduction::Type>(value) {}
|
||||||
: EnumAttributeAdapterBase<reduction::Type>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<reduction::Type>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<reduction::Type>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -10,17 +10,14 @@
|
|||||||
|
|
||||||
#include "ngraph/check.hpp"
|
#include "ngraph/check.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// Uses a pairings defined by EnumTypes::get() to convert between strings
|
/// Uses a pairings defined by EnumTypes::get() to convert between strings
|
||||||
/// and enum values.
|
/// and enum values.
|
||||||
template <typename EnumType>
|
template <typename EnumType>
|
||||||
class EnumNames
|
class EnumNames {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
/// Converts strings to enum values
|
/// Converts strings to enum values
|
||||||
static EnumType as_enum(const std::string& name)
|
static EnumType as_enum(const std::string& name) {
|
||||||
{
|
|
||||||
auto to_lower = [](const std::string& s) {
|
auto to_lower = [](const std::string& s) {
|
||||||
std::string rc = s;
|
std::string rc = s;
|
||||||
std::transform(rc.begin(), rc.end(), rc.begin(), [](char c) {
|
std::transform(rc.begin(), rc.end(), rc.begin(), [](char c) {
|
||||||
@ -28,10 +25,8 @@ namespace ngraph
|
|||||||
});
|
});
|
||||||
return rc;
|
return rc;
|
||||||
};
|
};
|
||||||
for (const auto& p : get().m_string_enums)
|
for (const auto& p : get().m_string_enums) {
|
||||||
{
|
if (to_lower(p.first) == to_lower(name)) {
|
||||||
if (to_lower(p.first) == to_lower(name))
|
|
||||||
{
|
|
||||||
return p.second;
|
return p.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -39,12 +34,9 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Converts enum values to strings
|
/// Converts enum values to strings
|
||||||
static const std::string& as_string(EnumType e)
|
static const std::string& as_string(EnumType e) {
|
||||||
{
|
for (const auto& p : get().m_string_enums) {
|
||||||
for (const auto& p : get().m_string_enums)
|
if (p.second == e) {
|
||||||
{
|
|
||||||
if (p.second == e)
|
|
||||||
{
|
|
||||||
return p.first;
|
return p.first;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -53,12 +45,9 @@ namespace ngraph
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
/// Creates the mapping.
|
/// Creates the mapping.
|
||||||
EnumNames(const std::string& enum_name,
|
EnumNames(const std::string& enum_name, const std::vector<std::pair<std::string, EnumType>> string_enums)
|
||||||
const std::vector<std::pair<std::string, EnumType>> string_enums)
|
: m_enum_name(enum_name),
|
||||||
: m_enum_name(enum_name)
|
m_string_enums(string_enums) {}
|
||||||
, m_string_enums(string_enums)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Must be defined to returns a singleton for each supported enum class
|
/// Must be defined to returns a singleton for each supported enum class
|
||||||
static EnumNames<EnumType>& get();
|
static EnumNames<EnumType>& get();
|
||||||
@ -69,16 +58,13 @@ namespace ngraph
|
|||||||
|
|
||||||
/// Returns the enum value matching the string
|
/// Returns the enum value matching the string
|
||||||
template <typename Type, typename Value>
|
template <typename Type, typename Value>
|
||||||
typename std::enable_if<std::is_convertible<Value, std::string>::value, Type>::type
|
typename std::enable_if<std::is_convertible<Value, std::string>::value, Type>::type as_enum(const Value& value) {
|
||||||
as_enum(const Value& value)
|
|
||||||
{
|
|
||||||
return EnumNames<Type>::as_enum(value);
|
return EnumNames<Type>::as_enum(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the string matching the enum value
|
/// Returns the string matching the enum value
|
||||||
template <typename Value>
|
template <typename Value>
|
||||||
const std::string& as_string(Value value)
|
const std::string& as_string(Value value) {
|
||||||
{
|
|
||||||
return EnumNames<Value>::as_string(value);
|
return EnumNames<Value>::as_string(value);
|
||||||
}
|
}
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -5,12 +5,10 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <ngraph/ngraph_visibility.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include <ngraph/ngraph_visibility.hpp>
|
namespace ngraph {
|
||||||
|
|
||||||
namespace ngraph
|
|
||||||
{
|
|
||||||
/// \brief Get the names environment variable as a string.
|
/// \brief Get the names environment variable as a string.
|
||||||
/// \param env_var The string name of the environment variable to get.
|
/// \param env_var The string name of the environment variable to get.
|
||||||
/// \return Returns string by value or an empty string if the environment
|
/// \return Returns string by value or an empty string if the environment
|
||||||
|
@ -12,14 +12,12 @@
|
|||||||
#include "ngraph/shape.hpp"
|
#include "ngraph/shape.hpp"
|
||||||
#include "ngraph/type/element_type_traits.hpp"
|
#include "ngraph/type/element_type_traits.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// \brief Execute handlers on a subgraph to compute values
|
/// \brief Execute handlers on a subgraph to compute values
|
||||||
///
|
///
|
||||||
///
|
///
|
||||||
template <typename V>
|
template <typename V>
|
||||||
class Evaluator
|
class Evaluator {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
/// \brief values we compute for outputs
|
/// \brief values we compute for outputs
|
||||||
using value_map = std::map<RawNodeOutput, V>;
|
using value_map = std::map<RawNodeOutput, V>;
|
||||||
@ -41,37 +39,40 @@ namespace ngraph
|
|||||||
/// Evaluator::get_value_map().
|
/// Evaluator::get_value_map().
|
||||||
///
|
///
|
||||||
/// \param Handlers for ops. Pairs of Node::type_info_t and handler functions.
|
/// \param Handlers for ops. Pairs of Node::type_info_t and handler functions.
|
||||||
Evaluator(const op_handler_map& handlers, value_map& values)
|
Evaluator(const op_handler_map& handlers, value_map& values) : m_handlers(handlers), m_value_map(values) {}
|
||||||
: m_handlers(handlers)
|
|
||||||
, m_value_map(values)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Retrieves the value_map, which holds all Output<Node> value associations.
|
/// \brief Retrieves the value_map, which holds all Output<Node> value associations.
|
||||||
value_map& get_value_map() { return m_value_map; }
|
value_map& get_value_map() {
|
||||||
const value_map& get_value_map() const { return m_value_map; }
|
return m_value_map;
|
||||||
|
}
|
||||||
|
const value_map& get_value_map() const {
|
||||||
|
return m_value_map;
|
||||||
|
}
|
||||||
/// \brief If set, handles all ops
|
/// \brief If set, handles all ops
|
||||||
const op_handler& get_univeral_handler() const { return m_universal_handler; }
|
const op_handler& get_univeral_handler() const {
|
||||||
|
return m_universal_handler;
|
||||||
|
}
|
||||||
/// \brief If set, handles all ops not in the handlers
|
/// \brief If set, handles all ops not in the handlers
|
||||||
const op_handler& get_default_handler() const { return m_default_handler; }
|
const op_handler& get_default_handler() const {
|
||||||
|
return m_default_handler;
|
||||||
|
}
|
||||||
/// \brief If set, handles all ops
|
/// \brief If set, handles all ops
|
||||||
void set_univeral_handler(const op_handler& handler) { m_universal_handler = handler; }
|
void set_univeral_handler(const op_handler& handler) {
|
||||||
|
m_universal_handler = handler;
|
||||||
|
}
|
||||||
/// \brief If set, handles all ops not in the handlers
|
/// \brief If set, handles all ops not in the handlers
|
||||||
void set_default_handler(const op_handler& handler) { m_default_handler = handler; }
|
void set_default_handler(const op_handler& handler) {
|
||||||
|
m_default_handler = handler;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
op_handler get_handler(Node* node)
|
op_handler get_handler(Node* node) {
|
||||||
{
|
|
||||||
op_handler handler = m_universal_handler;
|
op_handler handler = m_universal_handler;
|
||||||
if (!handler)
|
if (!handler) {
|
||||||
{
|
|
||||||
auto it = m_handlers.find(node->get_type_info());
|
auto it = m_handlers.find(node->get_type_info());
|
||||||
if (it == m_handlers.end())
|
if (it == m_handlers.end()) {
|
||||||
{
|
|
||||||
handler = m_default_handler;
|
handler = m_default_handler;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
handler = it->second;
|
handler = it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -83,56 +84,39 @@ namespace ngraph
|
|||||||
using InstStack = std::stack<InstPtr>;
|
using InstStack = std::stack<InstPtr>;
|
||||||
|
|
||||||
/// \brief Intstructions for evaluations state machine
|
/// \brief Intstructions for evaluations state machine
|
||||||
class Inst
|
class Inst {
|
||||||
{
|
|
||||||
protected:
|
protected:
|
||||||
Inst(Node* node)
|
Inst(Node* node) : m_node(node) {}
|
||||||
: m_node(node)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
virtual ~Inst() {}
|
virtual ~Inst() {}
|
||||||
virtual void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) = 0;
|
virtual void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) = 0;
|
||||||
Node* get_node() { return m_node; }
|
Node* get_node() {
|
||||||
|
return m_node;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Node* m_node;
|
Node* m_node;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Ensure value has been analyzed
|
/// \brief Ensure value has been analyzed
|
||||||
class ValueInst : public Inst
|
class ValueInst : public Inst {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
ValueInst(const Output<Node>& value)
|
ValueInst(const Output<Node>& value) : Inst(value.get_node()), m_index(value.get_index()) {}
|
||||||
: Inst(value.get_node())
|
|
||||||
, m_index(value.get_index())
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
ValueInst(const RawNodeOutput& value)
|
ValueInst(const RawNodeOutput& value) : Inst(value.node), m_index(value.index) {}
|
||||||
: Inst(value.node)
|
|
||||||
, m_index(value.index)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override
|
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override {
|
||||||
{
|
|
||||||
// Request to analyze this value if we can
|
// Request to analyze this value if we can
|
||||||
if (auto handler = evaluator.get_handler(node))
|
if (auto handler = evaluator.get_handler(node)) {
|
||||||
{
|
|
||||||
// Ensure the inputs are processed and then execute the op handler
|
// Ensure the inputs are processed and then execute the op handler
|
||||||
inst_stack.push(InstPtr(new ExecuteInst(node, handler)));
|
inst_stack.push(InstPtr(new ExecuteInst(node, handler)));
|
||||||
for (auto v : node->input_values())
|
for (auto v : node->input_values()) {
|
||||||
{
|
|
||||||
inst_stack.push(InstPtr(new ValueInst(v)));
|
inst_stack.push(InstPtr(new ValueInst(v)));
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
// We don't know how to handle this op, so mark the outputs as unknown
|
// We don't know how to handle this op, so mark the outputs as unknown
|
||||||
for (auto output : node->outputs())
|
for (auto output : node->outputs()) {
|
||||||
{
|
|
||||||
evaluator.get_value_map()[output] = V();
|
evaluator.get_value_map()[output] = V();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -143,27 +127,19 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// \brief All arguments have been handled; execute the node handler
|
/// \brief All arguments have been handled; execute the node handler
|
||||||
class ExecuteInst : public Inst
|
class ExecuteInst : public Inst {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
ExecuteInst(Node* node, op_handler& handler)
|
ExecuteInst(Node* node, op_handler& handler) : Inst(node), m_handler(handler) {}
|
||||||
: Inst(node)
|
|
||||||
, m_handler(handler)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override
|
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override {
|
||||||
{
|
|
||||||
// Request to execute the handleer. Pass what we know about the inputs to the
|
// Request to execute the handleer. Pass what we know about the inputs to the
|
||||||
// handler and associate the results with the outputs
|
// handler and associate the results with the outputs
|
||||||
std::vector<V> inputs;
|
std::vector<V> inputs;
|
||||||
for (auto v : node->input_values())
|
for (auto v : node->input_values()) {
|
||||||
{
|
|
||||||
inputs.push_back(evaluator.get_value_map().at(v));
|
inputs.push_back(evaluator.get_value_map().at(v));
|
||||||
}
|
}
|
||||||
std::vector<V> outputs = m_handler(node, inputs);
|
std::vector<V> outputs = m_handler(node, inputs);
|
||||||
for (size_t i = 0; i < outputs.size(); ++i)
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
{
|
|
||||||
evaluator.get_value_map()[node->output(i)] = outputs[i];
|
evaluator.get_value_map()[node->output(i)] = outputs[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,18 +150,15 @@ namespace ngraph
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
/// \brief Determine information about value
|
/// \brief Determine information about value
|
||||||
V evaluate(const Output<Node>& value)
|
V evaluate(const Output<Node>& value) {
|
||||||
{
|
|
||||||
InstStack inst_stack;
|
InstStack inst_stack;
|
||||||
inst_stack.push(InstPtr(new ValueInst(value)));
|
inst_stack.push(InstPtr(new ValueInst(value)));
|
||||||
while (!inst_stack.empty())
|
while (!inst_stack.empty()) {
|
||||||
{
|
|
||||||
InstPtr inst;
|
InstPtr inst;
|
||||||
std::swap(inst_stack.top(), inst);
|
std::swap(inst_stack.top(), inst);
|
||||||
inst_stack.pop();
|
inst_stack.pop();
|
||||||
auto node = inst->get_node();
|
auto node = inst->get_node();
|
||||||
if (m_value_map.find(node->output(0)) != m_value_map.end())
|
if (m_value_map.find(node->output(0)) != m_value_map.end()) {
|
||||||
{
|
|
||||||
// Already computed
|
// Already computed
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -4,39 +4,23 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <ngraph/ngraph_visibility.hpp>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#include <ngraph/ngraph_visibility.hpp>
|
namespace ngraph {
|
||||||
|
|
||||||
namespace ngraph
|
|
||||||
{
|
|
||||||
/// Base error for ngraph runtime errors.
|
/// Base error for ngraph runtime errors.
|
||||||
class NGRAPH_API ngraph_error : public std::runtime_error
|
class NGRAPH_API ngraph_error : public std::runtime_error {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
explicit ngraph_error(const std::string& what_arg)
|
explicit ngraph_error(const std::string& what_arg) : std::runtime_error(what_arg) {}
|
||||||
: std::runtime_error(what_arg)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit ngraph_error(const char* what_arg)
|
explicit ngraph_error(const char* what_arg) : std::runtime_error(what_arg) {}
|
||||||
: std::runtime_error(what_arg)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit ngraph_error(const std::stringstream& what_arg)
|
explicit ngraph_error(const std::stringstream& what_arg) : std::runtime_error(what_arg.str()) {}
|
||||||
: std::runtime_error(what_arg.str())
|
|
||||||
{
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class NGRAPH_API unsupported_op : public std::runtime_error
|
class NGRAPH_API unsupported_op : public std::runtime_error {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
unsupported_op(const std::string& what_arg)
|
unsupported_op(const std::string& what_arg) : std::runtime_error(what_arg) {}
|
||||||
: std::runtime_error(what_arg)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -10,63 +10,56 @@
|
|||||||
|
|
||||||
#include "ngraph/ngraph_visibility.hpp"
|
#include "ngraph/ngraph_visibility.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
NGRAPH_API std::mutex& get_registry_mutex();
|
NGRAPH_API std::mutex& get_registry_mutex();
|
||||||
|
|
||||||
/// \brief Registry of factories that can construct objects derived from BASE_TYPE
|
/// \brief Registry of factories that can construct objects derived from BASE_TYPE
|
||||||
template <typename BASE_TYPE>
|
template <typename BASE_TYPE>
|
||||||
class FactoryRegistry
|
class FactoryRegistry {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
using Factory = std::function<BASE_TYPE*()>;
|
using Factory = std::function<BASE_TYPE*()>;
|
||||||
using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>;
|
using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>;
|
||||||
|
|
||||||
// \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
|
// \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
|
||||||
template <typename DERIVED_TYPE>
|
template <typename DERIVED_TYPE>
|
||||||
static Factory get_default_factory()
|
static Factory get_default_factory() {
|
||||||
{
|
return []() {
|
||||||
return []() { return new DERIVED_TYPE(); };
|
return new DERIVED_TYPE();
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Register a custom factory for type_info
|
/// \brief Register a custom factory for type_info
|
||||||
void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory)
|
void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory) {
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> guard(get_registry_mutex());
|
std::lock_guard<std::mutex> guard(get_registry_mutex());
|
||||||
m_factory_map[type_info] = factory;
|
m_factory_map[type_info] = factory;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Register a custom factory for DERIVED_TYPE
|
/// \brief Register a custom factory for DERIVED_TYPE
|
||||||
template <typename DERIVED_TYPE>
|
template <typename DERIVED_TYPE>
|
||||||
void register_factory(Factory factory)
|
void register_factory(Factory factory) {
|
||||||
{
|
|
||||||
register_factory(DERIVED_TYPE::type_info, factory);
|
register_factory(DERIVED_TYPE::type_info, factory);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Register the defualt constructor factory for DERIVED_TYPE
|
/// \brief Register the defualt constructor factory for DERIVED_TYPE
|
||||||
template <typename DERIVED_TYPE>
|
template <typename DERIVED_TYPE>
|
||||||
void register_factory()
|
void register_factory() {
|
||||||
{
|
|
||||||
register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>());
|
register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Check to see if a factory is registered
|
/// \brief Check to see if a factory is registered
|
||||||
bool has_factory(const typename BASE_TYPE::type_info_t& info)
|
bool has_factory(const typename BASE_TYPE::type_info_t& info) {
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> guard(get_registry_mutex());
|
std::lock_guard<std::mutex> guard(get_registry_mutex());
|
||||||
return m_factory_map.find(info) != m_factory_map.end();
|
return m_factory_map.find(info) != m_factory_map.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Check to see if DERIVED_TYPE has a registered factory
|
/// \brief Check to see if DERIVED_TYPE has a registered factory
|
||||||
template <typename DERIVED_TYPE>
|
template <typename DERIVED_TYPE>
|
||||||
bool has_factory()
|
bool has_factory() {
|
||||||
{
|
|
||||||
return has_factory(DERIVED_TYPE::type_info);
|
return has_factory(DERIVED_TYPE::type_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Create an instance for type_info
|
/// \brief Create an instance for type_info
|
||||||
BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const
|
BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const {
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> guard(get_registry_mutex());
|
std::lock_guard<std::mutex> guard(get_registry_mutex());
|
||||||
auto it = m_factory_map.find(type_info);
|
auto it = m_factory_map.find(type_info);
|
||||||
return it == m_factory_map.end() ? nullptr : it->second();
|
return it == m_factory_map.end() ? nullptr : it->second();
|
||||||
@ -74,8 +67,7 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \brief Create an instance using factory for DERIVED_TYPE
|
/// \brief Create an instance using factory for DERIVED_TYPE
|
||||||
template <typename DERIVED_TYPE>
|
template <typename DERIVED_TYPE>
|
||||||
BASE_TYPE* create() const
|
BASE_TYPE* create() const {
|
||||||
{
|
|
||||||
return create(DERIVED_TYPE::type_info);
|
return create(DERIVED_TYPE::type_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,37 +8,32 @@
|
|||||||
#include "ngraph/attribute_visitor.hpp"
|
#include "ngraph/attribute_visitor.hpp"
|
||||||
#include "ngraph/factory.hpp"
|
#include "ngraph/factory.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
template <typename BASE_TYPE>
|
template <typename BASE_TYPE>
|
||||||
class FactoryAttributeAdapter : public VisitorAdapter
|
class FactoryAttributeAdapter : public VisitorAdapter {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
FactoryAttributeAdapter(std::shared_ptr<BASE_TYPE>& ref)
|
FactoryAttributeAdapter(std::shared_ptr<BASE_TYPE>& ref) : m_ref(ref) {}
|
||||||
: m_ref(ref)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Hook for extra processing before other attributes
|
/// \brief Hook for extra processing before other attributes
|
||||||
virtual bool on_start(AttributeVisitor& /* visitor */) { return true; }
|
virtual bool on_start(AttributeVisitor& /* visitor */) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
/// \brief Hook for extra processing after other attributes
|
/// \brief Hook for extra processing after other attributes
|
||||||
virtual bool on_finish(AttributeVisitor& /* visitor */) { return true; }
|
virtual bool on_finish(AttributeVisitor& /* visitor */) {
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override
|
return true;
|
||||||
{
|
}
|
||||||
if (on_start(visitor))
|
bool visit_attributes(AttributeVisitor& visitor) override {
|
||||||
{
|
if (on_start(visitor)) {
|
||||||
std::string type_info_name;
|
std::string type_info_name;
|
||||||
uint64_t type_info_version;
|
uint64_t type_info_version;
|
||||||
if (m_ref)
|
if (m_ref) {
|
||||||
{
|
|
||||||
auto& type_info = m_ref->get_type_info();
|
auto& type_info = m_ref->get_type_info();
|
||||||
type_info_name = type_info.name;
|
type_info_name = type_info.name;
|
||||||
type_info_version = type_info.version;
|
type_info_version = type_info.version;
|
||||||
}
|
}
|
||||||
visitor.on_attribute("name", type_info_name);
|
visitor.on_attribute("name", type_info_name);
|
||||||
visitor.on_attribute("version", type_info_version);
|
visitor.on_attribute("version", type_info_version);
|
||||||
if (m_ref)
|
if (m_ref) {
|
||||||
{
|
|
||||||
visitor.start_structure("value");
|
visitor.start_structure("value");
|
||||||
m_ref->visit_attributes(visitor);
|
m_ref->visit_attributes(visitor);
|
||||||
visitor.finish_structure();
|
visitor.finish_structure();
|
||||||
|
@ -5,15 +5,12 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <ngraph/ngraph_visibility.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <ngraph/ngraph_visibility.hpp>
|
namespace ngraph {
|
||||||
|
namespace file_util {
|
||||||
namespace ngraph
|
|
||||||
{
|
|
||||||
namespace file_util
|
|
||||||
{
|
|
||||||
/// \brief Returns the name with extension for a given path
|
/// \brief Returns the name with extension for a given path
|
||||||
/// \param path The path to the output file
|
/// \param path The path to the output file
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
@ -37,10 +34,7 @@ namespace ngraph
|
|||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::string path_join(const std::string& s1, const std::string& s2, const std::string& s3);
|
std::string path_join(const std::string& s1, const std::string& s2, const std::string& s3);
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::string path_join(const std::string& s1,
|
std::string path_join(const std::string& s1, const std::string& s2, const std::string& s3, const std::string& s4);
|
||||||
const std::string& s2,
|
|
||||||
const std::string& s3,
|
|
||||||
const std::string& s4);
|
|
||||||
|
|
||||||
/// \brief Iterate through files and optionally directories. Symbolic links are skipped.
|
/// \brief Iterate through files and optionally directories. Symbolic links are skipped.
|
||||||
/// \param path The path to iterate over
|
/// \param path The path to iterate over
|
||||||
|
@ -20,29 +20,21 @@
|
|||||||
#include "ngraph/op/sink.hpp"
|
#include "ngraph/op/sink.hpp"
|
||||||
#include "ngraph/op/util/variable.hpp"
|
#include "ngraph/op/util/variable.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// A user-defined function.
|
/// A user-defined function.
|
||||||
class NGRAPH_API Function
|
class NGRAPH_API Function {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
static constexpr DiscreteTypeInfo type_info{"Function", 0};
|
static constexpr DiscreteTypeInfo type_info{"Function", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const {
|
||||||
Function(const NodeVector& results,
|
return type_info;
|
||||||
const ParameterVector& parameters,
|
}
|
||||||
const std::string& name = "");
|
Function(const NodeVector& results, const ParameterVector& parameters, const std::string& name = "");
|
||||||
|
|
||||||
Function(const OutputVector& results,
|
Function(const OutputVector& results, const ParameterVector& parameters, const std::string& name = "");
|
||||||
const ParameterVector& parameters,
|
|
||||||
const std::string& name = "");
|
|
||||||
|
|
||||||
Function(const std::shared_ptr<Node>& result,
|
Function(const std::shared_ptr<Node>& result, const ParameterVector& parameters, const std::string& name = "");
|
||||||
const ParameterVector& parameters,
|
|
||||||
const std::string& name = "");
|
|
||||||
|
|
||||||
Function(const ResultVector& results,
|
Function(const ResultVector& results, const ParameterVector& parameters, const std::string& name = "");
|
||||||
const ParameterVector& parameters,
|
|
||||||
const std::string& name = "");
|
|
||||||
|
|
||||||
Function(const ResultVector& results,
|
Function(const ResultVector& results,
|
||||||
const SinkVector& sinks,
|
const SinkVector& sinks,
|
||||||
@ -82,9 +74,7 @@ namespace ngraph
|
|||||||
|
|
||||||
/// Constructs a Function. Lists of parameters and variables will be generated automatically
|
/// Constructs a Function. Lists of parameters and variables will be generated automatically
|
||||||
/// based on traversing the graph from the results and the sinks.
|
/// based on traversing the graph from the results and the sinks.
|
||||||
Function(const OutputVector& results,
|
Function(const OutputVector& results, const SinkVector& sinks, const std::string& name = "");
|
||||||
const SinkVector& sinks,
|
|
||||||
const std::string& name = "");
|
|
||||||
|
|
||||||
virtual ~Function() = default;
|
virtual ~Function() = default;
|
||||||
/// Return the number of outputs for this function.
|
/// Return the number of outputs for this function.
|
||||||
@ -147,19 +137,22 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param parameter_index The index of the parameter to replace.
|
/// \param parameter_index The index of the parameter to replace.
|
||||||
/// \param parameter The parameter to substitute for the `parameter_index`th parameter.
|
/// \param parameter The parameter to substitute for the `parameter_index`th parameter.
|
||||||
void replace_parameter(size_t parameter_index,
|
void replace_parameter(size_t parameter_index, const std::shared_ptr<op::Parameter>& parameter);
|
||||||
const std::shared_ptr<op::Parameter>& parameter);
|
|
||||||
|
|
||||||
using topological_sort_t = std::function<std::vector<std::shared_ptr<Node>>(
|
using topological_sort_t =
|
||||||
const std::vector<std::shared_ptr<Node>>& root_nodes)>;
|
std::function<std::vector<std::shared_ptr<Node>>(const std::vector<std::shared_ptr<Node>>& root_nodes)>;
|
||||||
void set_topological_sort(topological_sort_t);
|
void set_topological_sort(topological_sort_t);
|
||||||
|
|
||||||
virtual bool visit_attributes(AttributeVisitor& visitor);
|
virtual bool visit_attributes(AttributeVisitor& visitor);
|
||||||
|
|
||||||
/// Return the function parameters
|
/// Return the function parameters
|
||||||
const ParameterVector& get_parameters() const { return m_parameters; };
|
const ParameterVector& get_parameters() const {
|
||||||
|
return m_parameters;
|
||||||
|
};
|
||||||
/// Return a list of function's outputs
|
/// Return a list of function's outputs
|
||||||
const ResultVector& get_results() const { return m_results; };
|
const ResultVector& get_results() const {
|
||||||
|
return m_results;
|
||||||
|
};
|
||||||
/// Index for parameter, or -1
|
/// Index for parameter, or -1
|
||||||
int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
|
int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
|
||||||
|
|
||||||
@ -176,7 +169,9 @@ namespace ngraph
|
|||||||
EvaluationContext evaluation_context = EvaluationContext()) const;
|
EvaluationContext evaluation_context = EvaluationContext()) const;
|
||||||
|
|
||||||
/// \brief Return a list of function's sinks.
|
/// \brief Return a list of function's sinks.
|
||||||
const SinkVector& get_sinks() const { return m_sinks; }
|
const SinkVector& get_sinks() const {
|
||||||
|
return m_sinks;
|
||||||
|
}
|
||||||
/// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done
|
/// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done
|
||||||
/// manually after all changes.
|
/// manually after all changes.
|
||||||
/// \param sinks new sink nodes
|
/// \param sinks new sink nodes
|
||||||
@ -234,7 +229,9 @@ namespace ngraph
|
|||||||
void remove_variable(const VariablePtr& variable);
|
void remove_variable(const VariablePtr& variable);
|
||||||
|
|
||||||
/// \brief Return a list of function's variables.
|
/// \brief Return a list of function's variables.
|
||||||
const VariableVector& get_variables() const { return m_variables; }
|
const VariableVector& get_variables() const {
|
||||||
|
return m_variables;
|
||||||
|
}
|
||||||
|
|
||||||
/// \brief Return a variable by specified variable_id.
|
/// \brief Return a variable by specified variable_id.
|
||||||
VariablePtr get_variable_by_id(const std::string& variable_id) const;
|
VariablePtr get_variable_by_id(const std::string& variable_id) const;
|
||||||
@ -268,17 +265,13 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>>
|
class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>> : public DirectValueAccessor<std::shared_ptr<Function>> {
|
||||||
: public DirectValueAccessor<std::shared_ptr<Function>>
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::shared_ptr<Function>& value)
|
AttributeAdapter(std::shared_ptr<Function>& value) : DirectValueAccessor<std::shared_ptr<Function>>(value) {}
|
||||||
: DirectValueAccessor<std::shared_ptr<Function>>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>",
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>", 0};
|
||||||
0};
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -18,25 +18,20 @@
|
|||||||
#include "ngraph/function.hpp"
|
#include "ngraph/function.hpp"
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace descriptor {
|
||||||
namespace descriptor
|
|
||||||
{
|
|
||||||
class Input;
|
class Input;
|
||||||
class Output;
|
class Output;
|
||||||
} // namespace descriptor
|
} // namespace descriptor
|
||||||
|
|
||||||
namespace op
|
namespace op {
|
||||||
{
|
namespace v0 {
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
class Parameter;
|
class Parameter;
|
||||||
}
|
}
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
void traverse_nodes(const std::shared_ptr<const Function> p,
|
void traverse_nodes(const std::shared_ptr<const Function> p, std::function<void(std::shared_ptr<Node>)> f);
|
||||||
std::function<void(std::shared_ptr<Node>)> f);
|
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
|
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
|
||||||
@ -228,62 +223,49 @@ namespace ngraph
|
|||||||
/// - If a parameter node appears as a key in both `parameter_replacement_map` _and_ in
|
/// - If a parameter node appears as a key in both `parameter_replacement_map` _and_ in
|
||||||
/// `body_replacement_map`, behavior is unspecified.
|
/// `body_replacement_map`, behavior is unspecified.
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
void replace_nodes(
|
void replace_nodes(const std::shared_ptr<Function>& f,
|
||||||
const std::shared_ptr<Function>& f,
|
const std::unordered_map<std::shared_ptr<op::v0::Parameter>, std::shared_ptr<op::v0::Parameter>>&
|
||||||
const std::unordered_map<std::shared_ptr<op::v0::Parameter>,
|
parameter_replacement_map,
|
||||||
std::shared_ptr<op::v0::Parameter>>& parameter_replacement_map,
|
const std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>>& body_replacement_map);
|
||||||
const std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>>&
|
|
||||||
body_replacement_map);
|
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
|
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
|
||||||
|
|
||||||
/// Topological sort of nodes needed to compute root_nodes
|
/// Topological sort of nodes needed to compute root_nodes
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<std::shared_ptr<Node>> topological_sort(T root_nodes)
|
std::vector<std::shared_ptr<Node>> topological_sort(T root_nodes) {
|
||||||
{
|
|
||||||
std::stack<Node*, std::vector<Node*>> nodes_to_do;
|
std::stack<Node*, std::vector<Node*>> nodes_to_do;
|
||||||
std::unordered_set<Node*> nodes_done;
|
std::unordered_set<Node*> nodes_done;
|
||||||
std::vector<std::shared_ptr<Node>> result;
|
std::vector<std::shared_ptr<Node>> result;
|
||||||
|
|
||||||
for (auto& node : root_nodes)
|
for (auto& node : root_nodes) {
|
||||||
{
|
|
||||||
nodes_to_do.push(node.get());
|
nodes_to_do.push(node.get());
|
||||||
}
|
}
|
||||||
while (nodes_to_do.size() > 0)
|
while (nodes_to_do.size() > 0) {
|
||||||
{
|
|
||||||
Node* node = nodes_to_do.top();
|
Node* node = nodes_to_do.top();
|
||||||
if (nodes_done.count(node) == 0)
|
if (nodes_done.count(node) == 0) {
|
||||||
{
|
|
||||||
bool can_add = true;
|
bool can_add = true;
|
||||||
size_t arg_count = node->get_input_size();
|
size_t arg_count = node->get_input_size();
|
||||||
for (size_t i = 0; i < arg_count; ++i)
|
for (size_t i = 0; i < arg_count; ++i) {
|
||||||
{
|
|
||||||
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
|
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
|
||||||
if (nodes_done.count(dep) == 0)
|
if (nodes_done.count(dep) == 0) {
|
||||||
{
|
|
||||||
can_add = false;
|
can_add = false;
|
||||||
nodes_to_do.push(dep);
|
nodes_to_do.push(dep);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& depptr : node->get_control_dependencies())
|
for (auto& depptr : node->get_control_dependencies()) {
|
||||||
{
|
|
||||||
Node* dep = depptr.get();
|
Node* dep = depptr.get();
|
||||||
if (nodes_done.count(dep) == 0)
|
if (nodes_done.count(dep) == 0) {
|
||||||
{
|
|
||||||
can_add = false;
|
can_add = false;
|
||||||
nodes_to_do.push(dep);
|
nodes_to_do.push(dep);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (can_add)
|
if (can_add) {
|
||||||
{
|
|
||||||
result.push_back(node->shared_from_this());
|
result.push_back(node->shared_from_this());
|
||||||
nodes_to_do.pop();
|
nodes_to_do.pop();
|
||||||
nodes_done.insert(node);
|
nodes_done.insert(node);
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
nodes_to_do.pop();
|
nodes_to_do.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -292,49 +274,39 @@ namespace ngraph
|
|||||||
|
|
||||||
/// Topological sort of just nodes
|
/// Topological sort of just nodes
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<std::shared_ptr<Node>> subgraph_topological_sort(T nodes)
|
std::vector<std::shared_ptr<Node>> subgraph_topological_sort(T nodes) {
|
||||||
{
|
|
||||||
std::stack<Node*, std::vector<Node*>> nodes_to_do;
|
std::stack<Node*, std::vector<Node*>> nodes_to_do;
|
||||||
std::unordered_set<Node*> nodes_done;
|
std::unordered_set<Node*> nodes_done;
|
||||||
std::unordered_set<Node*> nodes_to_emit;
|
std::unordered_set<Node*> nodes_to_emit;
|
||||||
std::vector<std::shared_ptr<Node>> result;
|
std::vector<std::shared_ptr<Node>> result;
|
||||||
|
|
||||||
for (auto& node : nodes)
|
for (auto& node : nodes) {
|
||||||
{
|
|
||||||
nodes_to_emit.insert(node.get());
|
nodes_to_emit.insert(node.get());
|
||||||
nodes_to_do.push(node.get());
|
nodes_to_do.push(node.get());
|
||||||
}
|
}
|
||||||
// NB: Some centos versions implement std::list::size() by counting elements
|
// NB: Some centos versions implement std::list::size() by counting elements
|
||||||
size_t nodes_remaining = nodes_to_emit.size();
|
size_t nodes_remaining = nodes_to_emit.size();
|
||||||
while (nodes_to_do.size() > 0 && nodes_remaining > 0)
|
while (nodes_to_do.size() > 0 && nodes_remaining > 0) {
|
||||||
{
|
|
||||||
Node* node = nodes_to_do.top();
|
Node* node = nodes_to_do.top();
|
||||||
if (nodes_done.count(node) == 0)
|
if (nodes_done.count(node) == 0) {
|
||||||
{
|
|
||||||
bool can_add = true;
|
bool can_add = true;
|
||||||
size_t arg_count = node->get_input_size();
|
size_t arg_count = node->get_input_size();
|
||||||
for (size_t i = 0; i < arg_count; ++i)
|
for (size_t i = 0; i < arg_count; ++i) {
|
||||||
{
|
|
||||||
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
|
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
|
||||||
if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0)
|
if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0) {
|
||||||
{
|
|
||||||
can_add = false;
|
can_add = false;
|
||||||
nodes_to_do.push(dep);
|
nodes_to_do.push(dep);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& depptr : node->get_control_dependencies())
|
for (auto& depptr : node->get_control_dependencies()) {
|
||||||
{
|
|
||||||
Node* dep = depptr.get();
|
Node* dep = depptr.get();
|
||||||
if (nodes_done.count(dep) == 0)
|
if (nodes_done.count(dep) == 0) {
|
||||||
{
|
|
||||||
can_add = false;
|
can_add = false;
|
||||||
nodes_to_do.push(dep);
|
nodes_to_do.push(dep);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (can_add)
|
if (can_add) {
|
||||||
{
|
if (nodes_to_emit.count(node) != 0) {
|
||||||
if (nodes_to_emit.count(node) != 0)
|
|
||||||
{
|
|
||||||
result.push_back(node->shared_from_this());
|
result.push_back(node->shared_from_this());
|
||||||
nodes_remaining--;
|
nodes_remaining--;
|
||||||
}
|
}
|
||||||
@ -343,8 +315,7 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else
|
else {
|
||||||
{
|
|
||||||
nodes_to_do.pop();
|
nodes_to_do.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -352,10 +323,8 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void validate_nodes_and_infer_types(const T& nodes)
|
void validate_nodes_and_infer_types(const T& nodes) {
|
||||||
{
|
for (auto& node : subgraph_topological_sort(nodes)) {
|
||||||
for (auto& node : subgraph_topological_sort(nodes))
|
|
||||||
{
|
|
||||||
node->revalidate_and_infer_types();
|
node->revalidate_and_infer_types();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -365,38 +334,35 @@ namespace ngraph
|
|||||||
bool is_post_dominated(Node* X, Node* Y);
|
bool is_post_dominated(Node* X, Node* Y);
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
bool is_equal_to_const_value(const std::string& const_value,
|
bool is_equal_to_const_value(const std::string& const_value, const Output<Node>& reduce_constant);
|
||||||
const Output<Node>& reduce_constant);
|
|
||||||
|
|
||||||
// input nodes are cloned and returned
|
// input nodes are cloned and returned
|
||||||
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
|
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
|
||||||
// NodeMap output (by reference) fully maps input and cloned nodes
|
// NodeMap output (by reference) fully maps input and cloned nodes
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::vector<std::shared_ptr<ngraph::Node>>
|
std::vector<std::shared_ptr<ngraph::Node>> clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
|
||||||
clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map);
|
NodeMap& node_map);
|
||||||
|
|
||||||
// input nodes are cloned and returned
|
// input nodes are cloned and returned
|
||||||
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
|
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
|
||||||
// NodeMap output (by reference) fully maps input and cloned nodes
|
// NodeMap output (by reference) fully maps input and cloned nodes
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::list<std::shared_ptr<ngraph::Node>>
|
std::list<std::shared_ptr<ngraph::Node>> clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
|
||||||
clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
|
|
||||||
RawNodeOutputMap& node_map);
|
RawNodeOutputMap& node_map);
|
||||||
|
|
||||||
// input function is cloned and returned
|
// input function is cloned and returned
|
||||||
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
|
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
|
||||||
// NodeMap output (by reference) fully maps input and cloned function ops
|
// NodeMap output (by reference) fully maps input and cloned function ops
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func,
|
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func, NodeMap& node_map);
|
||||||
NodeMap& node_map);
|
|
||||||
|
|
||||||
// input function is cloned and returned
|
// input function is cloned and returned
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func);
|
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func);
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>>
|
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>> insert_result_parameter_split(
|
||||||
insert_result_parameter_split(const std::shared_ptr<Node>& src_node,
|
const std::shared_ptr<Node>& src_node,
|
||||||
const std::shared_ptr<Node>& dst_node);
|
const std::shared_ptr<Node>& dst_node);
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
@ -408,9 +374,7 @@ namespace ngraph
|
|||||||
std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape);
|
std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape);
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::shared_ptr<Node> make_constant_from_string(std::string val,
|
std::shared_ptr<Node> make_constant_from_string(std::string val, const element::Type& element_type, const Shape& shape);
|
||||||
const element::Type& element_type,
|
|
||||||
const Shape& shape);
|
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
bool is_zero(const Output<Node>& reduce_constant);
|
bool is_zero(const Output<Node>& reduce_constant);
|
||||||
@ -454,8 +418,7 @@ namespace ngraph
|
|||||||
bool is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks);
|
bool is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks);
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
void plot_graph(
|
void plot_graph(std::shared_ptr<Function> f,
|
||||||
std::shared_ptr<Function> f,
|
|
||||||
const std::string& filename,
|
const std::string& filename,
|
||||||
std::function<void(const Node& node, std::vector<std::string>& attributes)> = nullptr);
|
std::function<void(const Node& node, std::vector<std::string>& attributes)> = nullptr);
|
||||||
|
|
||||||
@ -472,9 +435,7 @@ namespace ngraph
|
|||||||
/// going forward.
|
/// going forward.
|
||||||
/// It returns true if a cycle is found and the first cycle encountered.
|
/// It returns true if a cycle is found and the first cycle encountered.
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
bool check_for_cycles(const ngraph::Function* func,
|
bool check_for_cycles(const ngraph::Function* func, ngraph::NodeVector& cycle_nodes, bool& is_bkwd_cycle);
|
||||||
ngraph::NodeVector& cycle_nodes,
|
|
||||||
bool& is_bkwd_cycle);
|
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
bool replace_output_update_name(Output<Node> node, const Output<Node>& node_input);
|
bool replace_output_update_name(Output<Node> node, const Output<Node>& node_input);
|
||||||
|
@ -12,8 +12,7 @@
|
|||||||
|
|
||||||
#include "ngraph/ngraph_visibility.hpp"
|
#include "ngraph/ngraph_visibility.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// \brief Interval arithmetic
|
/// \brief Interval arithmetic
|
||||||
///
|
///
|
||||||
/// An interval is the set of integers from m_min_val through m_max_val.
|
/// An interval is the set of integers from m_min_val through m_max_val.
|
||||||
@ -21,8 +20,7 @@ namespace ngraph
|
|||||||
/// addition, subtraction, or multiplication of intervals is the smallest interval
|
/// addition, subtraction, or multiplication of intervals is the smallest interval
|
||||||
/// containing the sums, differences, or products of elements of the two intervals. An empty
|
/// containing the sums, differences, or products of elements of the two intervals. An empty
|
||||||
/// interval is canonicalized to [s_max, s_max].
|
/// interval is canonicalized to [s_max, s_max].
|
||||||
class NGRAPH_API Interval
|
class NGRAPH_API Interval {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
using value_type = std::int64_t;
|
using value_type = std::int64_t;
|
||||||
using size_type = std::uint64_t;
|
using size_type = std::uint64_t;
|
||||||
@ -41,26 +39,36 @@ namespace ngraph
|
|||||||
Interval& operator=(const Interval& interval) = default;
|
Interval& operator=(const Interval& interval) = default;
|
||||||
|
|
||||||
/// \brief The number of elements in the interval. Zero if max < min.
|
/// \brief The number of elements in the interval. Zero if max < min.
|
||||||
size_type size() const
|
size_type size() const {
|
||||||
{
|
if (m_max_val == s_max) {
|
||||||
if (m_max_val == s_max)
|
|
||||||
{
|
|
||||||
return m_min_val == s_max ? 0 : s_max;
|
return m_min_val == s_max ? 0 : s_max;
|
||||||
}
|
}
|
||||||
return m_max_val - m_min_val + 1;
|
return m_max_val - m_min_val + 1;
|
||||||
}
|
}
|
||||||
/// \brief Returns true if the interval has no elements
|
/// \brief Returns true if the interval has no elements
|
||||||
bool empty() const { return m_min_val == s_max; }
|
bool empty() const {
|
||||||
|
return m_min_val == s_max;
|
||||||
|
}
|
||||||
/// \brief the inclusive lower bound of the interval
|
/// \brief the inclusive lower bound of the interval
|
||||||
value_type get_min_val() const { return m_min_val; }
|
value_type get_min_val() const {
|
||||||
|
return m_min_val;
|
||||||
|
}
|
||||||
/// \brief Set the inclusive lower bound of the interval
|
/// \brief Set the inclusive lower bound of the interval
|
||||||
void set_min_val(value_type val) { m_min_val = val; }
|
void set_min_val(value_type val) {
|
||||||
|
m_min_val = val;
|
||||||
|
}
|
||||||
/// \brief the inclusive upper bound of the interval
|
/// \brief the inclusive upper bound of the interval
|
||||||
value_type get_max_val() const { return m_max_val; }
|
value_type get_max_val() const {
|
||||||
|
return m_max_val;
|
||||||
|
}
|
||||||
/// \brief Set the inclusive upper bound of the interval
|
/// \brief Set the inclusive upper bound of the interval
|
||||||
void set_max_val(value_type val) { m_max_val = val; }
|
void set_max_val(value_type val) {
|
||||||
|
m_max_val = val;
|
||||||
|
}
|
||||||
/// \brief True if the upper bound is finite
|
/// \brief True if the upper bound is finite
|
||||||
bool has_upper_bound() const { return m_max_val != s_max; }
|
bool has_upper_bound() const {
|
||||||
|
return m_max_val != s_max;
|
||||||
|
}
|
||||||
/// \brief True if min and max bounds match
|
/// \brief True if min and max bounds match
|
||||||
bool operator==(const Interval& interval) const;
|
bool operator==(const Interval& interval) const;
|
||||||
bool operator!=(const Interval& interval) const;
|
bool operator!=(const Interval& interval) const;
|
||||||
@ -91,7 +99,9 @@ namespace ngraph
|
|||||||
Interval& operator&=(const Interval& interval);
|
Interval& operator&=(const Interval& interval);
|
||||||
|
|
||||||
/// \brief True if this interval includes value
|
/// \brief True if this interval includes value
|
||||||
bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; }
|
bool contains(value_type value) const {
|
||||||
|
return m_min_val <= value && value <= m_max_val;
|
||||||
|
}
|
||||||
/// \brief True if this interval includes all the values in interval
|
/// \brief True if this interval includes all the values in interval
|
||||||
bool contains(const Interval& interval) const;
|
bool contains(const Interval& interval) const;
|
||||||
|
|
||||||
|
@ -16,80 +16,67 @@
|
|||||||
# include <sys/time.h>
|
# include <sys/time.h>
|
||||||
# include <unistd.h>
|
# include <unistd.h>
|
||||||
#endif
|
#endif
|
||||||
|
#include <ngraph/ngraph_visibility.hpp>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <ngraph/ngraph_visibility.hpp>
|
namespace ngraph {
|
||||||
|
class ConstString {
|
||||||
namespace ngraph
|
|
||||||
{
|
|
||||||
class ConstString
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
template <size_t SIZE>
|
template <size_t SIZE>
|
||||||
constexpr ConstString(const char (&p)[SIZE])
|
constexpr ConstString(const char (&p)[SIZE]) : m_string(p),
|
||||||
: m_string(p)
|
m_size(SIZE) {}
|
||||||
, m_size(SIZE)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr char operator[](size_t i) const
|
constexpr char operator[](size_t i) const {
|
||||||
{
|
|
||||||
return i < m_size ? m_string[i] : throw std::out_of_range("");
|
return i < m_size ? m_string[i] : throw std::out_of_range("");
|
||||||
}
|
}
|
||||||
constexpr const char* get_ptr(size_t offset) const
|
constexpr const char* get_ptr(size_t offset) const {
|
||||||
{
|
|
||||||
return offset < m_size ? &m_string[offset] : m_string;
|
return offset < m_size ? &m_string[offset] : m_string;
|
||||||
}
|
}
|
||||||
constexpr size_t size() const { return m_size; }
|
constexpr size_t size() const {
|
||||||
|
return m_size;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const char* m_string;
|
const char* m_string;
|
||||||
size_t m_size;
|
size_t m_size;
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr const char* find_last(ConstString s, size_t offset, char ch)
|
constexpr const char* find_last(ConstString s, size_t offset, char ch) {
|
||||||
{
|
return offset == 0 ? s.get_ptr(0) : (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
|
||||||
return offset == 0
|
|
||||||
? s.get_ptr(0)
|
|
||||||
: (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr const char* find_last(ConstString s, char ch)
|
constexpr const char* find_last(ConstString s, char ch) {
|
||||||
{
|
|
||||||
return find_last(s, s.size() - 1, ch);
|
return find_last(s, s.size() - 1, ch);
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr const char* get_file_name(ConstString s) { return find_last(s, '/'); }
|
constexpr const char* get_file_name(ConstString s) {
|
||||||
constexpr const char* trim_file_name(ConstString root, ConstString s)
|
return find_last(s, '/');
|
||||||
{
|
}
|
||||||
|
constexpr const char* trim_file_name(ConstString root, ConstString s) {
|
||||||
return s.get_ptr(root.size());
|
return s.get_ptr(root.size());
|
||||||
}
|
}
|
||||||
enum class LOG_TYPE
|
enum class LOG_TYPE {
|
||||||
{
|
|
||||||
_LOG_TYPE_ERROR,
|
_LOG_TYPE_ERROR,
|
||||||
_LOG_TYPE_WARNING,
|
_LOG_TYPE_WARNING,
|
||||||
_LOG_TYPE_INFO,
|
_LOG_TYPE_INFO,
|
||||||
_LOG_TYPE_DEBUG,
|
_LOG_TYPE_DEBUG,
|
||||||
};
|
};
|
||||||
|
|
||||||
class NGRAPH_API LogHelper
|
class NGRAPH_API LogHelper {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
LogHelper(LOG_TYPE,
|
LogHelper(LOG_TYPE, const char* file, int line, std::function<void(const std::string&)> m_handler_func);
|
||||||
const char* file,
|
|
||||||
int line,
|
|
||||||
std::function<void(const std::string&)> m_handler_func);
|
|
||||||
~LogHelper();
|
~LogHelper();
|
||||||
|
|
||||||
std::ostream& stream() { return m_stream; }
|
std::ostream& stream() {
|
||||||
|
return m_stream;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<void(const std::string&)> m_handler_func;
|
std::function<void(const std::string&)> m_handler_func;
|
||||||
std::stringstream m_stream;
|
std::stringstream m_stream;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Logger
|
class Logger {
|
||||||
{
|
|
||||||
friend class LogHelper;
|
friend class LogHelper;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -142,28 +129,21 @@ namespace ngraph
|
|||||||
.stream()
|
.stream()
|
||||||
#else
|
#else
|
||||||
|
|
||||||
struct NullLogger
|
struct NullLogger {};
|
||||||
{
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
NullLogger&& operator<<(NullLogger&& logger, T&&)
|
NullLogger&& operator<<(NullLogger&& logger, T&&) {
|
||||||
{
|
|
||||||
return std::move(logger);
|
return std::move(logger);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
NullLogger&& operator<<(NullLogger&& logger, const T&)
|
NullLogger&& operator<<(NullLogger&& logger, const T&) {
|
||||||
{
|
|
||||||
return std::move(logger);
|
return std::move(logger);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline NullLogger&&
|
inline NullLogger&& operator<<(
|
||||||
operator<<(NullLogger&& logger,
|
NullLogger&& logger,
|
||||||
std::basic_ostream<char, std::char_traits<char>>& (&)(std::basic_ostream<
|
std::basic_ostream<char, std::char_traits<char>>& (&)(std::basic_ostream<char, std::char_traits<char>>&)) {
|
||||||
char,
|
|
||||||
std::char_traits<char>>&))
|
|
||||||
{
|
|
||||||
return std::move(logger);
|
return std::move(logger);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,8 +18,7 @@
|
|||||||
|
|
||||||
extern "C" NGRAPH_API const char* get_ngraph_version_string();
|
extern "C" NGRAPH_API const char* get_ngraph_version_string();
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
/// \brief Function to query parsed version information of the version of ngraph which
|
/// \brief Function to query parsed version information of the version of ngraph which
|
||||||
/// contains this function. Version information strictly follows Semantic Versioning
|
/// contains this function. Version information strictly follows Semantic Versioning
|
||||||
/// http://semver.org
|
/// http://semver.org
|
||||||
|
@ -28,8 +28,7 @@
|
|||||||
# if defined __INTEL_COMPILER || defined _MSC_VER
|
# if defined __INTEL_COMPILER || defined _MSC_VER
|
||||||
# define ENABLE_UNICODE_PATH_SUPPORT
|
# define ENABLE_UNICODE_PATH_SUPPORT
|
||||||
# endif
|
# endif
|
||||||
#elif defined(__GNUC__) && (__GNUC__ > 5 || (__GNUC__ == 5 && __GNUC_MINOR__ > 2)) || \
|
# elif defined(__GNUC__) && (__GNUC__ > 5 || (__GNUC__ == 5 && __GNUC_MINOR__ > 2)) || defined(__clang__)
|
||||||
defined(__clang__)
|
|
||||||
# define ENABLE_UNICODE_PATH_SUPPORT
|
# define ENABLE_UNICODE_PATH_SUPPORT
|
||||||
# endif
|
# endif
|
||||||
#endif
|
#endif
|
||||||
|
@ -37,8 +37,7 @@
|
|||||||
#include "ngraph/type.hpp"
|
#include "ngraph/type.hpp"
|
||||||
#include "ngraph/variant.hpp"
|
#include "ngraph/variant.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
template <typename NodeType>
|
template <typename NodeType>
|
||||||
class Input;
|
class Input;
|
||||||
|
|
||||||
@ -50,8 +49,7 @@ namespace ngraph
|
|||||||
|
|
||||||
class Function;
|
class Function;
|
||||||
|
|
||||||
namespace runtime
|
namespace runtime {
|
||||||
{
|
|
||||||
class HostTensor;
|
class HostTensor;
|
||||||
}
|
}
|
||||||
using HostTensor = runtime::HostTensor;
|
using HostTensor = runtime::HostTensor;
|
||||||
@ -62,18 +60,15 @@ namespace ngraph
|
|||||||
/// environment) for evaluating ngraph::function.
|
/// environment) for evaluating ngraph::function.
|
||||||
using EvaluationContext = std::map<std::string, std::shared_ptr<Variant>>;
|
using EvaluationContext = std::map<std::string, std::shared_ptr<Variant>>;
|
||||||
|
|
||||||
namespace op
|
namespace op {
|
||||||
{
|
|
||||||
struct AutoBroadcastSpec;
|
struct AutoBroadcastSpec;
|
||||||
|
|
||||||
namespace v0
|
namespace v0 {
|
||||||
{
|
|
||||||
class Result;
|
class Result;
|
||||||
}
|
}
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
namespace pattern
|
namespace pattern {
|
||||||
{
|
|
||||||
class Matcher;
|
class Matcher;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,13 +77,11 @@ namespace ngraph
|
|||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::string node_validation_failure_loc_string(const Node* node);
|
std::string node_validation_failure_loc_string(const Node* node);
|
||||||
|
|
||||||
const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node,
|
const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node, size_t i);
|
||||||
size_t i);
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
const NodeVector& check_single_output_args(const NodeVector& args);
|
const NodeVector& check_single_output_args(const NodeVector& args);
|
||||||
|
|
||||||
const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node,
|
const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node, size_t i);
|
||||||
size_t i);
|
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
OutputVector as_output_vector(const NodeVector& args);
|
OutputVector as_output_vector(const NodeVector& args);
|
||||||
@ -118,13 +111,13 @@ namespace ngraph
|
|||||||
/// a runtime error.
|
/// a runtime error.
|
||||||
|
|
||||||
#define TYPE_CASE(a) \
|
#define TYPE_CASE(a) \
|
||||||
case element::Type_t::a: rc = evaluate<element::Type_t::a>
|
case element::Type_t::a: \
|
||||||
|
rc = evaluate<element::Type_t::a>
|
||||||
|
|
||||||
/// Nodes are the backbone of the graph of Value dataflow. Every node has
|
/// Nodes are the backbone of the graph of Value dataflow. Every node has
|
||||||
/// zero or more nodes as arguments and one value, which is either a tensor
|
/// zero or more nodes as arguments and one value, which is either a tensor
|
||||||
/// or a (possibly empty) tuple of values.
|
/// or a (possibly empty) tuple of values.
|
||||||
class NGRAPH_API Node : public std::enable_shared_from_this<Node>
|
class NGRAPH_API Node : public std::enable_shared_from_this<Node> {
|
||||||
{
|
|
||||||
// For access to m_outputs.
|
// For access to m_outputs.
|
||||||
friend class descriptor::Input;
|
friend class descriptor::Input;
|
||||||
|
|
||||||
@ -196,7 +189,9 @@ namespace ngraph
|
|||||||
public:
|
public:
|
||||||
virtual ~Node();
|
virtual ~Node();
|
||||||
|
|
||||||
virtual bool visit_attributes(AttributeVisitor&) { return false; }
|
virtual bool visit_attributes(AttributeVisitor&) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
/// \returns the autobroadcasr spec
|
/// \returns the autobroadcasr spec
|
||||||
virtual const op::AutoBroadcastSpec& get_autob() const;
|
virtual const op::AutoBroadcastSpec& get_autob() const;
|
||||||
|
|
||||||
@ -208,8 +203,7 @@ namespace ngraph
|
|||||||
/// \param output_values Tensors for the outputs to compute. One for each result
|
/// \param output_values Tensors for the outputs to compute. One for each result
|
||||||
/// \param input_values Tensors for the inputs. One for each inputs.
|
/// \param input_values Tensors for the inputs. One for each inputs.
|
||||||
/// \returns true if successful
|
/// \returns true if successful
|
||||||
virtual bool evaluate(const HostTensorVector& output_values,
|
virtual bool evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const;
|
||||||
const HostTensorVector& input_values) const;
|
|
||||||
/// \brief Evaluates the op on input_values putting results in output_values
|
/// \brief Evaluates the op on input_values putting results in output_values
|
||||||
/// \param output_values Tensors for the outputs to compute. One for each result
|
/// \param output_values Tensors for the outputs to compute. One for each result
|
||||||
/// \param input_values Tensors for the inputs. One for each inputs.
|
/// \param input_values Tensors for the inputs. One for each inputs.
|
||||||
@ -227,12 +221,16 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return A vector of nodes comprising the sub-graph. The order of output
|
/// \return A vector of nodes comprising the sub-graph. The order of output
|
||||||
/// tensors must match the match output tensors of the FusedOp
|
/// tensors must match the match output tensors of the FusedOp
|
||||||
virtual OutputVector decompose_op() const { return OutputVector(); }
|
virtual OutputVector decompose_op() const {
|
||||||
|
return OutputVector();
|
||||||
|
}
|
||||||
/// Returns the NodeTypeInfo for the node's class.
|
/// Returns the NodeTypeInfo for the node's class.
|
||||||
/// During transition to type_info, returns a dummy type_info for Node if the class
|
/// During transition to type_info, returns a dummy type_info for Node if the class
|
||||||
/// has not been updated yet.
|
/// has not been updated yet.
|
||||||
virtual const type_info_t& get_type_info() const = 0;
|
virtual const type_info_t& get_type_info() const = 0;
|
||||||
const char* get_type_name() const { return get_type_info().name; }
|
const char* get_type_name() const {
|
||||||
|
return get_type_info().name;
|
||||||
|
}
|
||||||
/// Sets/replaces the arguments with new arguments.
|
/// Sets/replaces the arguments with new arguments.
|
||||||
void set_arguments(const NodeVector& arguments);
|
void set_arguments(const NodeVector& arguments);
|
||||||
/// Sets/replaces the arguments with new arguments.
|
/// Sets/replaces the arguments with new arguments.
|
||||||
@ -240,16 +238,13 @@ namespace ngraph
|
|||||||
/// Sets/replaces the arguments with new arguments.
|
/// Sets/replaces the arguments with new arguments.
|
||||||
void set_argument(size_t position, const Output<Node>& argument);
|
void set_argument(size_t position, const Output<Node>& argument);
|
||||||
|
|
||||||
void set_output_type(size_t i,
|
void set_output_type(size_t i, const element::Type& element_type, const PartialShape& pshape);
|
||||||
const element::Type& element_type,
|
|
||||||
const PartialShape& pshape);
|
|
||||||
|
|
||||||
/// Sets the number of outputs
|
/// Sets the number of outputs
|
||||||
void set_output_size(size_t output_size);
|
void set_output_size(size_t output_size);
|
||||||
|
|
||||||
void invalidate_values();
|
void invalidate_values();
|
||||||
virtual void revalidate_and_infer_types()
|
virtual void revalidate_and_infer_types() {
|
||||||
{
|
|
||||||
invalidate_values();
|
invalidate_values();
|
||||||
validate_and_infer_types();
|
validate_and_infer_types();
|
||||||
}
|
}
|
||||||
@ -273,7 +268,9 @@ namespace ngraph
|
|||||||
const std::string& get_friendly_name() const;
|
const std::string& get_friendly_name() const;
|
||||||
|
|
||||||
virtual bool is_dynamic() const;
|
virtual bool is_dynamic() const;
|
||||||
size_t get_instance_id() const { return m_instance_id; }
|
size_t get_instance_id() const {
|
||||||
|
return m_instance_id;
|
||||||
|
}
|
||||||
/// \brief Writes a description of a node to a stream
|
/// \brief Writes a description of a node to a stream
|
||||||
/// \param os The stream; should be returned
|
/// \param os The stream; should be returned
|
||||||
/// \param depth How many levels of inputs to describe
|
/// \param depth How many levels of inputs to describe
|
||||||
@ -346,8 +343,7 @@ namespace ngraph
|
|||||||
descriptor::Tensor& get_input_tensor(size_t i) const;
|
descriptor::Tensor& get_input_tensor(size_t i) const;
|
||||||
|
|
||||||
/// Returns the tensor name for output i
|
/// Returns the tensor name for output i
|
||||||
NGRAPH_DEPRECATED(
|
NGRAPH_DEPRECATED("The tensor name was deprecated. Use get_output_tensor(i).get_names() instead.")
|
||||||
"The tensor name was deprecated. Use get_output_tensor(i).get_names() instead.")
|
|
||||||
const std::string& get_output_tensor_name(size_t i) const;
|
const std::string& get_output_tensor_name(size_t i) const;
|
||||||
|
|
||||||
std::set<Input<Node>> get_output_target_inputs(size_t i) const;
|
std::set<Input<Node>> get_output_target_inputs(size_t i) const;
|
||||||
@ -368,8 +364,7 @@ namespace ngraph
|
|||||||
const PartialShape& get_input_partial_shape(size_t i) const;
|
const PartialShape& get_input_partial_shape(size_t i) const;
|
||||||
|
|
||||||
/// Returns the tensor name for input i
|
/// Returns the tensor name for input i
|
||||||
NGRAPH_DEPRECATED(
|
NGRAPH_DEPRECATED("The tensor name was deprecated. Use get_input_tensor(i).get_names() instead.")
|
||||||
"The tensor name was deprecated. Use get_input_tensor(i).get_names() instead.")
|
|
||||||
const std::string& get_input_tensor_name(size_t i) const;
|
const std::string& get_input_tensor_name(size_t i) const;
|
||||||
|
|
||||||
std::unordered_set<descriptor::Tensor*> liveness_new_list;
|
std::unordered_set<descriptor::Tensor*> liveness_new_list;
|
||||||
@ -384,8 +379,7 @@ namespace ngraph
|
|||||||
|
|
||||||
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
|
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
|
||||||
|
|
||||||
std::shared_ptr<Node> copy_with_new_inputs(
|
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& inputs,
|
||||||
const OutputVector& inputs,
|
|
||||||
const std::vector<std::shared_ptr<Node>>& control_dependencies) const;
|
const std::vector<std::shared_ptr<Node>>& control_dependencies) const;
|
||||||
|
|
||||||
/// True if this and node have one output with same element type and shape
|
/// True if this and node have one output with same element type and shape
|
||||||
@ -393,21 +387,22 @@ namespace ngraph
|
|||||||
|
|
||||||
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
|
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
|
||||||
|
|
||||||
RTMap& get_rt_info() { return m_rt_info; }
|
RTMap& get_rt_info() {
|
||||||
const RTMap& get_rt_info() const { return m_rt_info; }
|
return m_rt_info;
|
||||||
|
}
|
||||||
|
const RTMap& get_rt_info() const {
|
||||||
|
return m_rt_info;
|
||||||
|
}
|
||||||
const std::unordered_set<std::string>& get_provenance_tags() const;
|
const std::unordered_set<std::string>& get_provenance_tags() const;
|
||||||
void add_provenance_tag(const std::string& tag);
|
void add_provenance_tag(const std::string& tag);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void add_provenance_tags(T tag_set)
|
void add_provenance_tags(T tag_set) {
|
||||||
{
|
for (auto tag : tag_set) {
|
||||||
for (auto tag : tag_set)
|
|
||||||
{
|
|
||||||
add_provenance_tag(tag);
|
add_provenance_tag(tag);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// \brief Adds tag_set to this node and all intermediate nodes above base
|
/// \brief Adds tag_set to this node and all intermediate nodes above base
|
||||||
void add_provenance_tags_above(const OutputVector& base,
|
void add_provenance_tags_above(const OutputVector& base, const std::unordered_set<std::string>& tag_set);
|
||||||
const std::unordered_set<std::string>& tag_set);
|
|
||||||
void remove_provenance_tag(const std::string& tag);
|
void remove_provenance_tag(const std::string& tag);
|
||||||
/// \brief Add node to additional nodes that receive tags
|
/// \brief Add node to additional nodes that receive tags
|
||||||
void add_provenance_group_member(const std::shared_ptr<Node>& node);
|
void add_provenance_group_member(const std::shared_ptr<Node>& node);
|
||||||
@ -433,12 +428,18 @@ namespace ngraph
|
|||||||
NodeVector get_users(bool check_is_used = false) const;
|
NodeVector get_users(bool check_is_used = false) const;
|
||||||
|
|
||||||
/// \return Version of this node
|
/// \return Version of this node
|
||||||
virtual size_t get_version() const { return get_type_info().version; }
|
virtual size_t get_version() const {
|
||||||
|
return get_type_info().version;
|
||||||
|
}
|
||||||
|
|
||||||
NGRAPH_DEPRECATED("This method is deprecated and will be removed soon.")
|
NGRAPH_DEPRECATED("This method is deprecated and will be removed soon.")
|
||||||
virtual std::shared_ptr<Node> get_default_value() const { return nullptr; }
|
virtual std::shared_ptr<Node> get_default_value() const {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
/// Use instance ids for comparison instead of memory addresses to improve determinism
|
/// Use instance ids for comparison instead of memory addresses to improve determinism
|
||||||
bool operator<(const Node& other) const { return m_instance_id < other.m_instance_id; }
|
bool operator<(const Node& other) const {
|
||||||
|
return m_instance_id < other.m_instance_id;
|
||||||
|
}
|
||||||
/// \return A vector containing a handle for each of this node's inputs, in order.
|
/// \return A vector containing a handle for each of this node's inputs, in order.
|
||||||
// TODO: Rename to get_inputs()?
|
// TODO: Rename to get_inputs()?
|
||||||
std::vector<Input<Node>> inputs();
|
std::vector<Input<Node>> inputs();
|
||||||
@ -474,12 +475,10 @@ namespace ngraph
|
|||||||
/// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
|
/// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
|
||||||
Output<const Node> output(size_t output_index) const;
|
Output<const Node> output(size_t output_index) const;
|
||||||
|
|
||||||
void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations)
|
void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations) {
|
||||||
{
|
|
||||||
m_op_annotations = op_annotations;
|
m_op_annotations = op_annotations;
|
||||||
}
|
}
|
||||||
std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations() const
|
std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations() const {
|
||||||
{
|
|
||||||
return m_op_annotations;
|
return m_op_annotations;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -558,23 +557,21 @@ namespace ngraph
|
|||||||
static const ::ngraph::Node::type_info_t& get_type_info_static()
|
static const ::ngraph::Node::type_info_t& get_type_info_static()
|
||||||
|
|
||||||
#define _NGRAPH_RTTI_DEFINITION_COMMON(CLASS) \
|
#define _NGRAPH_RTTI_DEFINITION_COMMON(CLASS) \
|
||||||
const ::ngraph::Node::type_info_t& CLASS::get_type_info() const \
|
const ::ngraph::Node::type_info_t& CLASS::get_type_info() const { \
|
||||||
{ \
|
|
||||||
return get_type_info_static(); \
|
return get_type_info_static(); \
|
||||||
} \
|
} \
|
||||||
const ::ngraph::Node::type_info_t CLASS::type_info = CLASS::get_type_info_static()
|
const ::ngraph::Node::type_info_t CLASS::type_info = CLASS::get_type_info_static()
|
||||||
#define _NGRAPH_RTTI_DEFINITION_WITH_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX, PARENT_CLASS) \
|
#define _NGRAPH_RTTI_DEFINITION_WITH_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX, PARENT_CLASS) \
|
||||||
const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \
|
const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() { \
|
||||||
{ \
|
static const ::ngraph::Node::type_info_t type_info_static{TYPE_NAME, \
|
||||||
static const ::ngraph::Node::type_info_t type_info_static{ \
|
_VERSION_INDEX, \
|
||||||
TYPE_NAME, _VERSION_INDEX, &PARENT_CLASS::get_type_info_static()}; \
|
&PARENT_CLASS::get_type_info_static()}; \
|
||||||
return type_info_static; \
|
return type_info_static; \
|
||||||
} \
|
} \
|
||||||
_NGRAPH_RTTI_DEFINITION_COMMON(CLASS)
|
_NGRAPH_RTTI_DEFINITION_COMMON(CLASS)
|
||||||
|
|
||||||
#define _NGRAPH_RTTI_DEFINITION_NO_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX) \
|
#define _NGRAPH_RTTI_DEFINITION_NO_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX) \
|
||||||
const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \
|
const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() { \
|
||||||
{ \
|
|
||||||
static const ::ngraph::Node::type_info_t type_info_static{TYPE_NAME, _VERSION_INDEX}; \
|
static const ::ngraph::Node::type_info_t type_info_static{TYPE_NAME, _VERSION_INDEX}; \
|
||||||
return type_info_static; \
|
return type_info_static; \
|
||||||
} \
|
} \
|
||||||
@ -610,23 +607,14 @@ namespace ngraph
|
|||||||
/// For convenience, TYPE_NAME and CLASS name are recommended to be the same.
|
/// For convenience, TYPE_NAME and CLASS name are recommended to be the same.
|
||||||
///
|
///
|
||||||
#define NGRAPH_RTTI_DEFINITION(...) \
|
#define NGRAPH_RTTI_DEFINITION(...) \
|
||||||
_NGRAPH_RTTI_EXPAND(_NGRAPH_RTTI_DEFINITION_SELECTOR( \
|
_NGRAPH_RTTI_EXPAND(_NGRAPH_RTTI_DEFINITION_SELECTOR(__VA_ARGS__, \
|
||||||
__VA_ARGS__, _NGRAPH_RTTI_DEFINITION_WITH_PARENT, _NGRAPH_RTTI_DEFINITION_NO_PARENT)( \
|
_NGRAPH_RTTI_DEFINITION_WITH_PARENT, \
|
||||||
__VA_ARGS__))
|
_NGRAPH_RTTI_DEFINITION_NO_PARENT)(__VA_ARGS__))
|
||||||
|
|
||||||
// Like an Output but with a Node* instead of a shared_ptr<Node>
|
// Like an Output but with a Node* instead of a shared_ptr<Node>
|
||||||
struct RawNodeOutput
|
struct RawNodeOutput {
|
||||||
{
|
RawNodeOutput(const Output<Node>& value) : node(value.get_node()), index(value.get_index()) {}
|
||||||
RawNodeOutput(const Output<Node>& value)
|
RawNodeOutput(Node* node, size_t index) : node(node), index(index) {}
|
||||||
: node(value.get_node())
|
|
||||||
, index(value.get_index())
|
|
||||||
{
|
|
||||||
}
|
|
||||||
RawNodeOutput(Node* node, size_t index)
|
|
||||||
: node(node)
|
|
||||||
, index(index)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
RawNodeOutput(const RawNodeOutput&) = default;
|
RawNodeOutput(const RawNodeOutput&) = default;
|
||||||
RawNodeOutput() = default;
|
RawNodeOutput() = default;
|
||||||
RawNodeOutput& operator=(const RawNodeOutput&) = default;
|
RawNodeOutput& operator=(const RawNodeOutput&) = default;
|
||||||
@ -634,49 +622,56 @@ namespace ngraph
|
|||||||
Node* node;
|
Node* node;
|
||||||
size_t index{0};
|
size_t index{0};
|
||||||
|
|
||||||
operator Output<Node>() { return Output<Node>(node->shared_from_this(), index); }
|
operator Output<Node>() {
|
||||||
bool operator==(const RawNodeOutput& other) const
|
return Output<Node>(node->shared_from_this(), index);
|
||||||
{
|
}
|
||||||
|
bool operator==(const RawNodeOutput& other) const {
|
||||||
return node == other.node && index == other.index;
|
return node == other.node && index == other.index;
|
||||||
}
|
}
|
||||||
bool operator!=(const RawNodeOutput& other) const { return !(*this == other); }
|
bool operator!=(const RawNodeOutput& other) const {
|
||||||
bool operator<(const RawNodeOutput& other) const
|
return !(*this == other);
|
||||||
{
|
}
|
||||||
|
bool operator<(const RawNodeOutput& other) const {
|
||||||
return node < other.node || (node == other.node && index < other.index);
|
return node < other.node || (node == other.node && index < other.index);
|
||||||
}
|
}
|
||||||
bool operator>(const RawNodeOutput& other) const
|
bool operator>(const RawNodeOutput& other) const {
|
||||||
{
|
|
||||||
return node > other.node || (node == other.node && index > other.index);
|
return node > other.node || (node == other.node && index > other.index);
|
||||||
}
|
}
|
||||||
bool operator<=(const RawNodeOutput& other) const { return !(*this > other); }
|
bool operator<=(const RawNodeOutput& other) const {
|
||||||
bool operator>=(const RawNodeOutput& other) const { return !(*this < other); }
|
return !(*this > other);
|
||||||
|
}
|
||||||
|
bool operator>=(const RawNodeOutput& other) const {
|
||||||
|
return !(*this < other);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Visits a reference to a node that has been registered with the visitor.
|
/// \brief Visits a reference to a node that has been registered with the visitor.
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<Node>> : public VisitorAdapter
|
class NGRAPH_API AttributeAdapter<std::shared_ptr<Node>> : public VisitorAdapter {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::shared_ptr<Node>& value);
|
AttributeAdapter(std::shared_ptr<Node>& value);
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Node>>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Node>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<Node>& m_ref;
|
std::shared_ptr<Node>& m_ref;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<NodeVector> : public VisitorAdapter
|
class NGRAPH_API AttributeAdapter<NodeVector> : public VisitorAdapter {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(NodeVector& ref);
|
AttributeAdapter(NodeVector& ref);
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<NodeVector>", 0};
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<NodeVector>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
NodeVector& m_ref;
|
NodeVector& m_ref;
|
||||||
@ -684,25 +679,17 @@ namespace ngraph
|
|||||||
|
|
||||||
using RawNodeOutputMap = std::map<RawNodeOutput, Output<Node>>;
|
using RawNodeOutputMap = std::map<RawNodeOutput, Output<Node>>;
|
||||||
|
|
||||||
class NGRAPH_API NodeValidationFailure : public CheckFailure
|
class NGRAPH_API NodeValidationFailure : public CheckFailure {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NodeValidationFailure(const CheckLocInfo& check_loc_info,
|
NodeValidationFailure(const CheckLocInfo& check_loc_info, const Node* node, const std::string& explanation)
|
||||||
const Node* node,
|
: CheckFailure(check_loc_info, node_validation_failure_loc_string(node), explanation) {}
|
||||||
const std::string& explanation)
|
|
||||||
: CheckFailure(check_loc_info, node_validation_failure_loc_string(node), explanation)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
#define NODE_VALIDATION_CHECK(node, ...) \
|
#define NODE_VALIDATION_CHECK(node, ...) NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__)
|
||||||
NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__)
|
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void check_new_args_count(const Node* node, T new_args)
|
void check_new_args_count(const Node* node, T new_args) {
|
||||||
{
|
|
||||||
NODE_VALIDATION_CHECK(node,
|
NODE_VALIDATION_CHECK(node,
|
||||||
new_args.size() == node->input_values().size(),
|
new_args.size() == node->input_values().size(),
|
||||||
"clone_with_new_inputs() expected ",
|
"clone_with_new_inputs() expected ",
|
||||||
|
@ -13,22 +13,18 @@
|
|||||||
#include "ngraph/type/element_type.hpp"
|
#include "ngraph/type/element_type.hpp"
|
||||||
#include "ngraph/variant.hpp"
|
#include "ngraph/variant.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
class Node;
|
class Node;
|
||||||
|
|
||||||
template <typename NodeType>
|
template <typename NodeType>
|
||||||
class Output;
|
class Output;
|
||||||
|
|
||||||
template <typename NodeType>
|
template <typename NodeType>
|
||||||
class Input
|
class Input {};
|
||||||
{
|
|
||||||
};
|
|
||||||
|
|
||||||
/// \brief A handle for one of a node's inputs.
|
/// \brief A handle for one of a node's inputs.
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API Input<Node>
|
class NGRAPH_API Input<Node> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
/// \brief Constructs a Input.
|
/// \brief Constructs a Input.
|
||||||
/// \param node Pointer to the node for the input handle.
|
/// \param node Pointer to the node for the input handle.
|
||||||
@ -79,8 +75,7 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \brief A handle for one of a node's inputs.
|
/// \brief A handle for one of a node's inputs.
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API Input<const Node>
|
class NGRAPH_API Input<const Node> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
/// \brief Constructs a Input.
|
/// \brief Constructs a Input.
|
||||||
/// \param node Pointer to the node for the input handle.
|
/// \param node Pointer to the node for the input handle.
|
||||||
|
@ -14,22 +14,18 @@
|
|||||||
#include "ngraph/type/element_type.hpp"
|
#include "ngraph/type/element_type.hpp"
|
||||||
#include "ngraph/variant.hpp"
|
#include "ngraph/variant.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
|
||||||
class Node;
|
class Node;
|
||||||
|
|
||||||
template <typename NodeType>
|
template <typename NodeType>
|
||||||
class Input;
|
class Input;
|
||||||
|
|
||||||
template <typename NodeType>
|
template <typename NodeType>
|
||||||
class Output
|
class Output {};
|
||||||
{
|
|
||||||
};
|
|
||||||
|
|
||||||
/// \brief A handle for one of a node's outputs.
|
/// \brief A handle for one of a node's outputs.
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API Output<Node>
|
class NGRAPH_API Output<Node> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
/// \brief Constructs a Output.
|
/// \brief Constructs a Output.
|
||||||
/// \param node A pointer to the node for the output handle.
|
/// \param node A pointer to the node for the output handle.
|
||||||
@ -46,10 +42,7 @@ namespace ngraph
|
|||||||
/// \brief Constructs a Output, referencing the zeroth output of the node.
|
/// \brief Constructs a Output, referencing the zeroth output of the node.
|
||||||
/// \param node A `shared_ptr` to the node for the output handle.
|
/// \param node A `shared_ptr` to the node for the output handle.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Output(const std::shared_ptr<T>& node)
|
Output(const std::shared_ptr<T>& node) : Output(node ? node->get_default_output() : Output<Node>()) {}
|
||||||
: Output(node ? node->get_default_output() : Output<Node>())
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A null output
|
/// A null output
|
||||||
Output() = default;
|
Output() = default;
|
||||||
@ -109,8 +102,7 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API Output<const Node>
|
class NGRAPH_API Output<const Node> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
/// \brief Constructs a Output.
|
/// \brief Constructs a Output.
|
||||||
/// \param node A pointer to the node for the output handle.
|
/// \param node A pointer to the node for the output handle.
|
||||||
@ -127,10 +119,7 @@ namespace ngraph
|
|||||||
/// \brief Constructs a Output, referencing the zeroth output of the node.
|
/// \brief Constructs a Output, referencing the zeroth output of the node.
|
||||||
/// \param node A `shared_ptr` to the node for the output handle.
|
/// \param node A `shared_ptr` to the node for the output handle.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Output(const std::shared_ptr<T>& node)
|
Output(const std::shared_ptr<T>& node) : Output(node ? node->get_default_output() : Output<const Node>()) {}
|
||||||
: Output(node ? node->get_default_output() : Output<const Node>())
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A null output
|
/// A null output
|
||||||
Output() = default;
|
Output() = default;
|
||||||
|
@ -8,22 +8,22 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise absolute value operation.
|
/// \brief Elementwise absolute value operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API Abs : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Abs : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
static constexpr NodeTypeInfo type_info{"Abs", 0};
|
static constexpr NodeTypeInfo type_info{"Abs", 0};
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
const NodeTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
/// \brief Constructs an absolute value operation.
|
/// \brief Constructs an absolute value operation.
|
||||||
Abs() = default;
|
Abs() = default;
|
||||||
bool visit_attributes(AttributeVisitor&) override { return true; }
|
bool visit_attributes(AttributeVisitor&) override {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
/// \brief Constructs an absolute value operation.
|
/// \brief Constructs an absolute value operation.
|
||||||
///
|
///
|
||||||
/// \param arg Output that produces the input tensor.<br>
|
/// \param arg Output that produces the input tensor.<br>
|
||||||
@ -33,11 +33,9 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
Abs(const Output<Node>& arg);
|
Abs(const Output<Node>& arg);
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -8,19 +8,17 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise inverse cosine (arccos) operation.
|
/// \brief Elementwise inverse cosine (arccos) operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API Acos : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Acos : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
static constexpr NodeTypeInfo type_info{"Acos", 0};
|
static constexpr NodeTypeInfo type_info{"Acos", 0};
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
const NodeTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
/// \brief Constructs an arccos operation.
|
/// \brief Constructs an arccos operation.
|
||||||
Acos() = default;
|
Acos() = default;
|
||||||
/// \brief Constructs an arccos operation.
|
/// \brief Constructs an arccos operation.
|
||||||
@ -31,11 +29,11 @@ namespace ngraph
|
|||||||
/// Output `[d1, ...]`
|
/// Output `[d1, ...]`
|
||||||
///
|
///
|
||||||
Acos(const Output<Node>& arg);
|
Acos(const Output<Node>& arg);
|
||||||
bool visit_attributes(AttributeVisitor&) override { return true; }
|
bool visit_attributes(AttributeVisitor&) override {
|
||||||
std::shared_ptr<Node>
|
return true;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
}
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -8,16 +8,12 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v3 {
|
||||||
{
|
|
||||||
namespace v3
|
|
||||||
{
|
|
||||||
/// \brief Elementwise inverse hyperbolic cos operation.
|
/// \brief Elementwise inverse hyperbolic cos operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API Acosh : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Acosh : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -32,11 +28,11 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
Acosh(const Output<Node>& arg);
|
Acosh(const Output<Node>& arg);
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
bool visit_attributes(AttributeVisitor&) override {
|
||||||
bool visit_attributes(AttributeVisitor&) override { return true; }
|
return true;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
}
|
||||||
const HostTensorVector& inputs) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v3
|
} // namespace v3
|
||||||
|
@ -7,16 +7,12 @@
|
|||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v8 {
|
||||||
{
|
|
||||||
namespace v8
|
|
||||||
{
|
|
||||||
/// \brief Adaptive average pooling operation.
|
/// \brief Adaptive average pooling operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API AdaptiveAvgPool : public Op
|
class NGRAPH_API AdaptiveAvgPool : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -35,8 +31,7 @@ namespace ngraph
|
|||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v8
|
} // namespace v8
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -7,16 +7,12 @@
|
|||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v8 {
|
||||||
{
|
|
||||||
namespace v8
|
|
||||||
{
|
|
||||||
/// \brief Adaptive max pooling operation.
|
/// \brief Adaptive max pooling operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API AdaptiveMaxPool : public Op
|
class NGRAPH_API AdaptiveMaxPool : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -33,18 +29,18 @@ namespace ngraph
|
|||||||
/// \param index_element_type Specifies the output tensor type for indices
|
/// \param index_element_type Specifies the output tensor type for indices
|
||||||
/// output
|
/// output
|
||||||
///
|
///
|
||||||
AdaptiveMaxPool(
|
AdaptiveMaxPool(const Output<Node>& data,
|
||||||
const Output<Node>& data,
|
|
||||||
const Output<Node>& output_shape,
|
const Output<Node>& output_shape,
|
||||||
const ngraph::element::Type& index_element_type = ngraph::element::i64);
|
const ngraph::element::Type& index_element_type = ngraph::element::i64);
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
element::Type get_index_element_type() const { return m_index_element_type; }
|
element::Type get_index_element_type() const {
|
||||||
|
return m_index_element_type;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
ngraph::element::Type m_index_element_type = ngraph::element::i64;
|
ngraph::element::Type m_index_element_type = ngraph::element::i64;
|
||||||
|
@ -8,24 +8,17 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief Elementwise addition operation.
|
/// \brief Elementwise addition operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API Add : public util::BinaryElementwiseArithmetic
|
class NGRAPH_API Add : public util::BinaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
/// \brief Constructs an uninitialized addition operation
|
/// \brief Constructs an uninitialized addition operation
|
||||||
Add()
|
Add() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) {}
|
||||||
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Constructs an addition operation.
|
/// \brief Constructs an addition operation.
|
||||||
///
|
///
|
||||||
@ -40,16 +33,13 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
Add(const Output<Node>& arg0,
|
Add(const Output<Node>& arg0,
|
||||||
const Output<Node>& arg1,
|
const Output<Node>& arg1,
|
||||||
const AutoBroadcastSpec& auto_broadcast =
|
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
||||||
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
|
@ -8,16 +8,12 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/binary_elementwise_logical.hpp"
|
#include "ngraph/op/util/binary_elementwise_logical.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief Elementwise logical-and operation.
|
/// \brief Elementwise logical-and operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API LogicalAnd : public util::BinaryElementwiseLogical
|
class NGRAPH_API LogicalAnd : public util::BinaryElementwiseLogical {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
/// \brief Constructs a logical-and operation.
|
/// \brief Constructs a logical-and operation.
|
||||||
@ -35,14 +31,11 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
LogicalAnd(const Output<Node>& arg0,
|
LogicalAnd(const Output<Node>& arg0,
|
||||||
const Output<Node>& arg1,
|
const Output<Node>& arg1,
|
||||||
const AutoBroadcastSpec& auto_broadcast =
|
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
||||||
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
|
@ -8,19 +8,17 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise inverse sine (arcsin) operation.
|
/// \brief Elementwise inverse sine (arcsin) operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API Asin : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Asin : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
static constexpr NodeTypeInfo type_info{"Asin", 0};
|
static constexpr NodeTypeInfo type_info{"Asin", 0};
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
const NodeTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
/// \brief Constructs an arcsin operation.
|
/// \brief Constructs an arcsin operation.
|
||||||
Asin() = default;
|
Asin() = default;
|
||||||
/// \brief Constructs an arcsin operation.
|
/// \brief Constructs an arcsin operation.
|
||||||
@ -32,11 +30,11 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
Asin(const Output<Node>& arg);
|
Asin(const Output<Node>& arg);
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
bool visit_attributes(AttributeVisitor&) override {
|
||||||
bool visit_attributes(AttributeVisitor&) override { return true; }
|
return true;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
}
|
||||||
const HostTensorVector& inputs) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -8,16 +8,12 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v3 {
|
||||||
{
|
|
||||||
namespace v3
|
|
||||||
{
|
|
||||||
/// \brief Elementwise inverse hyperbolic sin operation.
|
/// \brief Elementwise inverse hyperbolic sin operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API Asinh : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Asinh : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -32,11 +28,11 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
Asinh(const Output<Node>& arg);
|
Asinh(const Output<Node>& arg);
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
bool visit_attributes(AttributeVisitor&) override {
|
||||||
bool visit_attributes(AttributeVisitor&) override { return true; }
|
return true;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
}
|
||||||
const HostTensorVector& inputs) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v3
|
} // namespace v3
|
||||||
|
@ -8,27 +8,19 @@
|
|||||||
#include "ngraph/op/util/variable.hpp"
|
#include "ngraph/op/util/variable.hpp"
|
||||||
#include "ngraph/op/util/variable_extension.hpp"
|
#include "ngraph/op/util/variable_extension.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
class NGRAPH_API AssignBase : public Sink, public VariableExtension {
|
||||||
{
|
|
||||||
class NGRAPH_API AssignBase : public Sink, public VariableExtension
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
AssignBase() = default;
|
AssignBase() = default;
|
||||||
/// \brief Constructs an AssignBase operation.
|
/// \brief Constructs an AssignBase operation.
|
||||||
explicit AssignBase(const OutputVector& arguments)
|
explicit AssignBase(const OutputVector& arguments) : Sink(arguments) {}
|
||||||
: Sink(arguments)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace v3
|
namespace v3 {
|
||||||
{
|
|
||||||
/// \brief Assign operation sets an input value to the variable with `variable_id`
|
/// \brief Assign operation sets an input value to the variable with `variable_id`
|
||||||
class NGRAPH_API Assign : public AssignBase
|
class NGRAPH_API Assign : public AssignBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
Assign() = default;
|
Assign() = default;
|
||||||
@ -40,10 +32,11 @@ namespace ngraph
|
|||||||
Assign(const Output<Node>& new_value, const std::string& variable_id);
|
Assign(const Output<Node>& new_value, const std::string& variable_id);
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
std::string get_variable_id() const override { return m_variable_id; }
|
std::string get_variable_id() const override {
|
||||||
|
return m_variable_id;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
@ -51,11 +44,9 @@ namespace ngraph
|
|||||||
std::string m_variable_id;
|
std::string m_variable_id;
|
||||||
};
|
};
|
||||||
} // namespace v3
|
} // namespace v3
|
||||||
namespace v6
|
namespace v6 {
|
||||||
{
|
|
||||||
/// \brief Assign operation sets an input value to the variable with `variable_id`
|
/// \brief Assign operation sets an input value to the variable with `variable_id`
|
||||||
class NGRAPH_API Assign : public AssignBase
|
class NGRAPH_API Assign : public AssignBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
Assign() = default;
|
Assign() = default;
|
||||||
@ -70,23 +61,19 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::string get_variable_id() const override
|
std::string get_variable_id() const override {
|
||||||
{
|
NGRAPH_CHECK(m_variable, "Variable is not initialized. Variable_id is unavailable");
|
||||||
NGRAPH_CHECK(m_variable,
|
|
||||||
"Variable is not initialized. Variable_id is unavailable");
|
|
||||||
return m_variable->get_info().variable_id;
|
return m_variable->get_info().variable_id;
|
||||||
}
|
}
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs,
|
||||||
const HostTensorVector& inputs,
|
const HostTensorVector& inputs,
|
||||||
const EvaluationContext& evaluation_context) const override;
|
const EvaluationContext& evaluation_context) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
bool constant_fold(OutputVector& output_values,
|
bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
|
||||||
const OutputVector& inputs_values) override;
|
|
||||||
};
|
};
|
||||||
} // namespace v6
|
} // namespace v6
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -8,16 +8,12 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise inverse tangent (arctan) operation.
|
/// \brief Elementwise inverse tangent (arctan) operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API Atan : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Atan : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
/// \brief Constructs an arctan operation.
|
/// \brief Constructs an arctan operation.
|
||||||
@ -32,11 +28,11 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
Atan(const Output<Node>& arg);
|
Atan(const Output<Node>& arg);
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
bool visit_attributes(AttributeVisitor&) override {
|
||||||
bool visit_attributes(AttributeVisitor&) override { return true; }
|
return true;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
}
|
||||||
const HostTensorVector& inputs) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -8,16 +8,12 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v3 {
|
||||||
{
|
|
||||||
namespace v3
|
|
||||||
{
|
|
||||||
/// \brief Elementwise inverse hyperbolic tangent operation.
|
/// \brief Elementwise inverse hyperbolic tangent operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API Atanh : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Atanh : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -32,11 +28,11 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
Atanh(const Output<Node>& arg);
|
Atanh(const Output<Node>& arg);
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
bool visit_attributes(AttributeVisitor&) override {
|
||||||
bool visit_attributes(AttributeVisitor&) override { return true; }
|
return true;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
}
|
||||||
const HostTensorVector& inputs) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v3
|
} // namespace v3
|
||||||
|
@ -7,16 +7,12 @@
|
|||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief Batched average pooling operation.
|
/// \brief Batched average pooling operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API AvgPool : public Op
|
class NGRAPH_API AvgPool : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -52,8 +48,7 @@ namespace ngraph
|
|||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
/// \return The kernel shape.
|
/// \return The kernel shape.
|
||||||
const Shape& get_kernel() const;
|
const Shape& get_kernel() const;
|
||||||
|
@ -10,14 +10,10 @@
|
|||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
class NGRAPH_API BatchNormInference : public Op {
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
class NGRAPH_API BatchNormInference : public Op
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
BatchNormInference() = default;
|
BatchNormInference() = default;
|
||||||
@ -38,10 +34,13 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
double get_eps_value() const { return m_epsilon; }
|
double get_eps_value() const {
|
||||||
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
|
return m_epsilon;
|
||||||
std::shared_ptr<Node>
|
}
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
void set_eps_value(double epsilon) {
|
||||||
|
m_epsilon = epsilon;
|
||||||
|
}
|
||||||
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr size_t INPUT_GAMMA = 0;
|
static constexpr size_t INPUT_GAMMA = 0;
|
||||||
@ -53,10 +52,8 @@ namespace ngraph
|
|||||||
double m_epsilon;
|
double m_epsilon;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
namespace v5
|
namespace v5 {
|
||||||
{
|
class NGRAPH_API BatchNormInference : public Op {
|
||||||
class NGRAPH_API BatchNormInference : public Op
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
BatchNormInference() = default;
|
BatchNormInference() = default;
|
||||||
@ -77,10 +74,13 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
double get_eps_value() const { return m_epsilon; }
|
double get_eps_value() const {
|
||||||
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
|
return m_epsilon;
|
||||||
std::shared_ptr<Node>
|
}
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
void set_eps_value(double epsilon) {
|
||||||
|
m_epsilon = epsilon;
|
||||||
|
}
|
||||||
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr size_t INPUT_DATA = 0;
|
static constexpr size_t INPUT_DATA = 0;
|
||||||
|
@ -7,12 +7,9 @@
|
|||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief BatchToSpace permutes data from the batch dimension of the data tensor into
|
/// \brief BatchToSpace permutes data from the batch dimension of the data tensor into
|
||||||
/// spatial dimensions.
|
/// spatial dimensions.
|
||||||
///
|
///
|
||||||
@ -24,8 +21,7 @@ namespace ngraph
|
|||||||
/// D_2 * block_shape[2] - crops_begin[2] - crops_end[2], ...,
|
/// D_2 * block_shape[2] - crops_begin[2] - crops_end[2], ...,
|
||||||
/// D_{N - 1} * block_shape[N - 1] - crops_begin[N - 1] - crops_end[N - 1]`
|
/// D_{N - 1} * block_shape[N - 1] - crops_begin[N - 1] - crops_end[N - 1]`
|
||||||
/// of the same type as `data` input.
|
/// of the same type as `data` input.
|
||||||
class NGRAPH_API BatchToSpace : public Op
|
class NGRAPH_API BatchToSpace : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
BatchToSpace() = default;
|
BatchToSpace() = default;
|
||||||
@ -41,13 +37,11 @@ namespace ngraph
|
|||||||
const Output<Node>& block_shape,
|
const Output<Node>& block_shape,
|
||||||
const Output<Node>& crops_begin,
|
const Output<Node>& crops_begin,
|
||||||
const Output<Node>& crops_end);
|
const Output<Node>& crops_end);
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
|
@ -8,19 +8,14 @@
|
|||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
class NGRAPH_API BinaryConvolution : public Op {
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
class NGRAPH_API BinaryConvolution : public Op
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
enum class BinaryConvolutionMode
|
enum class BinaryConvolutionMode {
|
||||||
{
|
|
||||||
// Interpret input data and kernel values: 0 as -1, 1 as 1
|
// Interpret input data and kernel values: 0 as -1, 1 as 1
|
||||||
XNOR_POPCOUNT
|
XNOR_POPCOUNT
|
||||||
};
|
};
|
||||||
@ -63,30 +58,57 @@ namespace ngraph
|
|||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
/// \return The strides.
|
/// \return The strides.
|
||||||
const Strides& get_strides() const { return m_strides; }
|
const Strides& get_strides() const {
|
||||||
void set_strides(const Strides& strides) { m_strides = strides; }
|
return m_strides;
|
||||||
|
}
|
||||||
|
void set_strides(const Strides& strides) {
|
||||||
|
m_strides = strides;
|
||||||
|
}
|
||||||
/// \return The dilations.
|
/// \return The dilations.
|
||||||
const Strides& get_dilations() const { return m_dilations; }
|
const Strides& get_dilations() const {
|
||||||
void set_dilations(const Strides& dilations) { m_dilations = dilations; }
|
return m_dilations;
|
||||||
|
}
|
||||||
|
void set_dilations(const Strides& dilations) {
|
||||||
|
m_dilations = dilations;
|
||||||
|
}
|
||||||
/// \return The padding-below sizes (possibly negative).
|
/// \return The padding-below sizes (possibly negative).
|
||||||
const CoordinateDiff& get_pads_begin() const { return m_pads_begin; }
|
const CoordinateDiff& get_pads_begin() const {
|
||||||
void set_pads_begin(const CoordinateDiff& pads_begin) { m_pads_begin = pads_begin; }
|
return m_pads_begin;
|
||||||
|
}
|
||||||
|
void set_pads_begin(const CoordinateDiff& pads_begin) {
|
||||||
|
m_pads_begin = pads_begin;
|
||||||
|
}
|
||||||
/// \return The padding-above sizes (possibly negative).
|
/// \return The padding-above sizes (possibly negative).
|
||||||
const CoordinateDiff& get_pads_end() const { return m_pads_end; }
|
const CoordinateDiff& get_pads_end() const {
|
||||||
void set_adding_above(const CoordinateDiff& pads_end) { m_pads_end = pads_end; }
|
return m_pads_end;
|
||||||
|
}
|
||||||
|
void set_adding_above(const CoordinateDiff& pads_end) {
|
||||||
|
m_pads_end = pads_end;
|
||||||
|
}
|
||||||
/// \return The pad type for convolution.
|
/// \return The pad type for convolution.
|
||||||
const PadType& get_auto_pad() const { return m_auto_pad; }
|
const PadType& get_auto_pad() const {
|
||||||
void set_auto_pad(const PadType& auto_pad) { m_auto_pad = auto_pad; }
|
return m_auto_pad;
|
||||||
|
}
|
||||||
|
void set_auto_pad(const PadType& auto_pad) {
|
||||||
|
m_auto_pad = auto_pad;
|
||||||
|
}
|
||||||
/// \return The mode of convolution.
|
/// \return The mode of convolution.
|
||||||
const BinaryConvolutionMode& get_mode() const { return m_mode; }
|
const BinaryConvolutionMode& get_mode() const {
|
||||||
void set_mode(const BinaryConvolutionMode& mode) { m_mode = mode; }
|
return m_mode;
|
||||||
|
}
|
||||||
|
void set_mode(const BinaryConvolutionMode& mode) {
|
||||||
|
m_mode = mode;
|
||||||
|
}
|
||||||
/// \return The pad value.
|
/// \return The pad value.
|
||||||
float get_pad_value() const { return m_pad_value; }
|
float get_pad_value() const {
|
||||||
void set_pad_value(float pad_value) { m_pad_value = pad_value; }
|
return m_pad_value;
|
||||||
|
}
|
||||||
|
void set_pad_value(float pad_value) {
|
||||||
|
m_pad_value = pad_value;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
BinaryConvolutionMode mode_from_string(const std::string& mode) const;
|
BinaryConvolutionMode mode_from_string(const std::string& mode) const;
|
||||||
@ -102,22 +124,20 @@ namespace ngraph
|
|||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::ostream& operator<<(std::ostream& s,
|
std::ostream& operator<<(std::ostream& s, const op::v1::BinaryConvolution::BinaryConvolutionMode& type);
|
||||||
const op::v1::BinaryConvolution::BinaryConvolutionMode& type);
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>
|
class NGRAPH_API AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>
|
||||||
: public EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>
|
: public EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(op::v1::BinaryConvolution::BinaryConvolutionMode& value)
|
AttributeAdapter(op::v1::BinaryConvolution::BinaryConvolutionMode& value)
|
||||||
: EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>(value)
|
: EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>(value) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>",
|
||||||
"AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>", 0};
|
0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -9,16 +9,12 @@
|
|||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
#include "ngraph/op/util/broadcast_base.hpp"
|
#include "ngraph/op/util/broadcast_base.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v3 {
|
||||||
{
|
|
||||||
namespace v3
|
|
||||||
{
|
|
||||||
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
|
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
|
||||||
/// input as needed along the new axes.
|
/// input as needed along the new axes.
|
||||||
class NGRAPH_API Broadcast : public util::BroadcastBase
|
class NGRAPH_API Broadcast : public util::BroadcastBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -54,13 +50,13 @@ namespace ngraph
|
|||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
// \return Broadcast Specification.
|
// \return Broadcast Specification.
|
||||||
const BroadcastModeSpec& get_broadcast_spec() const { return m_mode; }
|
const BroadcastModeSpec& get_broadcast_spec() const {
|
||||||
void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec)
|
return m_mode;
|
||||||
{
|
}
|
||||||
|
void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec) {
|
||||||
m_mode = broadcast_spec;
|
m_mode = broadcast_spec;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,22 +64,18 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \return true and the AxisSet if broadcast axes can be fully determined.
|
/// \return true and the AxisSet if broadcast axes can be fully determined.
|
||||||
std::pair<bool, AxisSet> get_broadcast_axes() const override;
|
std::pair<bool, AxisSet> get_broadcast_axes() const override;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool broadcast_evaluate(const HostTensorVector& outputs,
|
bool broadcast_evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
|
||||||
const HostTensorVector& inputs) const;
|
|
||||||
};
|
};
|
||||||
} // namespace v3
|
} // namespace v3
|
||||||
|
|
||||||
namespace v1
|
namespace v1 {
|
||||||
{
|
|
||||||
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
|
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
|
||||||
/// input as needed along the new axes.
|
/// input as needed along the new axes.
|
||||||
class NGRAPH_API Broadcast : public util::BroadcastBase
|
class NGRAPH_API Broadcast : public util::BroadcastBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -115,24 +107,22 @@ namespace ngraph
|
|||||||
/// axes
|
/// axes
|
||||||
Broadcast(const Output<Node>& arg,
|
Broadcast(const Output<Node>& arg,
|
||||||
const Output<Node>& target_shape,
|
const Output<Node>& target_shape,
|
||||||
const AutoBroadcastSpec& broadcast_spec =
|
const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
||||||
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
/// \return Broadcast Specification.
|
/// \return Broadcast Specification.
|
||||||
const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
|
const AutoBroadcastSpec& get_broadcast_spec() const {
|
||||||
void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
|
return m_broadcast_spec;
|
||||||
{
|
}
|
||||||
|
void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec) {
|
||||||
m_broadcast_spec = broadcast_spec;
|
m_broadcast_spec = broadcast_spec;
|
||||||
}
|
}
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -6,15 +6,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v3 {
|
||||||
{
|
|
||||||
namespace v3
|
|
||||||
{
|
|
||||||
/// \brief Operation that bucketizes the input based on boundaries
|
/// \brief Operation that bucketizes the input based on boundaries
|
||||||
class NGRAPH_API Bucketize : public Op
|
class NGRAPH_API Bucketize : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -34,17 +30,21 @@ namespace ngraph
|
|||||||
virtual void validate_and_infer_types() override;
|
virtual void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
|
||||||
clone_with_new_inputs(const OutputVector& inputs) const override;
|
|
||||||
|
|
||||||
element::Type get_output_type() const { return m_output_type; }
|
element::Type get_output_type() const {
|
||||||
void set_output_type(element::Type output_type) { m_output_type = output_type; }
|
return m_output_type;
|
||||||
|
}
|
||||||
|
void set_output_type(element::Type output_type) {
|
||||||
|
m_output_type = output_type;
|
||||||
|
}
|
||||||
// Overload collision with method on Node
|
// Overload collision with method on Node
|
||||||
using Node::set_output_type;
|
using Node::set_output_type;
|
||||||
|
|
||||||
bool get_with_right_bound() const { return m_with_right_bound; }
|
bool get_with_right_bound() const {
|
||||||
void set_with_right_bound(bool with_right_bound)
|
return m_with_right_bound;
|
||||||
{
|
}
|
||||||
|
void set_with_right_bound(bool with_right_bound) {
|
||||||
m_with_right_bound = with_right_bound;
|
m_with_right_bound = with_right_bound;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,15 +6,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise ceiling operation.
|
/// \brief Elementwise ceiling operation.
|
||||||
class NGRAPH_API Ceiling : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Ceiling : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
/// \brief Constructs a ceiling operation.
|
/// \brief Constructs a ceiling operation.
|
||||||
@ -24,11 +20,11 @@ namespace ngraph
|
|||||||
/// \param arg Node that produces the input tensor.
|
/// \param arg Node that produces the input tensor.
|
||||||
Ceiling(const Output<Node>& arg);
|
Ceiling(const Output<Node>& arg);
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor&) override { return true; }
|
bool visit_attributes(AttributeVisitor&) override {
|
||||||
virtual std::shared_ptr<Node>
|
return true;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
}
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -7,19 +7,15 @@
|
|||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Performs a clipping operation on all elements of the input node
|
/// \brief Performs a clipping operation on all elements of the input node
|
||||||
///
|
///
|
||||||
/// All input values that are outside of the <min;max> range are set to 'min' or 'max'
|
/// All input values that are outside of the <min;max> range are set to 'min' or 'max'
|
||||||
/// depending on which side of the <min;max> range they are. The values that fall into
|
/// depending on which side of the <min;max> range they are. The values that fall into
|
||||||
/// this range remain unchanged.
|
/// this range remain unchanged.
|
||||||
class NGRAPH_API Clamp : public ngraph::op::Op
|
class NGRAPH_API Clamp : public ngraph::op::Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -33,15 +29,17 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
double get_min() const { return m_min; }
|
double get_min() const {
|
||||||
double get_max() const { return m_max; }
|
return m_min;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
}
|
||||||
const HostTensorVector& inputs) const override;
|
double get_max() const {
|
||||||
|
return m_max;
|
||||||
|
}
|
||||||
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -8,15 +8,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Concatenation operation.
|
/// \brief Concatenation operation.
|
||||||
class NGRAPH_API Concat : public Op
|
class NGRAPH_API Concat : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -37,20 +33,23 @@ namespace ngraph
|
|||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
/// \return The concatenation axis.
|
/// \return The concatenation axis.
|
||||||
int64_t get_concatenation_axis() const { return m_concat_axis; }
|
int64_t get_concatenation_axis() const {
|
||||||
void set_concatenation_axis(int64_t concatenation_axis)
|
return m_concat_axis;
|
||||||
{
|
}
|
||||||
|
void set_concatenation_axis(int64_t concatenation_axis) {
|
||||||
m_concat_axis = concatenation_axis;
|
m_concat_axis = concatenation_axis;
|
||||||
}
|
}
|
||||||
/// \return The concatenation axis.
|
/// \return The concatenation axis.
|
||||||
int64_t get_axis() const { return m_axis; }
|
int64_t get_axis() const {
|
||||||
void set_axis(int64_t axis) { m_axis = axis; }
|
return m_axis;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
}
|
||||||
const HostTensorVector& inputs) const override;
|
void set_axis(int64_t axis) {
|
||||||
|
m_axis = axis;
|
||||||
|
}
|
||||||
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
bool evaluate_lower(const HostTensorVector& output_values) const override;
|
bool evaluate_lower(const HostTensorVector& output_values) const override;
|
||||||
bool evaluate_upper(const HostTensorVector& output_values) const override;
|
bool evaluate_upper(const HostTensorVector& output_values) const override;
|
||||||
|
@ -16,15 +16,11 @@
|
|||||||
#include "ngraph/type/element_type_traits.hpp"
|
#include "ngraph/type/element_type_traits.hpp"
|
||||||
#include "ngraph/util.hpp"
|
#include "ngraph/util.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Class for constants.
|
/// \brief Class for constants.
|
||||||
class NGRAPH_API Constant : public Op
|
class NGRAPH_API Constant : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -41,13 +37,8 @@ namespace ngraph
|
|||||||
/// \param values A vector of literals for initializing the tensor constant. The
|
/// \param values A vector of literals for initializing the tensor constant. The
|
||||||
/// size of values must match the size of the shape.
|
/// size of values must match the size of the shape.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Constant(const element::Type& type,
|
Constant(const element::Type& type, const Shape& shape, const std::vector<T>& values) : Constant(type, shape) {
|
||||||
const Shape& shape,
|
NODE_VALIDATION_CHECK(this,
|
||||||
const std::vector<T>& values)
|
|
||||||
: Constant(type, shape)
|
|
||||||
{
|
|
||||||
NODE_VALIDATION_CHECK(
|
|
||||||
this,
|
|
||||||
values.size() == 1 || values.size() == shape_size(m_shape),
|
values.size() == 1 || values.size() == shape_size(m_shape),
|
||||||
"Did not get the expected number of literals for a constant of shape ",
|
"Did not get the expected number of literals for a constant of shape ",
|
||||||
m_shape,
|
m_shape,
|
||||||
@ -58,12 +49,9 @@ namespace ngraph
|
|||||||
shape_size(m_shape),
|
shape_size(m_shape),
|
||||||
").");
|
").");
|
||||||
|
|
||||||
if (values.size() == 1)
|
if (values.size() == 1) {
|
||||||
{
|
|
||||||
fill_data(type, values.front());
|
fill_data(type, values.front());
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
write_values(values);
|
write_values(values);
|
||||||
}
|
}
|
||||||
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
|
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
|
||||||
@ -77,44 +65,72 @@ namespace ngraph
|
|||||||
/// \param shape The shape of the tensor constant.
|
/// \param shape The shape of the tensor constant.
|
||||||
/// \param value A scalar for initializing the uniform tensor constant. The
|
/// \param value A scalar for initializing the uniform tensor constant. The
|
||||||
/// value is broadcast to the specified shape.
|
/// value is broadcast to the specified shape.
|
||||||
template <class T,
|
template <class T, class = typename std::enable_if<std::is_fundamental<T>::value>::type>
|
||||||
class = typename std::enable_if<std::is_fundamental<T>::value>::type>
|
Constant(const element::Type& type, const Shape& shape, T value) : Constant(type, shape) {
|
||||||
Constant(const element::Type& type, const Shape& shape, T value)
|
|
||||||
: Constant(type, shape)
|
|
||||||
{
|
|
||||||
fill_data(type, value);
|
fill_data(type, value);
|
||||||
m_all_elements_bitwise_identical = true;
|
m_all_elements_bitwise_identical = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void fill_data(const element::Type& type, T value)
|
void fill_data(const element::Type& type, T value) {
|
||||||
{
|
|
||||||
using Type_t = element::Type_t;
|
using Type_t = element::Type_t;
|
||||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||||
# pragma GCC diagnostic push
|
# pragma GCC diagnostic push
|
||||||
# pragma GCC diagnostic error "-Wswitch"
|
# pragma GCC diagnostic error "-Wswitch"
|
||||||
# pragma GCC diagnostic error "-Wswitch-enum"
|
# pragma GCC diagnostic error "-Wswitch-enum"
|
||||||
#endif
|
#endif
|
||||||
switch (type)
|
switch (type) {
|
||||||
{
|
case Type_t::boolean:
|
||||||
case Type_t::boolean: fill_data<Type_t::boolean>(value); break;
|
fill_data<Type_t::boolean>(value);
|
||||||
case Type_t::bf16: fill_data<Type_t::bf16>(value); break;
|
break;
|
||||||
case Type_t::f16: fill_data<Type_t::f16>(value); break;
|
case Type_t::bf16:
|
||||||
case Type_t::f32: fill_data<Type_t::f32>(value); break;
|
fill_data<Type_t::bf16>(value);
|
||||||
case Type_t::f64: fill_data<Type_t::f64>(value); break;
|
break;
|
||||||
case Type_t::i4: fill_data<Type_t::i4>(value); break;
|
case Type_t::f16:
|
||||||
case Type_t::i8: fill_data<Type_t::i8>(value); break;
|
fill_data<Type_t::f16>(value);
|
||||||
case Type_t::i16: fill_data<Type_t::i16>(value); break;
|
break;
|
||||||
case Type_t::i32: fill_data<Type_t::i32>(value); break;
|
case Type_t::f32:
|
||||||
case Type_t::i64: fill_data<Type_t::i64>(value); break;
|
fill_data<Type_t::f32>(value);
|
||||||
case Type_t::u1: fill_data<Type_t::u1>(value); break;
|
break;
|
||||||
case Type_t::u4: fill_data<Type_t::u4>(value); break;
|
case Type_t::f64:
|
||||||
case Type_t::u8: fill_data<Type_t::u8>(value); break;
|
fill_data<Type_t::f64>(value);
|
||||||
case Type_t::u16: fill_data<Type_t::u16>(value); break;
|
break;
|
||||||
case Type_t::u32: fill_data<Type_t::u32>(value); break;
|
case Type_t::i4:
|
||||||
case Type_t::u64: fill_data<Type_t::u64>(value); break;
|
fill_data<Type_t::i4>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::i8:
|
||||||
|
fill_data<Type_t::i8>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::i16:
|
||||||
|
fill_data<Type_t::i16>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::i32:
|
||||||
|
fill_data<Type_t::i32>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::i64:
|
||||||
|
fill_data<Type_t::i64>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::u1:
|
||||||
|
fill_data<Type_t::u1>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::u4:
|
||||||
|
fill_data<Type_t::u4>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::u8:
|
||||||
|
fill_data<Type_t::u8>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::u16:
|
||||||
|
fill_data<Type_t::u16>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::u32:
|
||||||
|
fill_data<Type_t::u32>(value);
|
||||||
|
break;
|
||||||
|
case Type_t::u64:
|
||||||
|
fill_data<Type_t::u64>(value);
|
||||||
|
break;
|
||||||
case Type_t::undefined:
|
case Type_t::undefined:
|
||||||
case Type_t::dynamic: throw std::runtime_error("unsupported type");
|
case Type_t::dynamic:
|
||||||
|
throw std::runtime_error("unsupported type");
|
||||||
}
|
}
|
||||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||||
# pragma GCC diagnostic pop
|
# pragma GCC diagnostic pop
|
||||||
@ -127,9 +143,7 @@ namespace ngraph
|
|||||||
/// \param type The element type of the tensor constant.
|
/// \param type The element type of the tensor constant.
|
||||||
/// \param shape The shape of the tensor constant.
|
/// \param shape The shape of the tensor constant.
|
||||||
/// \param values A list of string values to use as the constant data.
|
/// \param values A list of string values to use as the constant data.
|
||||||
Constant(const element::Type& type,
|
Constant(const element::Type& type, const Shape& shape, const std::vector<std::string>& values);
|
||||||
const Shape& shape,
|
|
||||||
const std::vector<std::string>& values);
|
|
||||||
|
|
||||||
/// \brief Constructs a tensor constant with the supplied data
|
/// \brief Constructs a tensor constant with the supplied data
|
||||||
///
|
///
|
||||||
@ -144,12 +158,9 @@ namespace ngraph
|
|||||||
/// \param shape The shape of the tensor constant.
|
/// \param shape The shape of the tensor constant.
|
||||||
/// \param data A pointer to pre-allocated shared data.
|
/// \param data A pointer to pre-allocated shared data.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Constant(const element::Type& type,
|
Constant(const element::Type& type, const Shape& shape, std::shared_ptr<runtime::SharedBuffer<T>> data)
|
||||||
const Shape& shape,
|
: m_element_type(type),
|
||||||
std::shared_ptr<runtime::SharedBuffer<T>> data)
|
m_shape(shape) {
|
||||||
: m_element_type(type)
|
|
||||||
, m_shape(shape)
|
|
||||||
{
|
|
||||||
m_data = data;
|
m_data = data;
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
@ -160,23 +171,20 @@ namespace ngraph
|
|||||||
|
|
||||||
virtual ~Constant() override;
|
virtual ~Constant() override;
|
||||||
|
|
||||||
void validate_and_infer_types() override
|
void validate_and_infer_types() override {
|
||||||
{
|
|
||||||
infer_element_type();
|
infer_element_type();
|
||||||
set_output_type(0, m_element_type, m_shape);
|
set_output_type(0, m_element_type, m_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
||||||
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
||||||
|
|
||||||
// Don't constant fold a constant; it would make a copy
|
// Don't constant fold a constant; it would make a copy
|
||||||
bool constant_fold(OutputVector& outputs, const OutputVector& inputs) override
|
bool constant_fold(OutputVector& outputs, const OutputVector& inputs) override {
|
||||||
{
|
|
||||||
(void)outputs;
|
(void)outputs;
|
||||||
(void)inputs;
|
(void)inputs;
|
||||||
return false;
|
return false;
|
||||||
@ -227,8 +235,7 @@ namespace ngraph
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
static std::shared_ptr<Constant> create(const element::Type& type,
|
static std::shared_ptr<Constant> create(const element::Type& type,
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const std::vector<T>& values)
|
const std::vector<T>& values) {
|
||||||
{
|
|
||||||
return std::make_shared<Constant>(type, shape, values);
|
return std::make_shared<Constant>(type, shape, values);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,8 +247,7 @@ namespace ngraph
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
static std::shared_ptr<Constant> create(const element::Type& type,
|
static std::shared_ptr<Constant> create(const element::Type& type,
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
std::initializer_list<T> values)
|
std::initializer_list<T> values) {
|
||||||
{
|
|
||||||
return std::make_shared<Constant>(type, shape, std::vector<T>{values});
|
return std::make_shared<Constant>(type, shape, std::vector<T>{values});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,21 +256,17 @@ namespace ngraph
|
|||||||
/// \param type The element type of the tensor constant.
|
/// \param type The element type of the tensor constant.
|
||||||
/// \param shape The shape of the tensor constant.
|
/// \param shape The shape of the tensor constant.
|
||||||
/// \param memory An continues memory chunk which contains the constant data.
|
/// \param memory An continues memory chunk which contains the constant data.
|
||||||
static std::shared_ptr<Constant>
|
static std::shared_ptr<Constant> create(const element::Type& type, const Shape& shape, const void* memory) {
|
||||||
create(const element::Type& type, const Shape& shape, const void* memory)
|
|
||||||
{
|
|
||||||
return std::make_shared<Constant>(type, shape, memory);
|
return std::make_shared<Constant>(type, shape, memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
/// \return The initialization literals for the tensor constant.
|
/// \return The initialization literals for the tensor constant.
|
||||||
std::vector<std::string> get_value_strings() const;
|
std::vector<std::string> get_value_strings() const;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> get_vector() const
|
std::vector<T> get_vector() const {
|
||||||
{
|
|
||||||
const T* p = get_data_ptr<T>();
|
const T* p = get_data_ptr<T>();
|
||||||
if (p == nullptr)
|
if (p == nullptr)
|
||||||
throw std::runtime_error("Cannot create vector! Buffer is not allocated.");
|
throw std::runtime_error("Cannot create vector! Buffer is not allocated.");
|
||||||
@ -276,8 +278,7 @@ namespace ngraph
|
|||||||
/// \tparam T Type to which data vector's entries will be cast.
|
/// \tparam T Type to which data vector's entries will be cast.
|
||||||
/// \return Constant's data vector.
|
/// \return Constant's data vector.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> cast_vector() const
|
std::vector<T> cast_vector() const {
|
||||||
{
|
|
||||||
auto source_type = get_element_type();
|
auto source_type = get_element_type();
|
||||||
std::vector<T> rc;
|
std::vector<T> rc;
|
||||||
using Type_t = element::Type_t;
|
using Type_t = element::Type_t;
|
||||||
@ -285,25 +286,57 @@ namespace ngraph
|
|||||||
# pragma warning(push)
|
# pragma warning(push)
|
||||||
# pragma warning(disable : 4244)
|
# pragma warning(disable : 4244)
|
||||||
#endif
|
#endif
|
||||||
switch (source_type)
|
switch (source_type) {
|
||||||
{
|
case Type_t::boolean:
|
||||||
case Type_t::boolean: cast_vector<Type_t::boolean>(rc); break;
|
cast_vector<Type_t::boolean>(rc);
|
||||||
case Type_t::bf16: cast_vector<Type_t::bf16>(rc); break;
|
break;
|
||||||
case Type_t::f16: cast_vector<Type_t::f16>(rc); break;
|
case Type_t::bf16:
|
||||||
case Type_t::f32: cast_vector<Type_t::f32>(rc); break;
|
cast_vector<Type_t::bf16>(rc);
|
||||||
case Type_t::f64: cast_vector<Type_t::f64>(rc); break;
|
break;
|
||||||
case Type_t::i4: cast_vector<Type_t::i4>(rc); break;
|
case Type_t::f16:
|
||||||
case Type_t::i8: cast_vector<Type_t::i8>(rc); break;
|
cast_vector<Type_t::f16>(rc);
|
||||||
case Type_t::i16: cast_vector<Type_t::i16>(rc); break;
|
break;
|
||||||
case Type_t::i32: cast_vector<Type_t::i32>(rc); break;
|
case Type_t::f32:
|
||||||
case Type_t::i64: cast_vector<Type_t::i64>(rc); break;
|
cast_vector<Type_t::f32>(rc);
|
||||||
case Type_t::u1: cast_vector<Type_t::u1>(rc); break;
|
break;
|
||||||
case Type_t::u4: cast_vector<Type_t::u4>(rc); break;
|
case Type_t::f64:
|
||||||
case Type_t::u8: cast_vector<Type_t::u8>(rc); break;
|
cast_vector<Type_t::f64>(rc);
|
||||||
case Type_t::u16: cast_vector<Type_t::u16>(rc); break;
|
break;
|
||||||
case Type_t::u32: cast_vector<Type_t::u32>(rc); break;
|
case Type_t::i4:
|
||||||
case Type_t::u64: cast_vector<Type_t::u64>(rc); break;
|
cast_vector<Type_t::i4>(rc);
|
||||||
default: throw std::runtime_error("unsupported type");
|
break;
|
||||||
|
case Type_t::i8:
|
||||||
|
cast_vector<Type_t::i8>(rc);
|
||||||
|
break;
|
||||||
|
case Type_t::i16:
|
||||||
|
cast_vector<Type_t::i16>(rc);
|
||||||
|
break;
|
||||||
|
case Type_t::i32:
|
||||||
|
cast_vector<Type_t::i32>(rc);
|
||||||
|
break;
|
||||||
|
case Type_t::i64:
|
||||||
|
cast_vector<Type_t::i64>(rc);
|
||||||
|
break;
|
||||||
|
case Type_t::u1:
|
||||||
|
cast_vector<Type_t::u1>(rc);
|
||||||
|
break;
|
||||||
|
case Type_t::u4:
|
||||||
|
cast_vector<Type_t::u4>(rc);
|
||||||
|
break;
|
||||||
|
case Type_t::u8:
|
||||||
|
cast_vector<Type_t::u8>(rc);
|
||||||
|
break;
|
||||||
|
case Type_t::u16:
|
||||||
|
cast_vector<Type_t::u16>(rc);
|
||||||
|
break;
|
||||||
|
case Type_t::u32:
|
||||||
|
cast_vector<Type_t::u32>(rc);
|
||||||
|
break;
|
||||||
|
case Type_t::u64:
|
||||||
|
cast_vector<Type_t::u64>(rc);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("unsupported type");
|
||||||
}
|
}
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
# pragma warning(pop)
|
# pragma warning(pop)
|
||||||
@ -311,12 +344,12 @@ namespace ngraph
|
|||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
const void* get_data_ptr() const { return (m_data ? m_data->get_ptr() : nullptr); }
|
const void* get_data_ptr() const {
|
||||||
|
return (m_data ? m_data->get_ptr() : nullptr);
|
||||||
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const T* get_data_ptr() const
|
const T* get_data_ptr() const {
|
||||||
{
|
if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0) {
|
||||||
if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0)
|
|
||||||
{
|
|
||||||
throw ngraph_error("Buffer over-read");
|
throw ngraph_error("Buffer over-read");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -324,16 +357,12 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <element::Type_t ET>
|
template <element::Type_t ET>
|
||||||
const typename element_type_traits<ET>::value_type* get_data_ptr() const
|
const typename element_type_traits<ET>::value_type* get_data_ptr() const {
|
||||||
{
|
NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr() called for incorrect element type.");
|
||||||
NGRAPH_CHECK(ET == get_element_type(),
|
return static_cast<const typename element_type_traits<ET>::value_type*>(get_data_ptr());
|
||||||
"get_data_ptr() called for incorrect element type.");
|
|
||||||
return static_cast<const typename element_type_traits<ET>::value_type*>(
|
|
||||||
get_data_ptr());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool get_all_data_elements_bitwise_identical() const
|
bool get_all_data_elements_bitwise_identical() const {
|
||||||
{
|
|
||||||
return m_all_elements_bitwise_identical;
|
return m_all_elements_bitwise_identical;
|
||||||
}
|
}
|
||||||
std::string convert_value_to_string(size_t index) const;
|
std::string convert_value_to_string(size_t index) const;
|
||||||
@ -341,46 +370,39 @@ namespace ngraph
|
|||||||
/**
|
/**
|
||||||
* \brief Allows to avoid buffer allocation on the visit_attributes call
|
* \brief Allows to avoid buffer allocation on the visit_attributes call
|
||||||
*/
|
*/
|
||||||
void alloc_buffer_on_visit_attributes(bool val)
|
void alloc_buffer_on_visit_attributes(bool val) {
|
||||||
{
|
|
||||||
m_alloc_buffer_on_visit_attributes = val;
|
m_alloc_buffer_on_visit_attributes = val;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type != element::Type_t::u1 &&
|
typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
|
||||||
Type != element::Type_t::u4 &&
|
|
||||||
Type != element::Type_t::i4,
|
Type != element::Type_t::i4,
|
||||||
bool>::type = true>
|
bool>::type = true>
|
||||||
StorageDataType get_element_value(size_t index) const
|
StorageDataType get_element_value(size_t index) const {
|
||||||
{
|
|
||||||
return get_data_ptr<Type>()[index];
|
return get_data_ptr<Type>()[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
|
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
|
||||||
StorageDataType get_element_value(size_t index) const
|
StorageDataType get_element_value(size_t index) const {
|
||||||
{
|
|
||||||
return (get_data_ptr<uint8_t>()[index / 8] >> (7 - (index % 8))) & 1;
|
return (get_data_ptr<uint8_t>()[index / 8] >> (7 - (index % 8))) & 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type == element::Type_t::u4, bool>::type = true>
|
typename std::enable_if<Type == element::Type_t::u4, bool>::type = true>
|
||||||
StorageDataType get_element_value(size_t index) const
|
StorageDataType get_element_value(size_t index) const {
|
||||||
{
|
|
||||||
return (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
|
return (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type == element::Type_t::i4, bool>::type = true>
|
typename std::enable_if<Type == element::Type_t::i4, bool>::type = true>
|
||||||
StorageDataType get_element_value(size_t index) const
|
StorageDataType get_element_value(size_t index) const {
|
||||||
{
|
const uint8_t i4data = (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
|
||||||
const uint8_t i4data =
|
|
||||||
(get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
|
|
||||||
const bool is_negative_number = (i4data >> 3) & 0x01;
|
const bool is_negative_number = (i4data >> 3) & 0x01;
|
||||||
const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
|
const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
|
||||||
return data;
|
return data;
|
||||||
@ -388,12 +410,10 @@ namespace ngraph
|
|||||||
|
|
||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename OUT_T,
|
typename OUT_T,
|
||||||
typename std::enable_if<Type != element::Type_t::u1 &&
|
typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
|
||||||
Type != element::Type_t::u4 &&
|
|
||||||
Type != element::Type_t::i4,
|
Type != element::Type_t::i4,
|
||||||
bool>::type = true>
|
bool>::type = true>
|
||||||
void cast_vector(std::vector<OUT_T>& output_vector) const
|
void cast_vector(std::vector<OUT_T>& output_vector) const {
|
||||||
{
|
|
||||||
// this function is workaround for waring during windows building
|
// this function is workaround for waring during windows building
|
||||||
// build complains for vector creation based on iterators
|
// build complains for vector creation based on iterators
|
||||||
// which point on different type than destination vector::value_type
|
// which point on different type than destination vector::value_type
|
||||||
@ -401,28 +421,23 @@ namespace ngraph
|
|||||||
auto source_vector = get_vector<IN_T>();
|
auto source_vector = get_vector<IN_T>();
|
||||||
output_vector.reserve(source_vector.size());
|
output_vector.reserve(source_vector.size());
|
||||||
|
|
||||||
std::transform(source_vector.begin(),
|
std::transform(source_vector.begin(), source_vector.end(), std::back_inserter(output_vector), [](IN_T c) {
|
||||||
source_vector.end(),
|
return static_cast<OUT_T>(c);
|
||||||
std::back_inserter(output_vector),
|
});
|
||||||
[](IN_T c) { return static_cast<OUT_T>(c); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename OUT_T,
|
typename OUT_T,
|
||||||
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
|
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
|
||||||
void cast_vector(std::vector<OUT_T>& output) const
|
void cast_vector(std::vector<OUT_T>& output) const {
|
||||||
{
|
|
||||||
using IN_T = fundamental_type_for<Type>;
|
using IN_T = fundamental_type_for<Type>;
|
||||||
const auto element_number = shape_size(m_shape);
|
const auto element_number = shape_size(m_shape);
|
||||||
const auto source_begin = get_data_ptr<uint8_t>();
|
const auto source_begin = get_data_ptr<uint8_t>();
|
||||||
const auto source_end = std::next(source_begin, (element_number + 7) / 8);
|
const auto source_end = std::next(source_begin, (element_number + 7) / 8);
|
||||||
const auto round_element_no = element_number % 8
|
const auto round_element_no = element_number % 8 ? element_number - element_number % 8 + 8 : element_number;
|
||||||
? element_number - element_number % 8 + 8
|
|
||||||
: element_number;
|
|
||||||
output.reserve(round_element_no); // adds 7 more elements here?
|
output.reserve(round_element_no); // adds 7 more elements here?
|
||||||
std::for_each(source_begin, source_end, [&](IN_T c) {
|
std::for_each(source_begin, source_end, [&](IN_T c) {
|
||||||
for (const auto i : {7, 6, 5, 4, 3, 2, 1, 0})
|
for (const auto i : {7, 6, 5, 4, 3, 2, 1, 0}) {
|
||||||
{
|
|
||||||
const uint8_t data = (c >> i) & 0x01;
|
const uint8_t data = (c >> i) & 0x01;
|
||||||
output.push_back(data);
|
output.push_back(data);
|
||||||
}
|
}
|
||||||
@ -433,18 +448,15 @@ namespace ngraph
|
|||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename OUT_T,
|
typename OUT_T,
|
||||||
typename std::enable_if<Type == element::Type_t::u4, bool>::type = true>
|
typename std::enable_if<Type == element::Type_t::u4, bool>::type = true>
|
||||||
void cast_vector(std::vector<OUT_T>& output) const
|
void cast_vector(std::vector<OUT_T>& output) const {
|
||||||
{
|
|
||||||
using IN_T = fundamental_type_for<Type>;
|
using IN_T = fundamental_type_for<Type>;
|
||||||
const auto element_number = shape_size(m_shape);
|
const auto element_number = shape_size(m_shape);
|
||||||
const auto source_begin = get_data_ptr<uint8_t>();
|
const auto source_begin = get_data_ptr<uint8_t>();
|
||||||
const auto source_end = std::next(source_begin, (element_number + 1) / 2);
|
const auto source_end = std::next(source_begin, (element_number + 1) / 2);
|
||||||
const auto round_element_no =
|
const auto round_element_no = element_number % 2 ? element_number + 1 : element_number;
|
||||||
element_number % 2 ? element_number + 1 : element_number;
|
|
||||||
output.reserve(round_element_no); // adds 1 more elements here?
|
output.reserve(round_element_no); // adds 1 more elements here?
|
||||||
std::for_each(source_begin, source_end, [&](IN_T c) {
|
std::for_each(source_begin, source_end, [&](IN_T c) {
|
||||||
for (const auto i : {4, 0})
|
for (const auto i : {4, 0}) {
|
||||||
{
|
|
||||||
const uint8_t data = (c >> i) & 0x0F;
|
const uint8_t data = (c >> i) & 0x0F;
|
||||||
output.push_back(data);
|
output.push_back(data);
|
||||||
}
|
}
|
||||||
@ -454,18 +466,15 @@ namespace ngraph
|
|||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename OUT_T,
|
typename OUT_T,
|
||||||
typename std::enable_if<Type == element::Type_t::i4, bool>::type = true>
|
typename std::enable_if<Type == element::Type_t::i4, bool>::type = true>
|
||||||
void cast_vector(std::vector<OUT_T>& output) const
|
void cast_vector(std::vector<OUT_T>& output) const {
|
||||||
{
|
|
||||||
using IN_T = fundamental_type_for<Type>;
|
using IN_T = fundamental_type_for<Type>;
|
||||||
const auto element_number = shape_size(m_shape);
|
const auto element_number = shape_size(m_shape);
|
||||||
const auto source_begin = get_data_ptr<uint8_t>();
|
const auto source_begin = get_data_ptr<uint8_t>();
|
||||||
const auto source_end = std::next(source_begin, (element_number + 1) / 2);
|
const auto source_end = std::next(source_begin, (element_number + 1) / 2);
|
||||||
const auto round_element_no =
|
const auto round_element_no = element_number % 2 ? element_number + 1 : element_number;
|
||||||
element_number % 2 ? element_number + 1 : element_number;
|
|
||||||
output.reserve(round_element_no); // adds 1 more elements here?
|
output.reserve(round_element_no); // adds 1 more elements here?
|
||||||
std::for_each(source_begin, source_end, [&](IN_T c) {
|
std::for_each(source_begin, source_end, [&](IN_T c) {
|
||||||
for (const auto i : {4, 0})
|
for (const auto i : {4, 0}) {
|
||||||
{
|
|
||||||
const uint8_t i4data = (c >> i) & 0x0F;
|
const uint8_t i4data = (c >> i) & 0x0F;
|
||||||
const bool is_negative_number = (i4data >> 3) & 0x01;
|
const bool is_negative_number = (i4data >> 3) & 0x01;
|
||||||
const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
|
const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
|
||||||
@ -478,12 +487,10 @@ namespace ngraph
|
|||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename T,
|
typename T,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type != element::Type_t::u1 &&
|
typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
|
||||||
Type != element::Type_t::u4 &&
|
|
||||||
Type != element::Type_t::i4,
|
Type != element::Type_t::i4,
|
||||||
bool>::type = true>
|
bool>::type = true>
|
||||||
void fill_data(const T& value)
|
void fill_data(const T& value) {
|
||||||
{
|
|
||||||
const auto size = shape_size(m_shape);
|
const auto size = shape_size(m_shape);
|
||||||
const auto v = static_cast<StorageDataType>(value);
|
const auto v = static_cast<StorageDataType>(value);
|
||||||
std::fill_n(get_data_ptr_nc<Type>(), size, v);
|
std::fill_n(get_data_ptr_nc<Type>(), size, v);
|
||||||
@ -493,8 +500,7 @@ namespace ngraph
|
|||||||
typename T,
|
typename T,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
|
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
|
||||||
void fill_data(const T& value)
|
void fill_data(const T& value) {
|
||||||
{
|
|
||||||
const StorageDataType v = value ? 0xFF : 0x00;
|
const StorageDataType v = value ? 0xFF : 0x00;
|
||||||
std::fill_n(get_data_ptr_nc<Type>(), mem_size(), v);
|
std::fill_n(get_data_ptr_nc<Type>(), mem_size(), v);
|
||||||
}
|
}
|
||||||
@ -502,11 +508,8 @@ namespace ngraph
|
|||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename T,
|
typename T,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type == element::Type_t::u4 ||
|
typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4, bool>::type = true>
|
||||||
Type == element::Type_t::i4,
|
void fill_data(const T& value) {
|
||||||
bool>::type = true>
|
|
||||||
void fill_data(const T& value)
|
|
||||||
{
|
|
||||||
uint8_t v = value_in_range<Type>(value);
|
uint8_t v = value_in_range<Type>(value);
|
||||||
v &= 0x0F;
|
v &= 0x0F;
|
||||||
v += v << 4;
|
v += v << 4;
|
||||||
@ -515,42 +518,33 @@ namespace ngraph
|
|||||||
|
|
||||||
void allocate_buffer();
|
void allocate_buffer();
|
||||||
|
|
||||||
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
|
void* get_data_ptr_nc() {
|
||||||
|
return (m_data ? m_data->get_ptr() : nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
template <element::Type_t ET>
|
template <element::Type_t ET>
|
||||||
typename element_type_traits<ET>::value_type* get_data_ptr_nc()
|
typename element_type_traits<ET>::value_type* get_data_ptr_nc() {
|
||||||
{
|
NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr_nc() called for incorrect element type.");
|
||||||
NGRAPH_CHECK(ET == get_element_type(),
|
return static_cast<typename element_type_traits<ET>::value_type*>(get_data_ptr_nc());
|
||||||
"get_data_ptr_nc() called for incorrect element type.");
|
|
||||||
return static_cast<typename element_type_traits<ET>::value_type*>(
|
|
||||||
get_data_ptr_nc());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Constant(const OutputVector& args)
|
Constant(const OutputVector& args) : Op(args), m_shape({}) {}
|
||||||
: Op(args)
|
|
||||||
, m_shape({})
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void infer_element_type() {}
|
virtual void infer_element_type() {}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void write_values(const std::vector<T>& values)
|
void write_values(const std::vector<T>& values) {
|
||||||
{
|
|
||||||
write_to_buffer(values);
|
write_to_buffer(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename T,
|
typename T,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type != element::Type_t::u1 &&
|
typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
|
||||||
Type != element::Type_t::u4 &&
|
|
||||||
Type != element::Type_t::i4,
|
Type != element::Type_t::i4,
|
||||||
bool>::type = true>
|
bool>::type = true>
|
||||||
void write_buffer(const std::vector<T>& source)
|
void write_buffer(const std::vector<T>& source) {
|
||||||
{
|
|
||||||
auto p = get_data_ptr_nc<Type>();
|
auto p = get_data_ptr_nc<Type>();
|
||||||
for (size_t i = 0; i < source.size(); i++)
|
for (size_t i = 0; i < source.size(); i++) {
|
||||||
{
|
|
||||||
p[i] = static_cast<StorageDataType>(source[i]);
|
p[i] = static_cast<StorageDataType>(source[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -558,22 +552,17 @@ namespace ngraph
|
|||||||
template <element::Type_t Type,
|
template <element::Type_t Type,
|
||||||
typename T,
|
typename T,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type == element::Type_t::u4 ||
|
typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4, bool>::type = true>
|
||||||
Type == element::Type_t::i4,
|
void write_buffer(const std::vector<T>& source) {
|
||||||
bool>::type = true>
|
|
||||||
void write_buffer(const std::vector<T>& source)
|
|
||||||
{
|
|
||||||
auto p = get_data_ptr_nc<Type>();
|
auto p = get_data_ptr_nc<Type>();
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (; i < source.size() / 2; i++)
|
for (; i < source.size() / 2; i++) {
|
||||||
{
|
|
||||||
const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
|
const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
|
||||||
const auto v2 = value_in_range<Type>(source[i * 2 + 1]) & 0x0F;
|
const auto v2 = value_in_range<Type>(source[i * 2 + 1]) & 0x0F;
|
||||||
const auto v = (v1 << 4) | v2;
|
const auto v = (v1 << 4) | v2;
|
||||||
p[i] = static_cast<StorageDataType>(v);
|
p[i] = static_cast<StorageDataType>(v);
|
||||||
}
|
}
|
||||||
if (source.size() % 2)
|
if (source.size() % 2) {
|
||||||
{
|
|
||||||
const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
|
const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
|
||||||
const auto v = v1 << 4;
|
const auto v = v1 << 4;
|
||||||
p[i] = static_cast<StorageDataType>(v);
|
p[i] = static_cast<StorageDataType>(v);
|
||||||
@ -584,23 +573,19 @@ namespace ngraph
|
|||||||
typename T,
|
typename T,
|
||||||
typename StorageDataType = fundamental_type_for<Type>,
|
typename StorageDataType = fundamental_type_for<Type>,
|
||||||
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
|
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
|
||||||
void write_buffer(const std::vector<T>& source)
|
void write_buffer(const std::vector<T>& source) {
|
||||||
{
|
|
||||||
auto p = get_data_ptr_nc<Type>();
|
auto p = get_data_ptr_nc<Type>();
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (; i < source.size() / 8; i++)
|
for (; i < source.size() / 8; i++) {
|
||||||
{
|
|
||||||
uint8_t v{};
|
uint8_t v{};
|
||||||
for (int j = 0; j != 8; j++)
|
for (int j = 0; j != 8; j++) {
|
||||||
{
|
|
||||||
const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0;
|
const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0;
|
||||||
v |= b;
|
v |= b;
|
||||||
}
|
}
|
||||||
p[i] = static_cast<StorageDataType>(v);
|
p[i] = static_cast<StorageDataType>(v);
|
||||||
}
|
}
|
||||||
uint8_t v{};
|
uint8_t v{};
|
||||||
for (unsigned j = 0; j != source.size() % 8; j++)
|
for (unsigned j = 0; j != source.size() % 8; j++) {
|
||||||
{
|
|
||||||
const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0;
|
const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0;
|
||||||
v |= b;
|
v |= b;
|
||||||
}
|
}
|
||||||
@ -608,12 +593,10 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void write_to_buffer(const std::vector<T>& source)
|
void write_to_buffer(const std::vector<T>& source) {
|
||||||
{
|
|
||||||
const auto& target_type = m_element_type;
|
const auto& target_type = m_element_type;
|
||||||
size_t target_element_count = shape_size(m_shape);
|
size_t target_element_count = shape_size(m_shape);
|
||||||
if (source.size() != target_element_count)
|
if (source.size() != target_element_count) {
|
||||||
{
|
|
||||||
throw std::runtime_error("Constant initializer does not match shape");
|
throw std::runtime_error("Constant initializer does not match shape");
|
||||||
}
|
}
|
||||||
using Type_t = element::Type_t;
|
using Type_t = element::Type_t;
|
||||||
@ -622,63 +605,89 @@ namespace ngraph
|
|||||||
# pragma GCC diagnostic error "-Wswitch"
|
# pragma GCC diagnostic error "-Wswitch"
|
||||||
# pragma GCC diagnostic error "-Wswitch-enum"
|
# pragma GCC diagnostic error "-Wswitch-enum"
|
||||||
#endif
|
#endif
|
||||||
switch (target_type)
|
switch (target_type) {
|
||||||
{
|
case Type_t::boolean:
|
||||||
case Type_t::boolean: write_buffer<Type_t::boolean>(source); break;
|
write_buffer<Type_t::boolean>(source);
|
||||||
case Type_t::bf16: write_buffer<Type_t::bf16>(source); break;
|
break;
|
||||||
case Type_t::f16: write_buffer<Type_t::f16>(source); break;
|
case Type_t::bf16:
|
||||||
case Type_t::f32: write_buffer<Type_t::f32>(source); break;
|
write_buffer<Type_t::bf16>(source);
|
||||||
case Type_t::f64: write_buffer<Type_t::f64>(source); break;
|
break;
|
||||||
case Type_t::i4: write_buffer<Type_t::i4>(source); break;
|
case Type_t::f16:
|
||||||
case Type_t::i8: write_buffer<Type_t::i8>(source); break;
|
write_buffer<Type_t::f16>(source);
|
||||||
case Type_t::i16: write_buffer<Type_t::i16>(source); break;
|
break;
|
||||||
case Type_t::i32: write_buffer<Type_t::i32>(source); break;
|
case Type_t::f32:
|
||||||
case Type_t::i64: write_buffer<Type_t::i64>(source); break;
|
write_buffer<Type_t::f32>(source);
|
||||||
case Type_t::u1: write_buffer<Type_t::u1>(source); break;
|
break;
|
||||||
case Type_t::u4: write_buffer<Type_t::u4>(source); break;
|
case Type_t::f64:
|
||||||
case Type_t::u8: write_buffer<Type_t::u8>(source); break;
|
write_buffer<Type_t::f64>(source);
|
||||||
case Type_t::u16: write_buffer<Type_t::u16>(source); break;
|
break;
|
||||||
case Type_t::u32: write_buffer<Type_t::u32>(source); break;
|
case Type_t::i4:
|
||||||
case Type_t::u64: write_buffer<Type_t::u64>(source); break;
|
write_buffer<Type_t::i4>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::i8:
|
||||||
|
write_buffer<Type_t::i8>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::i16:
|
||||||
|
write_buffer<Type_t::i16>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::i32:
|
||||||
|
write_buffer<Type_t::i32>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::i64:
|
||||||
|
write_buffer<Type_t::i64>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::u1:
|
||||||
|
write_buffer<Type_t::u1>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::u4:
|
||||||
|
write_buffer<Type_t::u4>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::u8:
|
||||||
|
write_buffer<Type_t::u8>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::u16:
|
||||||
|
write_buffer<Type_t::u16>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::u32:
|
||||||
|
write_buffer<Type_t::u32>(source);
|
||||||
|
break;
|
||||||
|
case Type_t::u64:
|
||||||
|
write_buffer<Type_t::u64>(source);
|
||||||
|
break;
|
||||||
case element::Type_t::undefined:
|
case element::Type_t::undefined:
|
||||||
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
|
case element::Type_t::dynamic:
|
||||||
|
throw std::runtime_error("unsupported type");
|
||||||
}
|
}
|
||||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||||
# pragma GCC diagnostic pop
|
# pragma GCC diagnostic pop
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
template <
|
template <ngraph::element::Type_t Type,
|
||||||
ngraph::element::Type_t Type,
|
|
||||||
typename ValueT,
|
typename ValueT,
|
||||||
typename std::enable_if<Type == ngraph::element::Type_t::u4, bool>::type = true>
|
typename std::enable_if<Type == ngraph::element::Type_t::u4, bool>::type = true>
|
||||||
static ngraph::fundamental_type_for<Type> value_in_range(const ValueT& value)
|
static ngraph::fundamental_type_for<Type> value_in_range(const ValueT& value) {
|
||||||
{
|
|
||||||
const auto result = ngraph::fundamental_type_for<Type>(value);
|
const auto result = ngraph::fundamental_type_for<Type>(value);
|
||||||
NGRAPH_CHECK(0 <= result && result <= 15,
|
NGRAPH_CHECK(0 <= result && result <= 15, "assigned value out of range u4 values");
|
||||||
"assigned value out of range u4 values");
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <ngraph::element::Type_t Type,
|
||||||
ngraph::element::Type_t Type,
|
|
||||||
typename ValueT,
|
typename ValueT,
|
||||||
typename std::enable_if<Type == ngraph::element::Type_t::i4, bool>::type = true>
|
typename std::enable_if<Type == ngraph::element::Type_t::i4, bool>::type = true>
|
||||||
static ngraph::fundamental_type_for<Type> value_in_range(const ValueT& value)
|
static ngraph::fundamental_type_for<Type> value_in_range(const ValueT& value) {
|
||||||
{
|
|
||||||
const auto result = ngraph::fundamental_type_for<Type>(value);
|
const auto result = ngraph::fundamental_type_for<Type>(value);
|
||||||
NGRAPH_CHECK(-8 <= result && result <= 7,
|
NGRAPH_CHECK(-8 <= result && result <= 7, "assigned value out of range i4 values");
|
||||||
"assigned value out of range i4 values");
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool are_all_data_elements_bitwise_identical() const;
|
bool are_all_data_elements_bitwise_identical() const;
|
||||||
static constexpr size_t host_alignment() { return 64; }
|
static constexpr size_t host_alignment() {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
size_t mem_size() const
|
size_t mem_size() const {
|
||||||
{
|
|
||||||
const bool bitwidth_less_than_byte = m_element_type.bitwidth() < 8;
|
const bool bitwidth_less_than_byte = m_element_type.bitwidth() < 8;
|
||||||
if (bitwidth_less_than_byte)
|
if (bitwidth_less_than_byte) {
|
||||||
{
|
|
||||||
const auto size = shape_size(m_shape);
|
const auto size = shape_size(m_shape);
|
||||||
const auto bitwidth = size * m_element_type.bitwidth();
|
const auto bitwidth = size * m_element_type.bitwidth();
|
||||||
// for rounding by `(bitwidth + 7) / 8` will work for
|
// for rounding by `(bitwidth + 7) / 8` will work for
|
||||||
|
@ -7,15 +7,11 @@
|
|||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/runtime/host_tensor.hpp"
|
#include "ngraph/runtime/host_tensor.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise type conversion operation.
|
/// \brief Elementwise type conversion operation.
|
||||||
class NGRAPH_API Convert : public Op
|
class NGRAPH_API Convert : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -29,21 +25,21 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
const element::Type& get_destination_type() const {
|
||||||
const element::Type& get_destination_type() const { return m_destination_type; }
|
return m_destination_type;
|
||||||
void set_destination_type(const element::Type& destination_type)
|
}
|
||||||
{
|
void set_destination_type(const element::Type& destination_type) {
|
||||||
m_destination_type = destination_type;
|
m_destination_type = destination_type;
|
||||||
}
|
}
|
||||||
const element::Type& get_convert_element_type() const { return m_destination_type; }
|
const element::Type& get_convert_element_type() const {
|
||||||
void set_convert_element_type(const element::Type& destination_type)
|
return m_destination_type;
|
||||||
{
|
}
|
||||||
|
void set_convert_element_type(const element::Type& destination_type) {
|
||||||
m_destination_type = destination_type;
|
m_destination_type = destination_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
||||||
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
||||||
|
@ -6,15 +6,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief Elementwise type conversion operation.
|
/// \brief Elementwise type conversion operation.
|
||||||
class NGRAPH_API ConvertLike : public Op
|
class NGRAPH_API ConvertLike : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -28,11 +24,9 @@ namespace ngraph
|
|||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool constant_fold(OutputVector& output_values,
|
bool constant_fold(OutputVector& output_values, const OutputVector& input_values) override;
|
||||||
const OutputVector& input_values) override;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
|
@ -8,16 +8,12 @@
|
|||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief Batched convolution operation, with optional window dilation and stride.
|
/// \brief Batched convolution operation, with optional window dilation and stride.
|
||||||
///
|
///
|
||||||
class NGRAPH_API Convolution : public Op
|
class NGRAPH_API Convolution : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -53,24 +49,43 @@ namespace ngraph
|
|||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
/// \return The strides.
|
/// \return The strides.
|
||||||
const Strides& get_strides() const { return m_strides; }
|
const Strides& get_strides() const {
|
||||||
void set_strides(const Strides& strides) { m_strides = strides; }
|
return m_strides;
|
||||||
|
}
|
||||||
|
void set_strides(const Strides& strides) {
|
||||||
|
m_strides = strides;
|
||||||
|
}
|
||||||
/// \return The dilations.
|
/// \return The dilations.
|
||||||
const Strides& get_dilations() const { return m_dilations; }
|
const Strides& get_dilations() const {
|
||||||
void set_dilations(const Strides& dilations) { m_dilations = dilations; }
|
return m_dilations;
|
||||||
|
}
|
||||||
|
void set_dilations(const Strides& dilations) {
|
||||||
|
m_dilations = dilations;
|
||||||
|
}
|
||||||
/// \return The padding-below sizes (possibly negative).
|
/// \return The padding-below sizes (possibly negative).
|
||||||
const CoordinateDiff& get_pads_begin() const { return m_pads_begin; }
|
const CoordinateDiff& get_pads_begin() const {
|
||||||
void set_pads_begin(const CoordinateDiff& pads_begin) { m_pads_begin = pads_begin; }
|
return m_pads_begin;
|
||||||
|
}
|
||||||
|
void set_pads_begin(const CoordinateDiff& pads_begin) {
|
||||||
|
m_pads_begin = pads_begin;
|
||||||
|
}
|
||||||
/// \return The padding-above sizes (possibly negative).
|
/// \return The padding-above sizes (possibly negative).
|
||||||
const CoordinateDiff& get_pads_end() const { return m_pads_end; }
|
const CoordinateDiff& get_pads_end() const {
|
||||||
void set_adding_above(const CoordinateDiff& pads_end) { m_pads_end = pads_end; }
|
return m_pads_end;
|
||||||
|
}
|
||||||
|
void set_adding_above(const CoordinateDiff& pads_end) {
|
||||||
|
m_pads_end = pads_end;
|
||||||
|
}
|
||||||
/// \return The pad type for convolution.
|
/// \return The pad type for convolution.
|
||||||
const PadType& get_auto_pad() const { return m_auto_pad; }
|
const PadType& get_auto_pad() const {
|
||||||
void set_auto_pad(const PadType& auto_pad) { m_auto_pad = auto_pad; }
|
return m_auto_pad;
|
||||||
|
}
|
||||||
|
void set_auto_pad(const PadType& auto_pad) {
|
||||||
|
m_auto_pad = auto_pad;
|
||||||
|
}
|
||||||
/// \return The default value for Convolution.
|
/// \return The default value for Convolution.
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
virtual std::shared_ptr<Node> get_default_value() const override;
|
virtual std::shared_ptr<Node> get_default_value() const override;
|
||||||
@ -85,8 +100,7 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Data batch backprop for batched convolution operation.
|
/// \brief Data batch backprop for batched convolution operation.
|
||||||
class NGRAPH_API ConvolutionBackpropData : public Op
|
class NGRAPH_API ConvolutionBackpropData : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
|
@ -6,15 +6,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise cosine operation.
|
/// \brief Elementwise cosine operation.
|
||||||
class NGRAPH_API Cos : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Cos : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -26,10 +22,8 @@ namespace ngraph
|
|||||||
Cos(const Output<Node>& arg);
|
Cos(const Output<Node>& arg);
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -6,15 +6,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise hyperbolic cosine (cosh) operation.
|
/// \brief Elementwise hyperbolic cosine (cosh) operation.
|
||||||
class NGRAPH_API Cosh : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Cosh : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -26,10 +22,8 @@ namespace ngraph
|
|||||||
Cosh(const Output<Node>& arg);
|
Cosh(const Output<Node>& arg);
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -6,14 +6,10 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
class NGRAPH_API CTCGreedyDecoder : public Op {
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
class NGRAPH_API CTCGreedyDecoder : public Op
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -23,16 +19,15 @@ namespace ngraph
|
|||||||
/// \param input Logits on which greedy decoding is performed
|
/// \param input Logits on which greedy decoding is performed
|
||||||
/// \param seq_len Sequence lengths
|
/// \param seq_len Sequence lengths
|
||||||
/// \param ctc_merge_repeated Whether to merge repeated labels
|
/// \param ctc_merge_repeated Whether to merge repeated labels
|
||||||
CTCGreedyDecoder(const Output<Node>& input,
|
CTCGreedyDecoder(const Output<Node>& input, const Output<Node>& seq_len, const bool ctc_merge_repeated);
|
||||||
const Output<Node>& seq_len,
|
|
||||||
const bool ctc_merge_repeated);
|
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool get_ctc_merge_repeated() const { return m_ctc_merge_repeated; }
|
bool get_ctc_merge_repeated() const {
|
||||||
|
return m_ctc_merge_repeated;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool m_ctc_merge_repeated;
|
bool m_ctc_merge_repeated;
|
||||||
|
@ -6,16 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v6 {
|
||||||
{
|
|
||||||
namespace v6
|
|
||||||
{
|
|
||||||
/// \brief Operator performing CTCGreedyDecoder
|
/// \brief Operator performing CTCGreedyDecoder
|
||||||
///
|
///
|
||||||
class NGRAPH_API CTCGreedyDecoderSeqLen : public Op
|
class NGRAPH_API CTCGreedyDecoderSeqLen : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
CTCGreedyDecoderSeqLen() = default;
|
CTCGreedyDecoderSeqLen() = default;
|
||||||
@ -52,25 +48,27 @@ namespace ngraph
|
|||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
/// \brief Get merge_repeated attribute
|
/// \brief Get merge_repeated attribute
|
||||||
///
|
///
|
||||||
/// \return Current value of merge_repeated attribute
|
/// \return Current value of merge_repeated attribute
|
||||||
///
|
///
|
||||||
bool get_merge_repeated() const { return m_merge_repeated; }
|
bool get_merge_repeated() const {
|
||||||
|
return m_merge_repeated;
|
||||||
|
}
|
||||||
/// \brief Get classes_index_type attribute
|
/// \brief Get classes_index_type attribute
|
||||||
///
|
///
|
||||||
/// \return Current value of classes_index_type attribute
|
/// \return Current value of classes_index_type attribute
|
||||||
///
|
///
|
||||||
const element::Type& get_classes_index_type() const { return m_classes_index_type; }
|
const element::Type& get_classes_index_type() const {
|
||||||
|
return m_classes_index_type;
|
||||||
|
}
|
||||||
/// \brief Set classes_index_type attribute
|
/// \brief Set classes_index_type attribute
|
||||||
///
|
///
|
||||||
/// \param classes_index_type Type of classes_index
|
/// \param classes_index_type Type of classes_index
|
||||||
///
|
///
|
||||||
void set_classes_index_type(const element::Type& classes_index_type)
|
void set_classes_index_type(const element::Type& classes_index_type) {
|
||||||
{
|
|
||||||
m_classes_index_type = classes_index_type;
|
m_classes_index_type = classes_index_type;
|
||||||
validate_and_infer_types();
|
validate_and_infer_types();
|
||||||
}
|
}
|
||||||
@ -79,8 +77,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return Current value of sequence_length_type attribute
|
/// \return Current value of sequence_length_type attribute
|
||||||
///
|
///
|
||||||
const element::Type& get_sequence_length_type() const
|
const element::Type& get_sequence_length_type() const {
|
||||||
{
|
|
||||||
return m_sequence_length_type;
|
return m_sequence_length_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,8 +85,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param sequence_length_type Type of sequence length
|
/// \param sequence_length_type Type of sequence length
|
||||||
///
|
///
|
||||||
void set_sequence_length_type(const element::Type& sequence_length_type)
|
void set_sequence_length_type(const element::Type& sequence_length_type) {
|
||||||
{
|
|
||||||
m_sequence_length_type = sequence_length_type;
|
m_sequence_length_type = sequence_length_type;
|
||||||
validate_and_infer_types();
|
validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
@ -6,17 +6,15 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v4 {
|
||||||
{
|
class NGRAPH_API CTCLoss : public Op {
|
||||||
namespace v4
|
|
||||||
{
|
|
||||||
class NGRAPH_API CTCLoss : public Op
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
static constexpr NodeTypeInfo type_info{"CTCLoss", 0};
|
static constexpr NodeTypeInfo type_info{"CTCLoss", 0};
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
const NodeTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
CTCLoss() = default;
|
CTCLoss() = default;
|
||||||
/// \brief Constructs a CTCLoss operation
|
/// \brief Constructs a CTCLoss operation
|
||||||
///
|
///
|
||||||
@ -53,15 +51,17 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool get_preprocess_collapse_repeated() const
|
bool get_preprocess_collapse_repeated() const {
|
||||||
{
|
|
||||||
return preprocess_collapse_repeated_;
|
return preprocess_collapse_repeated_;
|
||||||
}
|
}
|
||||||
bool get_ctc_merge_repeated() const { return ctc_merge_repeated_; }
|
bool get_ctc_merge_repeated() const {
|
||||||
bool get_unique() const { return unique_; }
|
return ctc_merge_repeated_;
|
||||||
|
}
|
||||||
|
bool get_unique() const {
|
||||||
|
return unique_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool preprocess_collapse_repeated_;
|
bool preprocess_collapse_repeated_;
|
||||||
|
@ -7,12 +7,9 @@
|
|||||||
#include "ngraph/axis_set.hpp"
|
#include "ngraph/axis_set.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Tensor cumulative sum operation.
|
/// \brief Tensor cumulative sum operation.
|
||||||
///
|
///
|
||||||
/// Compute the cumulative sum of the input tensor along the axis specified.
|
/// Compute the cumulative sum of the input tensor along the axis specified.
|
||||||
@ -57,8 +54,7 @@ namespace ngraph
|
|||||||
/// | Output tensor of the same type as `arg` with cumulative sums of the arg's elements
|
/// | Output tensor of the same type as `arg` with cumulative sums of the arg's elements
|
||||||
/// |
|
/// |
|
||||||
|
|
||||||
class NGRAPH_API CumSum : public Op
|
class NGRAPH_API CumSum : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -72,20 +68,14 @@ namespace ngraph
|
|||||||
/// cumulative sum must be performed
|
/// cumulative sum must be performed
|
||||||
/// \param exclusive if set to true, the top element is not included
|
/// \param exclusive if set to true, the top element is not included
|
||||||
/// \param reverse if set to true, will perform the sums in reverse direction
|
/// \param reverse if set to true, will perform the sums in reverse direction
|
||||||
CumSum(const Output<Node>& arg,
|
CumSum(const Output<Node>& arg, const Output<Node>& axis, const bool exclusive = false, const bool reverse = false);
|
||||||
const Output<Node>& axis,
|
|
||||||
const bool exclusive = false,
|
|
||||||
const bool reverse = false);
|
|
||||||
|
|
||||||
/// \brief Constructs a cumulative summation operation with axis = 0
|
/// \brief Constructs a cumulative summation operation with axis = 0
|
||||||
///
|
///
|
||||||
/// \param arg The tensor to be summed
|
/// \param arg The tensor to be summed
|
||||||
CumSum(const Output<Node>& arg,
|
CumSum(const Output<Node>& arg, const bool exclusive = false, const bool reverse = false);
|
||||||
const bool exclusive = false,
|
|
||||||
const bool reverse = false);
|
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
@ -94,8 +84,12 @@ namespace ngraph
|
|||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
virtual std::shared_ptr<Node> get_default_value() const override;
|
virtual std::shared_ptr<Node> get_default_value() const override;
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||||
bool is_exclusive() const { return m_exclusive; }
|
bool is_exclusive() const {
|
||||||
bool is_reverse() const { return m_reverse; }
|
return m_exclusive;
|
||||||
|
}
|
||||||
|
bool is_reverse() const {
|
||||||
|
return m_reverse;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool m_exclusive;
|
bool m_exclusive;
|
||||||
|
@ -9,15 +9,11 @@
|
|||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
#include "ngraph/op/util/deformable_convolution_base.hpp"
|
#include "ngraph/op/util/deformable_convolution_base.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief DeformableConvolution operation.
|
/// \brief DeformableConvolution operation.
|
||||||
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase
|
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -55,15 +51,12 @@ namespace ngraph
|
|||||||
const int64_t group = 1,
|
const int64_t group = 1,
|
||||||
const int64_t deformable_group = 1);
|
const int64_t deformable_group = 1);
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
|
|
||||||
namespace v8
|
namespace v8 {
|
||||||
{
|
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase {
|
||||||
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -160,18 +153,17 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool get_bilinear_interpolation_pad() const { return m_bilinear_interpolation_pad; }
|
bool get_bilinear_interpolation_pad() const {
|
||||||
|
return m_bilinear_interpolation_pad;
|
||||||
|
}
|
||||||
|
|
||||||
void set_bilinear_interpolation_pad(const bool bilinear_interpolation_pad)
|
void set_bilinear_interpolation_pad(const bool bilinear_interpolation_pad) {
|
||||||
{
|
|
||||||
m_bilinear_interpolation_pad = bilinear_interpolation_pad;
|
m_bilinear_interpolation_pad = bilinear_interpolation_pad;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,14 +6,10 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
class NGRAPH_API DeformablePSROIPooling : public Op {
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
class NGRAPH_API DeformablePSROIPooling : public Op
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -69,17 +65,32 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
int64_t get_output_dim() const { return m_output_dim; }
|
int64_t get_output_dim() const {
|
||||||
int64_t get_group_size() const { return m_group_size; }
|
return m_output_dim;
|
||||||
float get_spatial_scale() const { return m_spatial_scale; }
|
}
|
||||||
const std::string& get_mode() const { return m_mode; }
|
int64_t get_group_size() const {
|
||||||
int64_t get_spatial_bins_x() const { return m_spatial_bins_x; }
|
return m_group_size;
|
||||||
int64_t get_spatial_bins_y() const { return m_spatial_bins_y; }
|
}
|
||||||
float get_trans_std() const { return m_trans_std; }
|
float get_spatial_scale() const {
|
||||||
int64_t get_part_size() const { return m_part_size; }
|
return m_spatial_scale;
|
||||||
|
}
|
||||||
|
const std::string& get_mode() const {
|
||||||
|
return m_mode;
|
||||||
|
}
|
||||||
|
int64_t get_spatial_bins_x() const {
|
||||||
|
return m_spatial_bins_x;
|
||||||
|
}
|
||||||
|
int64_t get_spatial_bins_y() const {
|
||||||
|
return m_spatial_bins_y;
|
||||||
|
}
|
||||||
|
float get_trans_std() const {
|
||||||
|
return m_trans_std;
|
||||||
|
}
|
||||||
|
int64_t get_part_size() const {
|
||||||
|
return m_part_size;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int64_t m_output_dim;
|
int64_t m_output_dim;
|
||||||
|
@ -7,12 +7,9 @@
|
|||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief DepthToSpace permutes data from the depth dimension of the input blob into
|
/// \brief DepthToSpace permutes data from the depth dimension of the input blob into
|
||||||
/// spatial dimensions.
|
/// spatial dimensions.
|
||||||
///
|
///
|
||||||
@ -21,13 +18,11 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// Output node produces a tensor with shape:
|
/// Output node produces a tensor with shape:
|
||||||
/// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize]
|
/// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize]
|
||||||
class NGRAPH_API DepthToSpace : public Op
|
class NGRAPH_API DepthToSpace : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
enum class DepthToSpaceMode
|
enum class DepthToSpaceMode {
|
||||||
{
|
|
||||||
// The input depth is divided to [block_size, ..., block_size, new_depth]
|
// The input depth is divided to [block_size, ..., block_size, new_depth]
|
||||||
BLOCKS_FIRST,
|
BLOCKS_FIRST,
|
||||||
// The input depth is divided to [new_depth, block_size, ..., block_size]
|
// The input depth is divided to [new_depth, block_size, ..., block_size]
|
||||||
@ -41,22 +36,20 @@ namespace ngraph
|
|||||||
/// \param mode Specifies how the input depth dimension is split to block
|
/// \param mode Specifies how the input depth dimension is split to block
|
||||||
/// coordinates
|
/// coordinates
|
||||||
/// \param block_size The size of the block of values to be moved
|
/// \param block_size The size of the block of values to be moved
|
||||||
DepthToSpace(const Output<Node>& data,
|
DepthToSpace(const Output<Node>& data, const DepthToSpaceMode& mode, std::size_t block_size = 1);
|
||||||
const DepthToSpaceMode& mode,
|
|
||||||
std::size_t block_size = 1);
|
|
||||||
|
|
||||||
DepthToSpace(const Output<Node>& data,
|
DepthToSpace(const Output<Node>& data, const std::string& mode, std::size_t block_size = 1);
|
||||||
const std::string& mode,
|
|
||||||
std::size_t block_size = 1);
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::size_t get_block_size() const { return m_blocksize; }
|
std::size_t get_block_size() const {
|
||||||
DepthToSpaceMode get_mode() const { return m_mode; }
|
return m_blocksize;
|
||||||
virtual std::shared_ptr<Node>
|
}
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
DepthToSpaceMode get_mode() const {
|
||||||
|
return m_mode;
|
||||||
|
}
|
||||||
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -72,16 +65,14 @@ namespace ngraph
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>
|
class NGRAPH_API AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>
|
||||||
: public EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>
|
: public EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(op::v0::DepthToSpace::DepthToSpaceMode& value)
|
AttributeAdapter(op::v0::DepthToSpace::DepthToSpaceMode& value)
|
||||||
: EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>(value)
|
: EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>(value) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>", 0};
|
||||||
"AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>", 0};
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,12 +6,9 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
struct DetectionOutputAttrs {
|
||||||
{
|
|
||||||
struct DetectionOutputAttrs
|
|
||||||
{
|
|
||||||
int num_classes;
|
int num_classes;
|
||||||
int background_label_id = 0;
|
int background_label_id = 0;
|
||||||
int top_k = -1;
|
int top_k = -1;
|
||||||
@ -30,12 +27,10 @@ namespace ngraph
|
|||||||
float objectness_score = 0;
|
float objectness_score = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace v0
|
namespace v0 {
|
||||||
{
|
|
||||||
/// \brief Layer which performs non-max suppression to
|
/// \brief Layer which performs non-max suppression to
|
||||||
/// generate detection output using location and confidence predictions
|
/// generate detection output using location and confidence predictions
|
||||||
class NGRAPH_API DetectionOutput : public Op
|
class NGRAPH_API DetectionOutput : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -68,10 +63,11 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
const DetectionOutputAttrs& get_attrs() const { return m_attrs; }
|
const DetectionOutputAttrs& get_attrs() const {
|
||||||
|
return m_attrs;
|
||||||
|
}
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -18,20 +18,17 @@
|
|||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ngraph/attribute_adapter.hpp"
|
#include "ngraph/attribute_adapter.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
#include "ngraph/op/util/fft_base.hpp"
|
#include "ngraph/op/util/fft_base.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v7 {
|
||||||
{
|
|
||||||
namespace v7
|
|
||||||
{
|
|
||||||
/// \brief An operation DFT that computes the discrete Fourier transformation.
|
/// \brief An operation DFT that computes the discrete Fourier transformation.
|
||||||
class NGRAPH_API DFT : public util::FFTBase
|
class NGRAPH_API DFT : public util::FFTBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
DFT() = default;
|
DFT() = default;
|
||||||
@ -47,14 +44,11 @@ namespace ngraph
|
|||||||
/// \param data Input data
|
/// \param data Input data
|
||||||
/// \param axes Axes to perform DFT
|
/// \param axes Axes to perform DFT
|
||||||
/// \param signal_size Signal sizes for 'axes'
|
/// \param signal_size Signal sizes for 'axes'
|
||||||
DFT(const Output<Node>& data,
|
DFT(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size);
|
||||||
const Output<Node>& axes,
|
|
||||||
const Output<Node>& signal_size);
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v7
|
} // namespace v7
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -6,22 +6,15 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief Elementwise division operation.
|
/// \brief Elementwise division operation.
|
||||||
class NGRAPH_API Divide : public util::BinaryElementwiseArithmetic
|
class NGRAPH_API Divide : public util::BinaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
/// \brief Constructs a division operation.
|
/// \brief Constructs a division operation.
|
||||||
Divide()
|
Divide() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) {}
|
||||||
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Constructs a division operation.
|
/// \brief Constructs a division operation.
|
||||||
///
|
///
|
||||||
@ -32,8 +25,7 @@ namespace ngraph
|
|||||||
Divide(const Output<Node>& arg0,
|
Divide(const Output<Node>& arg0,
|
||||||
const Output<Node>& arg1,
|
const Output<Node>& arg1,
|
||||||
bool pythondiv,
|
bool pythondiv,
|
||||||
const AutoBroadcastSpec& auto_broadcast =
|
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
||||||
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
|
||||||
|
|
||||||
/// \brief Constructs a division operation.
|
/// \brief Constructs a division operation.
|
||||||
///
|
///
|
||||||
@ -42,16 +34,17 @@ namespace ngraph
|
|||||||
/// \param auto_broadcast Auto broadcast specification
|
/// \param auto_broadcast Auto broadcast specification
|
||||||
Divide(const Output<Node>& arg0,
|
Divide(const Output<Node>& arg0,
|
||||||
const Output<Node>& arg1,
|
const Output<Node>& arg1,
|
||||||
const AutoBroadcastSpec& auto_broadcast =
|
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
||||||
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
bool is_pythondiv() const { return m_pythondiv; }
|
bool is_pythondiv() const {
|
||||||
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
|
return m_pythondiv;
|
||||||
virtual std::shared_ptr<Node>
|
}
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
void set_is_pythondiv(bool pythondiv) {
|
||||||
|
m_pythondiv = pythondiv;
|
||||||
|
}
|
||||||
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -7,15 +7,11 @@
|
|||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v7 {
|
||||||
{
|
|
||||||
namespace v7
|
|
||||||
{
|
|
||||||
/// \brief Einsum operation.
|
/// \brief Einsum operation.
|
||||||
class NGRAPH_API Einsum : public Op
|
class NGRAPH_API Einsum : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -35,14 +31,15 @@ namespace ngraph
|
|||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
/// \brief Get an equation of Einsum operation
|
/// \brief Get an equation of Einsum operation
|
||||||
///
|
///
|
||||||
/// \return Einsum equation
|
/// \return Einsum equation
|
||||||
///
|
///
|
||||||
std::string get_equation() const { return m_equation; }
|
std::string get_equation() const {
|
||||||
|
return m_equation;
|
||||||
|
}
|
||||||
|
|
||||||
/// \brief Check correctness of equation format and extract input subscripts
|
/// \brief Check correctness of equation format and extract input subscripts
|
||||||
/// and output subscript
|
/// and output subscript
|
||||||
|
@ -7,18 +7,14 @@
|
|||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Exponential Linear Unit
|
/// \brief Exponential Linear Unit
|
||||||
/// x < 0 => f(x) = alpha * (exp(x) - 1.)
|
/// x < 0 => f(x) = alpha * (exp(x) - 1.)
|
||||||
/// x >= 0 => f(x) = x
|
/// x >= 0 => f(x) = x
|
||||||
///
|
///
|
||||||
class NGRAPH_API Elu : public ngraph::op::Op
|
class NGRAPH_API Elu : public ngraph::op::Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -32,10 +28,11 @@ namespace ngraph
|
|||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
double get_alpha() const { return m_alpha; }
|
double get_alpha() const {
|
||||||
|
return m_alpha;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double m_alpha;
|
double m_alpha;
|
||||||
|
@ -7,18 +7,16 @@
|
|||||||
#include "ngraph/axis_set.hpp"
|
#include "ngraph/axis_set.hpp"
|
||||||
#include "ngraph/op/util/index_reduction.hpp"
|
#include "ngraph/op/util/index_reduction.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v3 {
|
||||||
{
|
|
||||||
namespace v3
|
|
||||||
{
|
|
||||||
/// \brief Returns embeddings for given indices
|
/// \brief Returns embeddings for given indices
|
||||||
class NGRAPH_API EmbeddingSegmentsSum : public Op
|
class NGRAPH_API EmbeddingSegmentsSum : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
static constexpr NodeTypeInfo type_info{"EmbeddingSegmentsSum", 3};
|
static constexpr NodeTypeInfo type_info{"EmbeddingSegmentsSum", 3};
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
const NodeTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
/// \brief Constructs a EmbeddingSegmentsSum operation.
|
/// \brief Constructs a EmbeddingSegmentsSum operation.
|
||||||
EmbeddingSegmentsSum() = default;
|
EmbeddingSegmentsSum() = default;
|
||||||
/// \brief Constructs a EmbeddingSegmentsSum operation.
|
/// \brief Constructs a EmbeddingSegmentsSum operation.
|
||||||
@ -63,10 +61,11 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor&) override { return true; }
|
bool visit_attributes(AttributeVisitor&) override {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr int EMB_TABLE = 0;
|
static constexpr int EMB_TABLE = 0;
|
||||||
|
@ -8,18 +8,16 @@
|
|||||||
#include "ngraph/op/util/embeddingbag_offsets_base.hpp"
|
#include "ngraph/op/util/embeddingbag_offsets_base.hpp"
|
||||||
#include "ngraph/op/util/index_reduction.hpp"
|
#include "ngraph/op/util/index_reduction.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v3 {
|
||||||
{
|
|
||||||
namespace v3
|
|
||||||
{
|
|
||||||
/// \brief Returns embeddings for given indices
|
/// \brief Returns embeddings for given indices
|
||||||
class NGRAPH_API EmbeddingBagOffsetsSum : public util::EmbeddingBagOffsetsBase
|
class NGRAPH_API EmbeddingBagOffsetsSum : public util::EmbeddingBagOffsetsBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
static constexpr NodeTypeInfo type_info{"EmbeddingBagOffsetsSum", 3};
|
static constexpr NodeTypeInfo type_info{"EmbeddingBagOffsetsSum", 3};
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
const NodeTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
/// \brief Constructs a EmbeddingBagOffsetsSum operation.
|
/// \brief Constructs a EmbeddingBagOffsetsSum operation.
|
||||||
EmbeddingBagOffsetsSum() = default;
|
EmbeddingBagOffsetsSum() = default;
|
||||||
/// \brief Constructs a EmbeddingBagOffsetsSum operation.
|
/// \brief Constructs a EmbeddingBagOffsetsSum operation.
|
||||||
@ -51,12 +49,9 @@ namespace ngraph
|
|||||||
const Output<Node>& offsets,
|
const Output<Node>& offsets,
|
||||||
const Output<Node>& default_index);
|
const Output<Node>& default_index);
|
||||||
|
|
||||||
EmbeddingBagOffsetsSum(const Output<Node>& emb_table,
|
EmbeddingBagOffsetsSum(const Output<Node>& emb_table, const Output<Node>& indices, const Output<Node>& offsets);
|
||||||
const Output<Node>& indices,
|
|
||||||
const Output<Node>& offsets);
|
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v3
|
} // namespace v3
|
||||||
using v3::EmbeddingBagOffsetsSum;
|
using v3::EmbeddingBagOffsetsSum;
|
||||||
|
@ -8,18 +8,16 @@
|
|||||||
#include "ngraph/op/util/embeddingbag_packed_base.hpp"
|
#include "ngraph/op/util/embeddingbag_packed_base.hpp"
|
||||||
#include "ngraph/op/util/index_reduction.hpp"
|
#include "ngraph/op/util/index_reduction.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v3 {
|
||||||
{
|
|
||||||
namespace v3
|
|
||||||
{
|
|
||||||
/// \brief Returns embeddings for given indices
|
/// \brief Returns embeddings for given indices
|
||||||
class NGRAPH_API EmbeddingBagPackedSum : public util::EmbeddingBagPackedBase
|
class NGRAPH_API EmbeddingBagPackedSum : public util::EmbeddingBagPackedBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
static constexpr NodeTypeInfo type_info{"EmbeddingBagPackedSum", 3};
|
static constexpr NodeTypeInfo type_info{"EmbeddingBagPackedSum", 3};
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
const NodeTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
/// \brief Constructs a EmbeddingBagPackedSum operation.
|
/// \brief Constructs a EmbeddingBagPackedSum operation.
|
||||||
EmbeddingBagPackedSum() = default;
|
EmbeddingBagPackedSum() = default;
|
||||||
/// \brief Constructs a EmbeddingBagPackedSum operation.
|
/// \brief Constructs a EmbeddingBagPackedSum operation.
|
||||||
@ -42,8 +40,7 @@ namespace ngraph
|
|||||||
|
|
||||||
EmbeddingBagPackedSum(const Output<Node>& emb_table, const Output<Node>& indices);
|
EmbeddingBagPackedSum(const Output<Node>& emb_table, const Output<Node>& indices);
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v3
|
} // namespace v3
|
||||||
using v3::EmbeddingBagPackedSum;
|
using v3::EmbeddingBagPackedSum;
|
||||||
|
@ -6,12 +6,9 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
|
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
/// \brief Elementwise is-equal operation.
|
/// \brief Elementwise is-equal operation.
|
||||||
///
|
///
|
||||||
@ -29,15 +26,11 @@ namespace ngraph
|
|||||||
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
|
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
|
||||||
// clang-format on
|
// clang-format on
|
||||||
class NGRAPH_API Equal : public util::BinaryElementwiseComparison
|
class NGRAPH_API Equal : public util::BinaryElementwiseComparison {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
/// \brief Constructs an equal operation.
|
/// \brief Constructs an equal operation.
|
||||||
Equal()
|
Equal() : util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY) {}
|
||||||
: util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
/// \brief Constructs an equal operation.
|
/// \brief Constructs an equal operation.
|
||||||
///
|
///
|
||||||
/// \param arg0 Node that produces the first input tensor.
|
/// \param arg0 Node that produces the first input tensor.
|
||||||
@ -45,15 +38,12 @@ namespace ngraph
|
|||||||
/// \param auto_broadcast Auto broadcast specification
|
/// \param auto_broadcast Auto broadcast specification
|
||||||
Equal(const Output<Node>& arg0,
|
Equal(const Output<Node>& arg0,
|
||||||
const Output<Node>& arg1,
|
const Output<Node>& arg1,
|
||||||
const AutoBroadcastSpec& auto_broadcast =
|
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
||||||
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
|
@ -6,15 +6,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise erf operation.
|
/// \brief Elementwise erf operation.
|
||||||
class NGRAPH_API Erf : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Erf : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
/// \brief Constructs a floor operation.
|
/// \brief Constructs a floor operation.
|
||||||
@ -25,10 +21,8 @@ namespace ngraph
|
|||||||
Erf(const Output<Node>& arg);
|
Erf(const Output<Node>& arg);
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -6,15 +6,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise natural exponential (exp) operation.
|
/// \brief Elementwise natural exponential (exp) operation.
|
||||||
class NGRAPH_API Exp : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Exp : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -26,11 +22,9 @@ namespace ngraph
|
|||||||
Exp(const Output<Node>& arg);
|
Exp(const Output<Node>& arg);
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -6,27 +6,23 @@
|
|||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ngraph/attribute_adapter.hpp"
|
#include "ngraph/attribute_adapter.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v6 {
|
||||||
{
|
|
||||||
namespace v6
|
|
||||||
{
|
|
||||||
/// \brief An operation ExperimentalDetectronDetectionOutput performs
|
/// \brief An operation ExperimentalDetectronDetectionOutput performs
|
||||||
/// non-maximum suppression to generate the detection output using
|
/// non-maximum suppression to generate the detection output using
|
||||||
/// information on location and score predictions.
|
/// information on location and score predictions.
|
||||||
class NGRAPH_API ExperimentalDetectronDetectionOutput : public Op
|
class NGRAPH_API ExperimentalDetectronDetectionOutput : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
/// \brief Structure that specifies attributes of the operation
|
/// \brief Structure that specifies attributes of the operation
|
||||||
struct Attributes
|
struct Attributes {
|
||||||
{
|
|
||||||
// specifies score threshold
|
// specifies score threshold
|
||||||
float score_threshold;
|
float score_threshold;
|
||||||
// specifies NMS threshold
|
// specifies NMS threshold
|
||||||
@ -64,10 +60,11 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
/// \brief Returns attributes of the operation ExperimentalDetectronDetectionOutput
|
/// \brief Returns attributes of the operation ExperimentalDetectronDetectionOutput
|
||||||
const Attributes& get_attrs() const { return m_attrs; }
|
const Attributes& get_attrs() const {
|
||||||
|
return m_attrs;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Attributes m_attrs;
|
Attributes m_attrs;
|
||||||
|
@ -6,26 +6,22 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ngraph/attribute_adapter.hpp"
|
#include "ngraph/attribute_adapter.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v6 {
|
||||||
{
|
|
||||||
namespace v6
|
|
||||||
{
|
|
||||||
/// \brief An operation ExperimentalDetectronGenerateProposalsSingleImage
|
/// \brief An operation ExperimentalDetectronGenerateProposalsSingleImage
|
||||||
/// computes ROIs and their scores based on input data.
|
/// computes ROIs and their scores based on input data.
|
||||||
class NGRAPH_API ExperimentalDetectronGenerateProposalsSingleImage : public Op
|
class NGRAPH_API ExperimentalDetectronGenerateProposalsSingleImage : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
/// \brief Structure that specifies attributes of the operation
|
/// \brief Structure that specifies attributes of the operation
|
||||||
struct Attributes
|
struct Attributes {
|
||||||
{
|
|
||||||
// minimum box width & height
|
// minimum box width & height
|
||||||
float min_size;
|
float min_size;
|
||||||
// specifies NMS threshold
|
// specifies NMS threshold
|
||||||
@ -54,10 +50,11 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
const Attributes& get_attrs() const { return m_attrs; }
|
const Attributes& get_attrs() const {
|
||||||
|
return m_attrs;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Attributes m_attrs;
|
Attributes m_attrs;
|
||||||
|
@ -6,26 +6,22 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ngraph/attribute_adapter.hpp"
|
#include "ngraph/attribute_adapter.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v6 {
|
||||||
{
|
|
||||||
namespace v6
|
|
||||||
{
|
|
||||||
/// \brief An operation ExperimentalDetectronPriorGridGenerator generates prior
|
/// \brief An operation ExperimentalDetectronPriorGridGenerator generates prior
|
||||||
/// grids of specified sizes.
|
/// grids of specified sizes.
|
||||||
class NGRAPH_API ExperimentalDetectronPriorGridGenerator : public Op
|
class NGRAPH_API ExperimentalDetectronPriorGridGenerator : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
/// \brief Structure that specifies attributes of the operation
|
/// \brief Structure that specifies attributes of the operation
|
||||||
struct Attributes
|
struct Attributes {
|
||||||
{
|
|
||||||
// Specifies whether the output tensor should be 2D or 4D
|
// Specifies whether the output tensor should be 2D or 4D
|
||||||
// `true` means the output tensor should be 2D tensor,
|
// `true` means the output tensor should be 2D tensor,
|
||||||
// `false` means the output tensor should be 4D tensor.
|
// `false` means the output tensor should be 4D tensor.
|
||||||
@ -55,10 +51,11 @@ namespace ngraph
|
|||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
/// \brief Returns attributes of this operation.
|
/// \brief Returns attributes of this operation.
|
||||||
const Attributes& get_attrs() const { return m_attrs; }
|
const Attributes& get_attrs() const {
|
||||||
|
return m_attrs;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Attributes m_attrs;
|
Attributes m_attrs;
|
||||||
|
@ -7,26 +7,22 @@
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ngraph/attribute_adapter.hpp"
|
#include "ngraph/attribute_adapter.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v6 {
|
||||||
{
|
|
||||||
namespace v6
|
|
||||||
{
|
|
||||||
/// \brief An operation ExperimentalDetectronROIFeatureExtractor
|
/// \brief An operation ExperimentalDetectronROIFeatureExtractor
|
||||||
/// is the ROIAlign operation applied over a feature pyramid.
|
/// is the ROIAlign operation applied over a feature pyramid.
|
||||||
class NGRAPH_API ExperimentalDetectronROIFeatureExtractor : public Op
|
class NGRAPH_API ExperimentalDetectronROIFeatureExtractor : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
/// \brief Structure that specifies attributes of the operation
|
/// \brief Structure that specifies attributes of the operation
|
||||||
struct Attributes
|
struct Attributes {
|
||||||
{
|
|
||||||
int64_t output_size;
|
int64_t output_size;
|
||||||
int64_t sampling_ratio;
|
int64_t sampling_ratio;
|
||||||
std::vector<int64_t> pyramid_scales;
|
std::vector<int64_t> pyramid_scales;
|
||||||
@ -38,23 +34,22 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
|
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
|
||||||
/// \param attrs Operation attributes
|
/// \param attrs Operation attributes
|
||||||
ExperimentalDetectronROIFeatureExtractor(const OutputVector& args,
|
ExperimentalDetectronROIFeatureExtractor(const OutputVector& args, const Attributes& attrs);
|
||||||
const Attributes& attrs);
|
|
||||||
|
|
||||||
/// \brief Constructs a ExperimentalDetectronROIFeatureExtractor operation.
|
/// \brief Constructs a ExperimentalDetectronROIFeatureExtractor operation.
|
||||||
///
|
///
|
||||||
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
|
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
|
||||||
/// \param attrs Operation attributes
|
/// \param attrs Operation attributes
|
||||||
ExperimentalDetectronROIFeatureExtractor(const NodeVector& args,
|
ExperimentalDetectronROIFeatureExtractor(const NodeVector& args, const Attributes& attrs);
|
||||||
const Attributes& attrs);
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
/// \brief Returns attributes of the operation.
|
/// \brief Returns attributes of the operation.
|
||||||
const Attributes& get_attrs() const { return m_attrs; }
|
const Attributes& get_attrs() const {
|
||||||
|
return m_attrs;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Attributes m_attrs;
|
Attributes m_attrs;
|
||||||
|
@ -6,20 +6,17 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ngraph/attribute_adapter.hpp"
|
#include "ngraph/attribute_adapter.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v6 {
|
||||||
{
|
|
||||||
namespace v6
|
|
||||||
{
|
|
||||||
/// \brief An operation ExperimentalDetectronTopKROIs, according to the repository
|
/// \brief An operation ExperimentalDetectronTopKROIs, according to the repository
|
||||||
/// is TopK operation applied to probabilities of input ROIs.
|
/// is TopK operation applied to probabilities of input ROIs.
|
||||||
class NGRAPH_API ExperimentalDetectronTopKROIs : public Op
|
class NGRAPH_API ExperimentalDetectronTopKROIs : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -29,17 +26,16 @@ namespace ngraph
|
|||||||
/// \param input_rois Input rois
|
/// \param input_rois Input rois
|
||||||
/// \param rois_probs Probabilities for input rois
|
/// \param rois_probs Probabilities for input rois
|
||||||
/// \param max_rois Maximal numbers of output rois
|
/// \param max_rois Maximal numbers of output rois
|
||||||
ExperimentalDetectronTopKROIs(const Output<Node>& input_rois,
|
ExperimentalDetectronTopKROIs(const Output<Node>& input_rois, const Output<Node>& rois_probs, size_t max_rois = 0);
|
||||||
const Output<Node>& rois_probs,
|
|
||||||
size_t max_rois = 0);
|
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
size_t get_max_rois() const { return m_max_rois; }
|
size_t get_max_rois() const {
|
||||||
|
return m_max_rois;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t m_max_rois;
|
size_t m_max_rois;
|
||||||
|
@ -6,14 +6,10 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v3 {
|
||||||
{
|
class NGRAPH_API ExtractImagePatches : public Op {
|
||||||
namespace v3
|
|
||||||
{
|
|
||||||
class NGRAPH_API ExtractImagePatches : public Op
|
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -36,17 +32,32 @@ namespace ngraph
|
|||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
const Shape& get_sizes() const { return m_patch_sizes; }
|
const Shape& get_sizes() const {
|
||||||
void set_sizes(const Shape& sizes) { m_patch_sizes = sizes; }
|
return m_patch_sizes;
|
||||||
const Strides& get_strides() const { return m_patch_movement_strides; }
|
}
|
||||||
void set_strides(const Strides& strides) { m_patch_movement_strides = strides; }
|
void set_sizes(const Shape& sizes) {
|
||||||
const Shape& get_rates() const { return m_patch_selection_rates; }
|
m_patch_sizes = sizes;
|
||||||
void set_rates(const Shape& rates) { m_patch_selection_rates = rates; }
|
}
|
||||||
const PadType& get_auto_pad() const { return m_padding; }
|
const Strides& get_strides() const {
|
||||||
void set_auto_pad(PadType& padding) { m_padding = padding; }
|
return m_patch_movement_strides;
|
||||||
|
}
|
||||||
|
void set_strides(const Strides& strides) {
|
||||||
|
m_patch_movement_strides = strides;
|
||||||
|
}
|
||||||
|
const Shape& get_rates() const {
|
||||||
|
return m_patch_selection_rates;
|
||||||
|
}
|
||||||
|
void set_rates(const Shape& rates) {
|
||||||
|
m_patch_selection_rates = rates;
|
||||||
|
}
|
||||||
|
const PadType& get_auto_pad() const {
|
||||||
|
return m_padding;
|
||||||
|
}
|
||||||
|
void set_auto_pad(PadType& padding) {
|
||||||
|
m_padding = padding;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape m_patch_sizes;
|
Shape m_patch_sizes;
|
||||||
|
@ -8,12 +8,9 @@
|
|||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/attr_types.hpp"
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
///
|
///
|
||||||
/// \brief Class performing element-wise linear quantization.
|
/// \brief Class performing element-wise linear quantization.
|
||||||
///
|
///
|
||||||
@ -27,8 +24,7 @@ namespace ngraph
|
|||||||
/// (levels-1) * (output_high - output_low) + output_low
|
/// (levels-1) * (output_high - output_low) + output_low
|
||||||
///
|
///
|
||||||
///
|
///
|
||||||
class NGRAPH_API FakeQuantize : public ngraph::op::Op
|
class NGRAPH_API FakeQuantize : public ngraph::op::Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -51,20 +47,23 @@ namespace ngraph
|
|||||||
const Output<Node>& output_low,
|
const Output<Node>& output_low,
|
||||||
const Output<Node>& output_high,
|
const Output<Node>& output_high,
|
||||||
std::size_t levels,
|
std::size_t levels,
|
||||||
const AutoBroadcastSpec& auto_broadcast =
|
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
||||||
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
virtual void validate_and_infer_types() override;
|
virtual void validate_and_infer_types() override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
std::size_t get_levels() const { return m_levels; }
|
std::size_t get_levels() const {
|
||||||
void set_levels(std::size_t levels) { m_levels = levels; }
|
return m_levels;
|
||||||
const AutoBroadcastSpec& get_auto_broadcast() const { return m_auto_broadcast; }
|
}
|
||||||
void set_auto_broadcast(const AutoBroadcastSpec& auto_broadcast)
|
void set_levels(std::size_t levels) {
|
||||||
{
|
m_levels = levels;
|
||||||
|
}
|
||||||
|
const AutoBroadcastSpec& get_auto_broadcast() const {
|
||||||
|
return m_auto_broadcast;
|
||||||
|
}
|
||||||
|
void set_auto_broadcast(const AutoBroadcastSpec& auto_broadcast) {
|
||||||
m_auto_broadcast = auto_broadcast;
|
m_auto_broadcast = auto_broadcast;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,15 +6,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v0 {
|
||||||
{
|
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
/// \brief Elementwise floor operation.
|
/// \brief Elementwise floor operation.
|
||||||
class NGRAPH_API Floor : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Floor : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
/// \brief Constructs a floor operation.
|
/// \brief Constructs a floor operation.
|
||||||
@ -25,10 +21,8 @@ namespace ngraph
|
|||||||
Floor(const Output<Node>& arg);
|
Floor(const Output<Node>& arg);
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
@ -8,22 +8,17 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief Elementwise FloorMod operation.
|
/// \brief Elementwise FloorMod operation.
|
||||||
///
|
///
|
||||||
class NGRAPH_API FloorMod : public util::BinaryElementwiseArithmetic
|
class NGRAPH_API FloorMod : public util::BinaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
/// \brief Constructs an uninitialized addition operation
|
/// \brief Constructs an uninitialized addition operation
|
||||||
FloorMod()
|
FloorMod() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY){};
|
||||||
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY){};
|
|
||||||
|
|
||||||
/// \brief Constructs an Floor Mod operation.
|
/// \brief Constructs an Floor Mod operation.
|
||||||
///
|
///
|
||||||
@ -39,11 +34,9 @@ namespace ngraph
|
|||||||
const Output<Node>& arg1,
|
const Output<Node>& arg1,
|
||||||
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
|
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
|
@ -6,15 +6,11 @@
|
|||||||
|
|
||||||
#include "ngraph/op/util/gather_base.hpp"
|
#include "ngraph/op/util/gather_base.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief Gather slices from axis of params according to indices
|
/// \brief Gather slices from axis of params according to indices
|
||||||
class NGRAPH_API Gather : public op::util::GatherBase
|
class NGRAPH_API Gather : public op::util::GatherBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
|
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
|
||||||
@ -22,23 +18,18 @@ namespace ngraph
|
|||||||
/// \param params The tensor from which slices are gathered
|
/// \param params The tensor from which slices are gathered
|
||||||
/// \param indices Tensor with indexes to gather
|
/// \param indices Tensor with indexes to gather
|
||||||
/// \param axis The tensor is a dimension index to gather data from
|
/// \param axis The tensor is a dimension index to gather data from
|
||||||
Gather(const Output<Node>& params,
|
Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axis);
|
||||||
const Output<Node>& indices,
|
|
||||||
const Output<Node>& axis);
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
int64_t get_axis() const override;
|
int64_t get_axis() const override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
|
|
||||||
namespace v7
|
namespace v7 {
|
||||||
{
|
|
||||||
/// \brief Gather slices from axis of params according to indices
|
/// \brief Gather slices from axis of params according to indices
|
||||||
class NGRAPH_API Gather : public op::util::GatherBase
|
class NGRAPH_API Gather : public op::util::GatherBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
Gather() = default;
|
Gather() = default;
|
||||||
@ -57,16 +48,13 @@ namespace ngraph
|
|||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
int64_t get_batch_dims() const;
|
int64_t get_batch_dims() const;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v7
|
} // namespace v7
|
||||||
|
|
||||||
namespace v8
|
namespace v8 {
|
||||||
{
|
|
||||||
/// \brief Gather slices from axis of params according to indices
|
/// \brief Gather slices from axis of params according to indices
|
||||||
class NGRAPH_API Gather : public op::util::GatherBase
|
class NGRAPH_API Gather : public op::util::GatherBase {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
Gather() = default;
|
Gather() = default;
|
||||||
@ -84,8 +72,7 @@ namespace ngraph
|
|||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
int64_t get_batch_dims() const;
|
int64_t get_batch_dims() const;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v8
|
} // namespace v8
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -6,16 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v6 {
|
||||||
{
|
|
||||||
namespace v6
|
|
||||||
{
|
|
||||||
/// \brief GatherElements operation
|
/// \brief GatherElements operation
|
||||||
///
|
///
|
||||||
class NGRAPH_API GatherElements : public Op
|
class NGRAPH_API GatherElements : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
GatherElements() = default;
|
GatherElements() = default;
|
||||||
@ -25,16 +21,15 @@ namespace ngraph
|
|||||||
/// \param data Node producing data that are gathered
|
/// \param data Node producing data that are gathered
|
||||||
/// \param indices Node producing indices by which the operation gathers elements
|
/// \param indices Node producing indices by which the operation gathers elements
|
||||||
/// \param axis specifies axis along which indices are specified
|
/// \param axis specifies axis along which indices are specified
|
||||||
GatherElements(const Output<Node>& data,
|
GatherElements(const Output<Node>& data, const Output<Node>& indices, const int64_t axis);
|
||||||
const Output<Node>& indices,
|
|
||||||
const int64_t axis);
|
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
int64_t get_axis() const { return m_axis; }
|
int64_t get_axis() const {
|
||||||
|
return m_axis;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int64_t m_axis;
|
int64_t m_axis;
|
||||||
|
@ -6,16 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v5 {
|
||||||
{
|
|
||||||
namespace v5
|
|
||||||
{
|
|
||||||
/// \brief GatherND operation
|
/// \brief GatherND operation
|
||||||
///
|
///
|
||||||
class NGRAPH_API GatherND : public Op
|
class NGRAPH_API GatherND : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
GatherND() = default;
|
GatherND() = default;
|
||||||
@ -26,16 +22,15 @@ namespace ngraph
|
|||||||
/// \param indices Node producing indices by which the operation gathers elements
|
/// \param indices Node producing indices by which the operation gathers elements
|
||||||
/// or slices from data
|
/// or slices from data
|
||||||
/// \param batch_dims Specifies a number of batch dimensions
|
/// \param batch_dims Specifies a number of batch dimensions
|
||||||
GatherND(const Output<Node>& data,
|
GatherND(const Output<Node>& data, const Output<Node>& indices, const size_t batch_dims = 0);
|
||||||
const Output<Node>& indices,
|
|
||||||
const size_t batch_dims = 0);
|
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
size_t get_batch_dims() const { return m_batch_dims; }
|
size_t get_batch_dims() const {
|
||||||
|
return m_batch_dims;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t m_batch_dims;
|
size_t m_batch_dims;
|
||||||
|
@ -6,16 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
namespace v1 {
|
||||||
{
|
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
/// \brief Generates the complete beams from the ids per each step and the parent beam
|
/// \brief Generates the complete beams from the ids per each step and the parent beam
|
||||||
/// ids.
|
/// ids.
|
||||||
class NGRAPH_API GatherTree : public Op
|
class NGRAPH_API GatherTree : public Op {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -35,8 +31,7 @@ namespace ngraph
|
|||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -9,20 +9,18 @@
|
|||||||
#include "ngraph/op/util/fused_op.hpp"
|
#include "ngraph/op/util/fused_op.hpp"
|
||||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph {
|
||||||
{
|
namespace op {
|
||||||
namespace op
|
|
||||||
{
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
namespace v0
|
namespace v0 {
|
||||||
{
|
|
||||||
/// \brief Gaussian Error Linear Unit
|
/// \brief Gaussian Error Linear Unit
|
||||||
/// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) )
|
/// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) )
|
||||||
class NGRAPH_API Gelu : public ngraph::op::util::FusedOp
|
class NGRAPH_API Gelu : public ngraph::op::util::FusedOp {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
static constexpr NodeTypeInfo type_info{"Gelu", 0};
|
static constexpr NodeTypeInfo type_info{"Gelu", 0};
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
const NodeTypeInfo& get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
Gelu();
|
Gelu();
|
||||||
/// \brief Constructs a Gelu operation.
|
/// \brief Constructs a Gelu operation.
|
||||||
///
|
///
|
||||||
@ -34,29 +32,22 @@ namespace ngraph
|
|||||||
|
|
||||||
void pre_validate_and_infer_types() override;
|
void pre_validate_and_infer_types() override;
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
using v0::Gelu;
|
using v0::Gelu;
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
/// \brief Specifies the approximation to calculate Gelu
|
/// \brief Specifies the approximation to calculate Gelu
|
||||||
enum class GeluApproximationMode
|
enum class GeluApproximationMode { TANH, ERF };
|
||||||
{
|
|
||||||
TANH,
|
|
||||||
ERF
|
|
||||||
};
|
|
||||||
NGRAPH_API std::ostream& operator<<(std::ostream& s, const GeluApproximationMode& type);
|
NGRAPH_API std::ostream& operator<<(std::ostream& s, const GeluApproximationMode& type);
|
||||||
|
|
||||||
namespace v7
|
namespace v7 {
|
||||||
{
|
|
||||||
/// \brief Gaussian Error Linear Unit
|
/// \brief Gaussian Error Linear Unit
|
||||||
/// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) ) for "approximation" = "erf"
|
/// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) ) for "approximation" = "erf"
|
||||||
/// f(x) = 0.5 * x * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3]) for "approximation" =
|
/// f(x) = 0.5 * x * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3]) for "approximation" =
|
||||||
/// "tanh"
|
/// "tanh"
|
||||||
class NGRAPH_API Gelu : public util::UnaryElementwiseArithmetic
|
class NGRAPH_API Gelu : public util::UnaryElementwiseArithmetic {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
@ -65,19 +56,16 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param data Input tensor
|
/// \param data Input tensor
|
||||||
/// \param mode Approximation mode
|
/// \param mode Approximation mode
|
||||||
Gelu(const Output<Node>& data,
|
Gelu(const Output<Node>& data, GeluApproximationMode mode = GeluApproximationMode::ERF);
|
||||||
GeluApproximationMode mode = GeluApproximationMode::ERF);
|
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs,
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
const HostTensorVector& inputs) const override;
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
|
|
||||||
std::shared_ptr<Node>
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
|
|
||||||
GeluApproximationMode get_approximation_mode() const;
|
GeluApproximationMode get_approximation_mode() const;
|
||||||
|
|
||||||
@ -88,16 +76,13 @@ namespace ngraph
|
|||||||
} // namespace op
|
} // namespace op
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<op::GeluApproximationMode>
|
class NGRAPH_API AttributeAdapter<op::GeluApproximationMode>
|
||||||
: public EnumAttributeAdapterBase<op::GeluApproximationMode>
|
: public EnumAttributeAdapterBase<op::GeluApproximationMode> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(op::GeluApproximationMode& value)
|
AttributeAdapter(op::GeluApproximationMode& value) : EnumAttributeAdapterBase<op::GeluApproximationMode>(value) {}
|
||||||
: EnumAttributeAdapterBase<op::GeluApproximationMode>(value)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::GeluApproximationMode>",
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::GeluApproximationMode>", 0};
|
||||||
0};
|
const DiscreteTypeInfo& get_type_info() const override {
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
return type_info;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user