nGraph shell of operations RDFT and IRDFT (#10353)

* Written header files for the nGraph operations RDFT and IRDFT.

* Written nGraph shell for the operation RDFT.

* Added missed include.

* Added RDFT to opset9 table.

* Code style fixes.

* Written the nGraph shell of the operation IRDFT.

* Added IRDFT to opset9 table.

* Started to write shape infer tests for RDFT.

* Refactoring: shape infer functions of RDFT and IRDFT moved into separate files.

* Written shape infer tests for RDFT.

* Written shape infer tests for IRDFT operation.

* Fixed code style.

* Fixes in the shape infer function of RDFT.

* Fixes in the shape infer function of RDFT.

* Fixes in the shape infer function of IRDFT.

* Deleted redundant includes in include/ngraph/op/irdft.hpp and include/ngraph/op/rdft.hpp

* Deleted redundant includes in include/openvino/op/rdft.hpp and include/openvino/op/irdft.hpp.

* Deleted redundant includes in cpp-files of nGraph shells of operations IRDFT and RDFT.

* Code style fixes.

* Shape inference functions of operations RDFT and IRDFT moved to the namespace ov::op::util.

* Deleted RDFT and IRDFT from docs/template_plugin/backend/opset_int_tbl.hpp.

* Deleted 'using namespace ngraph' from cpp-files of nGraph shells of operations RDFT and IRDFT.

* Fixed typos.

* Merged some loops in shape inference functions of RDFT and IRDFT.

* Written visitor tests for RDFT and IRDFT.

* Small change.

* Common part of RDFT and IRDFT shape validation moved into the separate file.

Co-authored-by: Ilya Churaev <ilya.churaev@intel.com>
This commit is contained in:
Vladimir Gavrilov 2022-03-16 14:54:32 +03:00 committed by GitHub
parent 097006d97a
commit f7875da083
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1445 additions and 3 deletions

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/irdft.hpp"
namespace ngraph {
namespace op {
namespace v9 {
using ov::op::v9::IRDFT;
} // namespace v9
} // namespace op
} // namespace ngraph

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/rdft.hpp"
namespace ngraph {
namespace op {
namespace v9 {
using ov::op::v9::RDFT;
} // namespace v9
} // namespace op
} // namespace ngraph

View File

@ -79,6 +79,7 @@
#include "ngraph/op/idft.hpp" #include "ngraph/op/idft.hpp"
#include "ngraph/op/if.hpp" #include "ngraph/op/if.hpp"
#include "ngraph/op/interpolate.hpp" #include "ngraph/op/interpolate.hpp"
#include "ngraph/op/irdft.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp" #include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp" #include "ngraph/op/log.hpp"
@ -119,6 +120,7 @@
#include "ngraph/op/psroi_pooling.hpp" #include "ngraph/op/psroi_pooling.hpp"
#include "ngraph/op/random_uniform.hpp" #include "ngraph/op/random_uniform.hpp"
#include "ngraph/op/range.hpp" #include "ngraph/op/range.hpp"
#include "ngraph/op/rdft.hpp"
#include "ngraph/op/read_value.hpp" #include "ngraph/op/read_value.hpp"
#include "ngraph/op/reduce_l1.hpp" #include "ngraph/op/reduce_l1.hpp"
#include "ngraph/op/reduce_l2.hpp" #include "ngraph/op/reduce_l2.hpp"

View File

@ -60,4 +60,5 @@ const NGRAPH_API OpSet& get_opset5();
const NGRAPH_API OpSet& get_opset6(); const NGRAPH_API OpSet& get_opset6();
const NGRAPH_API OpSet& get_opset7(); const NGRAPH_API OpSet& get_opset7();
const NGRAPH_API OpSet& get_opset8(); const NGRAPH_API OpSet& get_opset8();
const NGRAPH_API OpSet& get_opset9();
} // namespace ngraph } // namespace ngraph

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "ngraph/ops.hpp"
namespace ngraph {
namespace opset9 {
#define NGRAPH_OP(a, b) using b::a;
#include "ngraph/opsets/opset9_tbl.hpp"
#undef NGRAPH_OP
} // namespace opset9
} // namespace ngraph

View File

@ -0,0 +1,12 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#ifndef NGRAPH_OP
# warning "NGRAPH_OP not defined"
# define NGRAPH_OP(x, y)
#endif
#define _OPENVINO_OP_REG NGRAPH_OP
#include "openvino/opsets/opset9_tbl.hpp"
#undef _OPENVINO_OP_REG

View File

@ -0,0 +1,40 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/op.hpp"
#include "openvino/op/util/fft_base.hpp"
namespace ov {
namespace op {
namespace v9 {
/// \brief An operation IRDFT that computes the discrete inverse complex-to-real Fourier transformation.
class OPENVINO_API IRDFT : public util::FFTBase {
public:
OPENVINO_OP("IRDFT", "opset9", util::FFTBase);
BWDCMP_RTTI_DECLARATION;
IRDFT() = default;
/// \brief Constructs a IRDFT operation. IRDFT is performed for full size axes.
///
/// \param data Input data
/// \param axes Axes to perform IRDFT
IRDFT(const Output<Node>& data, const Output<Node>& axes);
/// \brief Constructs a IRDFT operation.
///
/// \param data Input data
/// \param axes Axes to perform IRDFT
/// \param signal_size Signal sizes for 'axes'
IRDFT(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v9
} // namespace op
} // namespace ov

View File

@ -78,6 +78,7 @@
#include "openvino/op/idft.hpp" #include "openvino/op/idft.hpp"
#include "openvino/op/if.hpp" #include "openvino/op/if.hpp"
#include "openvino/op/interpolate.hpp" #include "openvino/op/interpolate.hpp"
#include "openvino/op/irdft.hpp"
#include "openvino/op/less.hpp" #include "openvino/op/less.hpp"
#include "openvino/op/less_eq.hpp" #include "openvino/op/less_eq.hpp"
#include "openvino/op/log.hpp" #include "openvino/op/log.hpp"
@ -118,6 +119,7 @@
#include "openvino/op/psroi_pooling.hpp" #include "openvino/op/psroi_pooling.hpp"
#include "openvino/op/random_uniform.hpp" #include "openvino/op/random_uniform.hpp"
#include "openvino/op/range.hpp" #include "openvino/op/range.hpp"
#include "openvino/op/rdft.hpp"
#include "openvino/op/read_value.hpp" #include "openvino/op/read_value.hpp"
#include "openvino/op/reduce_l1.hpp" #include "openvino/op/reduce_l1.hpp"
#include "openvino/op/reduce_l2.hpp" #include "openvino/op/reduce_l2.hpp"

View File

@ -0,0 +1,40 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/op.hpp"
#include "openvino/op/util/fft_base.hpp"
namespace ov {
namespace op {
namespace v9 {
/// \brief An operation RDFT that computes the discrete real-to-complex Fourier transformation.
class OPENVINO_API RDFT : public util::FFTBase {
public:
OPENVINO_OP("RDFT", "opset9", util::FFTBase);
BWDCMP_RTTI_DECLARATION;
RDFT() = default;
/// \brief Constructs a RDFT operation. RDFT is performed for full size axes.
///
/// \param data Input data
/// \param axes Axes to perform RDFT
RDFT(const Output<Node>& data, const Output<Node>& axes);
/// \brief Constructs a RDFT operation.
///
/// \param data Input data
/// \param axes Axes to perform RDFT
/// \param signal_size Signal sizes for 'axes'
RDFT(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v9
} // namespace op
} // namespace ov

View File

@ -33,6 +33,9 @@ protected:
/// \param axes Axes to perform FFT /// \param axes Axes to perform FFT
/// \param signal_size Signal sizes for 'axes' /// \param signal_size Signal sizes for 'axes'
FFTBase(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size); FFTBase(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size);
/// \brief Validates input data types of FFT operation.
void validate_types();
}; };
} // namespace util } // namespace util
} // namespace op } // namespace op

View File

@ -136,4 +136,5 @@ const OPENVINO_API OpSet& get_opset5();
const OPENVINO_API OpSet& get_opset6(); const OPENVINO_API OpSet& get_opset6();
const OPENVINO_API OpSet& get_opset7(); const OPENVINO_API OpSet& get_opset7();
const OPENVINO_API OpSet& get_opset8(); const OPENVINO_API OpSet& get_opset8();
const OPENVINO_API OpSet& get_opset9();
} // namespace ov } // namespace ov

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/ops.hpp"
namespace ov {
namespace opset9 {
#define _OPENVINO_OP_REG(a, b) using b::a;
#include "openvino/opsets/opset9_tbl.hpp"
#undef _OPENVINO_OP_REG
} // namespace opset9
} // namespace ov

View File

@ -0,0 +1,194 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#ifndef _OPENVINO_OP_REG
# warning "_OPENVINO_OP_REG not defined"
# define _OPENVINO_OP_REG(x, y)
#endif
_OPENVINO_OP_REG(Abs, ov::op::v0)
_OPENVINO_OP_REG(Acos, ov::op::v0)
_OPENVINO_OP_REG(Add, ov::op::v1)
_OPENVINO_OP_REG(Asin, ov::op::v0)
_OPENVINO_OP_REG(Atan, ov::op::v0)
_OPENVINO_OP_REG(AvgPool, ov::op::v1)
_OPENVINO_OP_REG(BatchNormInference, ov::op::v5)
_OPENVINO_OP_REG(BinaryConvolution, ov::op::v1)
_OPENVINO_OP_REG(Broadcast, ov::op::v3)
_OPENVINO_OP_REG(Bucketize, ov::op::v3)
_OPENVINO_OP_REG(CTCGreedyDecoder, ov::op::v0)
_OPENVINO_OP_REG(Ceiling, ov::op::v0)
_OPENVINO_OP_REG(Clamp, ov::op::v0)
_OPENVINO_OP_REG(Concat, ov::op::v0)
_OPENVINO_OP_REG(Constant, ov::op::v0)
_OPENVINO_OP_REG(Convert, ov::op::v0)
_OPENVINO_OP_REG(ConvertLike, ov::op::v1)
_OPENVINO_OP_REG(Convolution, ov::op::v1)
_OPENVINO_OP_REG(ConvolutionBackpropData, ov::op::v1)
_OPENVINO_OP_REG(Cos, ov::op::v0)
_OPENVINO_OP_REG(Cosh, ov::op::v0)
_OPENVINO_OP_REG(CumSum, ov::op::v0)
_OPENVINO_OP_REG(DeformablePSROIPooling, ov::op::v1)
_OPENVINO_OP_REG(DepthToSpace, ov::op::v0)
_OPENVINO_OP_REG(Divide, ov::op::v1)
_OPENVINO_OP_REG(Elu, ov::op::v0)
_OPENVINO_OP_REG(Erf, ov::op::v0)
_OPENVINO_OP_REG(Equal, ov::op::v1)
_OPENVINO_OP_REG(Exp, ov::op::v0)
_OPENVINO_OP_REG(ExtractImagePatches, ov::op::v3)
_OPENVINO_OP_REG(FakeQuantize, ov::op::v0)
_OPENVINO_OP_REG(Floor, ov::op::v0)
_OPENVINO_OP_REG(FloorMod, ov::op::v1)
_OPENVINO_OP_REG(GatherTree, ov::op::v1)
_OPENVINO_OP_REG(Greater, ov::op::v1)
_OPENVINO_OP_REG(GreaterEqual, ov::op::v1)
_OPENVINO_OP_REG(GroupConvolution, ov::op::v1)
_OPENVINO_OP_REG(GroupConvolutionBackpropData, ov::op::v1)
_OPENVINO_OP_REG(GRN, ov::op::v0)
_OPENVINO_OP_REG(HardSigmoid, ov::op::v0)
_OPENVINO_OP_REG(Less, ov::op::v1)
_OPENVINO_OP_REG(LessEqual, ov::op::v1)
_OPENVINO_OP_REG(Log, ov::op::v0)
_OPENVINO_OP_REG(LogicalAnd, ov::op::v1)
_OPENVINO_OP_REG(LogicalNot, ov::op::v1)
_OPENVINO_OP_REG(LogicalOr, ov::op::v1)
_OPENVINO_OP_REG(LogicalXor, ov::op::v1)
_OPENVINO_OP_REG(LRN, ov::op::v0)
_OPENVINO_OP_REG(LSTMCell, ov::op::v4)
_OPENVINO_OP_REG(MatMul, ov::op::v0)
_OPENVINO_OP_REG(Maximum, ov::op::v1)
_OPENVINO_OP_REG(Minimum, ov::op::v1)
_OPENVINO_OP_REG(Mod, ov::op::v1)
_OPENVINO_OP_REG(Multiply, ov::op::v1)
_OPENVINO_OP_REG(Negative, ov::op::v0)
_OPENVINO_OP_REG(NormalizeL2, ov::op::v0)
_OPENVINO_OP_REG(NotEqual, ov::op::v1)
_OPENVINO_OP_REG(OneHot, ov::op::v1)
_OPENVINO_OP_REG(PRelu, ov::op::v0)
_OPENVINO_OP_REG(PSROIPooling, ov::op::v0)
_OPENVINO_OP_REG(Pad, ov::op::v1)
_OPENVINO_OP_REG(Parameter, ov::op::v0)
_OPENVINO_OP_REG(Power, ov::op::v1)
_OPENVINO_OP_REG(PriorBoxClustered, ov::op::v0)
_OPENVINO_OP_REG(Proposal, ov::op::v4)
_OPENVINO_OP_REG(Range, ov::op::v4)
_OPENVINO_OP_REG(Relu, ov::op::v0)
_OPENVINO_OP_REG(ReduceMax, ov::op::v1)
_OPENVINO_OP_REG(ReduceLogicalAnd, ov::op::v1)
_OPENVINO_OP_REG(ReduceLogicalOr, ov::op::v1)
_OPENVINO_OP_REG(ReduceMean, ov::op::v1)
_OPENVINO_OP_REG(ReduceMin, ov::op::v1)
_OPENVINO_OP_REG(ReduceProd, ov::op::v1)
_OPENVINO_OP_REG(ReduceSum, ov::op::v1)
_OPENVINO_OP_REG(RegionYolo, ov::op::v0)
_OPENVINO_OP_REG(ReorgYolo, ov::op::v0)
_OPENVINO_OP_REG(Reshape, ov::op::v1)
_OPENVINO_OP_REG(Result, ov::op::v0)
_OPENVINO_OP_REG(ReverseSequence, ov::op::v0)
_OPENVINO_OP_REG(ROIPooling, ov::op::v0)
_OPENVINO_OP_REG(ScatterNDUpdate, ov::op::v3)
_OPENVINO_OP_REG(Select, ov::op::v1)
_OPENVINO_OP_REG(Selu, ov::op::v0)
_OPENVINO_OP_REG(Sign, ov::op::v0)
_OPENVINO_OP_REG(Sigmoid, ov::op::v0)
_OPENVINO_OP_REG(Sin, ov::op::v0)
_OPENVINO_OP_REG(Sinh, ov::op::v0)
_OPENVINO_OP_REG(Sqrt, ov::op::v0)
_OPENVINO_OP_REG(SpaceToDepth, ov::op::v0)
_OPENVINO_OP_REG(Split, ov::op::v1)
_OPENVINO_OP_REG(SquaredDifference, ov::op::v0)
_OPENVINO_OP_REG(Squeeze, ov::op::v0)
_OPENVINO_OP_REG(StridedSlice, ov::op::v1)
_OPENVINO_OP_REG(Subtract, ov::op::v1)
_OPENVINO_OP_REG(Tan, ov::op::v0)
_OPENVINO_OP_REG(Tanh, ov::op::v0)
_OPENVINO_OP_REG(TensorIterator, ov::op::v0)
_OPENVINO_OP_REG(Tile, ov::op::v0)
_OPENVINO_OP_REG(Transpose, ov::op::v1)
_OPENVINO_OP_REG(Unsqueeze, ov::op::v0)
_OPENVINO_OP_REG(VariadicSplit, ov::op::v1)
// New operations added in opset2
_OPENVINO_OP_REG(BatchToSpace, ov::op::v1)
_OPENVINO_OP_REG(SpaceToBatch, ov::op::v1)
// New operations added in opset3
_OPENVINO_OP_REG(EmbeddingBagPackedSum, ov::op::v3)
_OPENVINO_OP_REG(EmbeddingSegmentsSum, ov::op::v3)
_OPENVINO_OP_REG(EmbeddingBagOffsetsSum, ov::op::v3)
_OPENVINO_OP_REG(GRUCell, ov::op::v3)
_OPENVINO_OP_REG(NonZero, ov::op::v3)
_OPENVINO_OP_REG(RNNCell, ov::op::v0)
_OPENVINO_OP_REG(ROIAlign, ov::op::v3)
_OPENVINO_OP_REG(ScatterElementsUpdate, ov::op::v3)
_OPENVINO_OP_REG(ScatterUpdate, ov::op::v3)
_OPENVINO_OP_REG(ShuffleChannels, ov::op::v0)
_OPENVINO_OP_REG(ShapeOf, ov::op::v3)
_OPENVINO_OP_REG(TopK, ov::op::v3)
// New operations added in opset4
_OPENVINO_OP_REG(Acosh, ov::op::v3)
_OPENVINO_OP_REG(Asinh, ov::op::v3)
_OPENVINO_OP_REG(Atanh, ov::op::v3)
_OPENVINO_OP_REG(CTCLoss, ov::op::v4)
_OPENVINO_OP_REG(HSwish, ov::op::v4)
_OPENVINO_OP_REG(Interpolate, ov::op::v4)
_OPENVINO_OP_REG(Mish, ov::op::v4)
_OPENVINO_OP_REG(ReduceL1, ov::op::v4)
_OPENVINO_OP_REG(ReduceL2, ov::op::v4)
_OPENVINO_OP_REG(SoftPlus, ov::op::v4)
_OPENVINO_OP_REG(Swish, ov::op::v4)
// New operations added in opset5
_OPENVINO_OP_REG(GRUSequence, ov::op::v5)
_OPENVINO_OP_REG(HSigmoid, ov::op::v5)
_OPENVINO_OP_REG(LogSoftmax, ov::op::v5)
_OPENVINO_OP_REG(Loop, ov::op::v5)
_OPENVINO_OP_REG(LSTMSequence, ov::op::v5)
_OPENVINO_OP_REG(NonMaxSuppression, ov::op::v5)
_OPENVINO_OP_REG(RNNSequence, ov::op::v5)
_OPENVINO_OP_REG(Round, ov::op::v5)
// New operations added in opset6
_OPENVINO_OP_REG(CTCGreedyDecoderSeqLen, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronDetectionOutput, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronGenerateProposalsSingleImage, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronPriorGridGenerator, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronROIFeatureExtractor, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronTopKROIs, ov::op::v6)
_OPENVINO_OP_REG(GatherElements, ov::op::v6)
_OPENVINO_OP_REG(MVN, ov::op::v6)
_OPENVINO_OP_REG(Assign, ov::op::v6) // new version
_OPENVINO_OP_REG(ReadValue, ov::op::v6) // new version
// New operations added in opset7
_OPENVINO_OP_REG(DFT, ov::op::v7)
_OPENVINO_OP_REG(Einsum, ov::op::v7)
_OPENVINO_OP_REG(Gelu, ov::op::v7)
_OPENVINO_OP_REG(IDFT, ov::op::v7)
_OPENVINO_OP_REG(Roll, ov::op::v7)
// New operations added in opset8
_OPENVINO_OP_REG(Gather, ov::op::v8)
_OPENVINO_OP_REG(GatherND, ov::op::v8)
_OPENVINO_OP_REG(AdaptiveAvgPool, ov::op::v8)
_OPENVINO_OP_REG(AdaptiveMaxPool, ov::op::v8)
_OPENVINO_OP_REG(DeformableConvolution, ov::op::v8)
_OPENVINO_OP_REG(DetectionOutput, ov::op::v8)
_OPENVINO_OP_REG(I420toBGR, ov::op::v8)
_OPENVINO_OP_REG(I420toRGB, ov::op::v8)
_OPENVINO_OP_REG(MatrixNms, ov::op::v8)
_OPENVINO_OP_REG(MaxPool, ov::op::v8)
_OPENVINO_OP_REG(MulticlassNms, ov::op::v8)
_OPENVINO_OP_REG(NV12toBGR, ov::op::v8)
_OPENVINO_OP_REG(NV12toRGB, ov::op::v8)
_OPENVINO_OP_REG(RandomUniform, ov::op::v8)
_OPENVINO_OP_REG(Slice, ov::op::v8)
_OPENVINO_OP_REG(Softmax, ov::op::v8)
_OPENVINO_OP_REG(If, ov::op::v8)
_OPENVINO_OP_REG(PriorBox, ov::op::v8)
// New operations added in opset9
_OPENVINO_OP_REG(IRDFT, ov::op::v9)
_OPENVINO_OP_REG(RDFT, ov::op::v9)

View File

@ -7,6 +7,9 @@
#include "openvino/core/axis_vector.hpp" #include "openvino/core/axis_vector.hpp"
#include "utils.hpp" #include "utils.hpp"
namespace ov {
namespace op {
namespace util {
template <class T> template <class T>
void shape_infer(const ov::op::util::FFTBase* op, void shape_infer(const ov::op::util::FFTBase* op,
const std::vector<T>& input_shapes, const std::vector<T>& input_shapes,
@ -114,4 +117,7 @@ void shape_infer(const ov::op::util::FFTBase* op,
output_shape[i] = ov::Dimension::dynamic(); output_shape[i] = ov::Dimension::dynamic();
} }
} }
} }
} // namespace util
} // namespace op
} // namespace ov

View File

@ -0,0 +1,75 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/op/irdft.hpp>
#include "openvino/core/axis_vector.hpp"
#include "rfft_common_validation.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace util {
template <class T>
void irdft_shape_infer(const ov::op::v9::IRDFT* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2 || input_shapes.size() == 3) && output_shapes.size() == 1);
const auto& input_shape = input_shapes[0];
const auto& axes_shape = input_shapes[1];
auto& output_shape = output_shapes[0];
std::vector<int64_t> axes;
bool axes_are_known = get_data_as_int64<T>(1, op, axes, constant_data);
rfft_common_validation::shape_validation(op, input_shapes, axes, axes_are_known, rfft_common_validation::RFFTKind::Inverse);
if (input_shape.rank().is_dynamic()) {
output_shape = ov::PartialShape::dynamic();
return;
}
const auto input_rank = input_shape.size();
output_shape = input_shape;
output_shape.resize(input_rank - 1);
if (axes_shape.rank().is_dynamic() || !axes_are_known) {
for (int64_t i = 0; i < input_rank - 1; ++i) {
output_shape[i] = ov::Dimension::dynamic();
}
return;
}
const auto last_axis = axes.back();
if (input_shapes.size() == 2) {
output_shape[last_axis] = DimType(2) * (input_shape[last_axis] - DimType(1));
return;
}
const auto& signal_size_shape = input_shapes[2];
std::vector<int64_t> signal_size;
bool status_signal_size = get_data_as_int64<T>(2, op, signal_size, constant_data);
if (signal_size_shape.rank().is_dynamic() || !status_signal_size) {
output_shape[last_axis] = ov::Dimension::dynamic();
return;
}
size_t num_of_axes = axes.size();
for (size_t i = 0; i < num_of_axes; ++i) {
if (signal_size[i] != -1) {
output_shape[axes[i]] = DimType(signal_size[i]);
}
}
if (signal_size.back() == -1) {
output_shape[last_axis] = DimType(2) * (input_shape[last_axis] - DimType(1));
}
}
} // namespace util
} // namespace op
} // namespace ov

View File

@ -0,0 +1,87 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/op/rdft.hpp>
#include "openvino/core/axis_vector.hpp"
#include "rfft_common_validation.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace util {
template <typename B>
B get_ouput_dimension_bound(B b) {
if (b <= 0) {
return b;
}
return b / 2 + 1;
}
template <class DimType>
DimType get_rdft_output_dimension(DimType d) {
return DimType(get_ouput_dimension_bound(d.get_min_length()), get_ouput_dimension_bound(d.get_max_length()));
}
template <class T>
void rdft_shape_infer(const ov::op::v9::RDFT* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2 || input_shapes.size() == 3) && output_shapes.size() == 1);
const auto& input_shape = input_shapes[0];
const auto& axes_shape = input_shapes[1];
auto& output_shape = output_shapes[0];
std::vector<int64_t> axes;
bool axes_are_known = get_data_as_int64<T>(1, op, axes, constant_data);
rfft_common_validation::shape_validation(op, input_shapes, axes, axes_are_known, rfft_common_validation::RFFTKind::Forward);
if (input_shape.rank().is_dynamic()) {
output_shape = ov::PartialShape::dynamic();
return;
}
output_shape = input_shape;
output_shape.push_back(DimType(2));
const auto input_rank = input_shape.size();
if (axes_shape.rank().is_dynamic() || !axes_are_known) {
for (int64_t i = 0; i < input_rank; ++i) {
output_shape[i] = ov::Dimension::dynamic();
}
return;
}
const auto last_axis = axes.back();
if (input_shapes.size() == 2) {
output_shape[last_axis] = get_rdft_output_dimension(input_shape[last_axis]);
return;
}
const auto& signal_size_shape = input_shapes[2];
std::vector<int64_t> signal_size;
bool status_signal_size = get_data_as_int64<T>(2, op, signal_size, constant_data);
if (signal_size_shape.rank().is_dynamic() || !status_signal_size) {
output_shape[last_axis] = ov::Dimension::dynamic();
return;
}
size_t num_of_axes = axes.size();
for (size_t i = 0; i < num_of_axes; ++i) {
const int64_t current_axis = axes[i];
if (signal_size[i] != -1) {
output_shape[current_axis] = DimType(signal_size[i]);
}
}
output_shape[last_axis] = get_rdft_output_dimension(output_shape[last_axis]);
}
} // namespace util
} // namespace op
} // namespace ov

View File

@ -0,0 +1,143 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/op/util/fft_base.hpp>
#include "openvino/core/axis_vector.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace util {
namespace rfft_common_validation {
enum class RFFTKind { Forward, Inverse };
template <class T>
void validate_input_rank(const ov::op::util::FFTBase* op,
const T& input_shape,
const T& axes_shape,
size_t input_rank,
RFFTKind rfft_kind) {
const size_t min_rank = (rfft_kind == RFFTKind::Forward) ? 1 : 2;
NODE_VALIDATION_CHECK(op,
input_rank >= min_rank,
"The input rank must be greater or equal to ",
min_rank,
". Got input rank: ",
input_rank);
if (rfft_kind == RFFTKind::Inverse) {
NODE_VALIDATION_CHECK(op,
input_shape[input_rank - 1].compatible(2),
"The last dimension of input data must be 2. Got: ",
input_shape[input_rank - 1]);
}
if (axes_shape.is_dynamic()) {
return;
}
if (rfft_kind == RFFTKind::Forward) {
NODE_VALIDATION_CHECK(op,
input_rank >= static_cast<int64_t>(axes_shape[0].get_length()),
"The input rank must be greater than or equal to the number of RDFT op axes. "
"Got input rank: ",
input_rank,
", number of axes: ",
axes_shape[0].get_length());
} else {
NODE_VALIDATION_CHECK(op,
input_rank >= static_cast<int64_t>(axes_shape[0].get_length() + 1),
"The input rank must be greater than number of IRDFT op axes. Got "
"input rank: ",
input_rank,
", number of axes: ",
axes_shape[0].get_length());
}
}
template <class T>
void validate_axes(const ov::op::util::FFTBase* op,
const T& axes_shape,
std::vector<int64_t>& axes,
size_t input_rank,
bool axes_are_known,
RFFTKind rfft_kind) {
if (axes_shape.rank().is_dynamic() || !axes_are_known) {
return;
}
// IRDFT operation supports negative axes to transform. More precisely, according to
// the IRDFT operation specification, axes should be integers from -(r - 1) to (r - 2)
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
// [n_0, ..., n_{r - 1}].
//
// But RDFT operation supports negative axes to transform in other sense. More precisely,
// according to the RDFT operation specification, axes should be integers from -r to (r - 1)
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis 'r + a'.
const int64_t axis_correction = (rfft_kind == RFFTKind::Forward) ? input_rank : (input_rank - 1);
ov::AxisSet axes_set;
for (int64_t& axis : axes) {
if (axis < 0) {
axis += axis_correction;
}
axes_set.insert(static_cast<size_t>(axis));
}
NODE_VALIDATION_CHECK(op, axes.size() == axes_set.size(), "(I)RDFT op axes must be unique.");
if (rfft_kind == RFFTKind::Inverse) {
NODE_VALIDATION_CHECK(op,
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
"IRDFT op axes cannot contain the last axis.");
}
}
template <class T>
void validate_signal_size(const ov::op::util::FFTBase* op,
const T& axes_shape,
const T& signal_size_shape) {
NODE_VALIDATION_CHECK(op,
signal_size_shape.rank().compatible(1),
"(I)RDFT op signal size input must be 1D tensor. Got signal: ",
signal_size_shape);
if (axes_shape.is_static() && signal_size_shape.is_static()) {
NODE_VALIDATION_CHECK(op,
axes_shape[0].compatible(signal_size_shape[0]),
"Sizes of inputs 'axes' and 'signal_size' of (I)RDFT op must be equal. "
"Got size of 'axes': ",
axes_shape[0],
"size of 'signal_size': ",
signal_size_shape[0]);
}
}
template <class T>
void shape_validation(const ov::op::util::FFTBase* op,
const std::vector<T>& input_shapes,
std::vector<int64_t>& axes,
bool axes_are_known,
RFFTKind rfft_kind) {
const auto& input_shape = input_shapes[0];
const auto& axes_shape = input_shapes[1];
if (input_shape.rank().is_static()) {
const auto input_rank = input_shape.size();
validate_input_rank(op, input_shape, axes_shape, input_rank, rfft_kind);
validate_axes(op, axes_shape, axes, input_rank, axes_are_known, rfft_kind);
}
NODE_VALIDATION_CHECK(op, axes_shape.rank().compatible(1), "(I)RDFT op axes input must be 1D tensor.");
if (input_shapes.size() == 3) {
const auto& signal_size_shape = input_shapes[2];
validate_signal_size(op, axes_shape, signal_size_shape);
}
}
} // rfft_common_validation
} // namespace util
} // namespace op
} // namespace ov

61
src/core/src/op/irdft.cpp Normal file
View File

@ -0,0 +1,61 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/irdft.hpp"
#include <memory>
#include "irdft_shape_inference.hpp"
#include "itt.hpp"
using namespace std;
BWDCMP_RTTI_DEFINITION(ov::op::v9::IRDFT);
ov::op::v9::IRDFT::IRDFT(const Output<Node>& data, const Output<Node>& axes) : FFTBase(data, axes) {
constructor_validate_and_infer_types();
}
ov::op::v9::IRDFT::IRDFT(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size)
: FFTBase(data, axes, signal_size) {
constructor_validate_and_infer_types();
}
bool ov::op::v9::IRDFT::visit_attributes(AttributeVisitor& visitor) {
NGRAPH_OP_SCOPE(v9_IRDFT_visit_attributes);
return true;
}
std::shared_ptr<ov::Node> ov::op::v9::IRDFT::clone_with_new_inputs(const OutputVector& new_args) const {
NGRAPH_OP_SCOPE(v9_IRDFT_clone_with_new_inputs);
check_new_args_count(this, new_args);
NODE_VALIDATION_CHECK(this, new_args.size() == 2 || new_args.size() == 3, "Number of inputs must be 2 or 3");
if (new_args.size() == 2) {
return std::make_shared<ov::op::v9::IRDFT>(new_args.at(0), new_args.at(1));
}
return std::make_shared<ov::op::v9::IRDFT>(new_args.at(0), new_args.at(1), new_args.at(2));
}
void ov::op::v9::IRDFT::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v9_IRDFT_validate_and_infer_types);
validate_types();
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
std::vector<ov::PartialShape> input_shapes;
const auto& data = get_input_partial_shape(0);
const auto& axes = get_input_partial_shape(1);
if (input_values().size() == 2) {
input_shapes = {data, axes};
} else {
const auto& signal_size = get_input_partial_shape(2);
input_shapes = {data, axes, signal_size};
}
ov::op::util::irdft_shape_infer(this, input_shapes, output_shapes);
set_output_type(0, get_input_element_type(0), output_shapes[0]);
}

61
src/core/src/op/rdft.cpp Normal file
View File

@ -0,0 +1,61 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/rdft.hpp"
#include <memory>
#include "itt.hpp"
#include "rdft_shape_inference.hpp"
using namespace std;
BWDCMP_RTTI_DEFINITION(ov::op::v9::RDFT);
ov::op::v9::RDFT::RDFT(const Output<Node>& data, const Output<Node>& axes) : FFTBase(data, axes) {
constructor_validate_and_infer_types();
}
ov::op::v9::RDFT::RDFT(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size)
: FFTBase(data, axes, signal_size) {
constructor_validate_and_infer_types();
}
bool ov::op::v9::RDFT::visit_attributes(AttributeVisitor& visitor) {
NGRAPH_OP_SCOPE(v9_RDFT_visit_attributes);
return true;
}
std::shared_ptr<ov::Node> ov::op::v9::RDFT::clone_with_new_inputs(const OutputVector& new_args) const {
NGRAPH_OP_SCOPE(v9_RDFT_clone_with_new_inputs);
check_new_args_count(this, new_args);
NODE_VALIDATION_CHECK(this, new_args.size() == 2 || new_args.size() == 3, "Number of inputs must be 2 or 3");
if (new_args.size() == 2) {
return std::make_shared<ov::op::v9::RDFT>(new_args.at(0), new_args.at(1));
}
return std::make_shared<ov::op::v9::RDFT>(new_args.at(0), new_args.at(1), new_args.at(2));
}
void ov::op::v9::RDFT::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v9_RDFT_validate_and_infer_types);
validate_types();
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
std::vector<ov::PartialShape> input_shapes;
const auto& data = get_input_partial_shape(0);
const auto& axes = get_input_partial_shape(1);
if (input_values().size() == 2) {
input_shapes = {data, axes};
} else {
const auto& signal_size = get_input_partial_shape(2);
input_shapes = {data, axes, signal_size};
}
ov::op::util::rdft_shape_infer(this, input_shapes, output_shapes);
set_output_type(0, get_input_element_type(0), output_shapes[0]);
}

View File

@ -24,8 +24,8 @@ bool ov::op::util::FFTBase::visit_attributes(AttributeVisitor& visitor) {
return true; return true;
} }
void ov::op::util::FFTBase::validate_and_infer_types() { void ov::op::util::FFTBase::validate_types() {
NGRAPH_OP_SCOPE(util_FFTBase_validate_and_infer_types); NGRAPH_OP_SCOPE(util_FFTBase_validate_types);
size_t num_of_inputs = get_input_size(); size_t num_of_inputs = get_input_size();
NODE_VALIDATION_CHECK(this, num_of_inputs == 2 || num_of_inputs == 3, "FFT op must have 2 or 3 inputs."); NODE_VALIDATION_CHECK(this, num_of_inputs == 2 || num_of_inputs == 3, "FFT op must have 2 or 3 inputs.");
@ -46,6 +46,12 @@ void ov::op::util::FFTBase::validate_and_infer_types() {
signal_size_et == element::i64 || signal_size_et == element::i32, signal_size_et == element::i64 || signal_size_et == element::i32,
"FFT op signal_size element type must be i32 or i64"); "FFT op signal_size element type must be i32 or i64");
} }
}
void ov::op::util::FFTBase::validate_and_infer_types() {
NGRAPH_OP_SCOPE(util_FFTBase_validate_and_infer_types);
validate_types();
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()}; std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
std::vector<ov::PartialShape> input_shapes; std::vector<ov::PartialShape> input_shapes;

View File

@ -117,6 +117,17 @@ const ov::OpSet& ov::get_opset8() {
return opset; return opset;
} }
const ov::OpSet& ov::get_opset9() {
static OpSet opset;
static std::once_flag flag;
std::call_once(flag, [&]() {
#define _OPENVINO_OP_REG(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "openvino/opsets/opset9_tbl.hpp"
#undef _OPENVINO_OP_REG
});
return opset;
}
const ngraph::OpSet& ngraph::get_opset1() { const ngraph::OpSet& ngraph::get_opset1() {
static OpSet opset(ov::get_opset1()); static OpSet opset(ov::get_opset1());
return opset; return opset;
@ -156,3 +167,8 @@ const ngraph::OpSet& ngraph::get_opset8() {
static OpSet opset(ov::get_opset8()); static OpSet opset(ov::get_opset8());
return opset; return opset;
} }
const ngraph::OpSet& ngraph::get_opset9() {
static OpSet opset(ov::get_opset9());
return opset;
}

View File

@ -165,6 +165,7 @@ set(SRC
type_prop/idft.cpp type_prop/idft.cpp
type_prop/if.cpp type_prop/if.cpp
type_prop/interpolate.cpp type_prop/interpolate.cpp
type_prop/irdft.cpp
type_prop/logical_and.cpp type_prop/logical_and.cpp
type_prop/logical_not.cpp type_prop/logical_not.cpp
type_prop/logical_or.cpp type_prop/logical_or.cpp
@ -197,6 +198,7 @@ set(SRC
type_prop/prior_box_clustered.cpp type_prop/prior_box_clustered.cpp
type_prop/random_uniform.cpp type_prop/random_uniform.cpp
type_prop/range.cpp type_prop/range.cpp
type_prop/rdft.cpp
type_prop/read_value.cpp type_prop/read_value.cpp
type_prop/reduce_l1.cpp type_prop/reduce_l1.cpp
type_prop/reduce_l2.cpp type_prop/reduce_l2.cpp
@ -324,6 +326,7 @@ set(SRC
visitors/op/interpolate.cpp visitors/op/interpolate.cpp
visitors/op/if.cpp visitors/op/if.cpp
visitors/op/idft.cpp visitors/op/idft.cpp
visitors/op/irdft.cpp
visitors/op/less_equal.cpp visitors/op/less_equal.cpp
visitors/op/less.cpp visitors/op/less.cpp
visitors/op/log.cpp visitors/op/log.cpp
@ -360,6 +363,7 @@ set(SRC
visitors/op/proposal.cpp visitors/op/proposal.cpp
visitors/op/psroi_pooling.cpp visitors/op/psroi_pooling.cpp
visitors/op/random_uniform.cpp visitors/op/random_uniform.cpp
visitors/op/rdft.cpp
visitors/op/reduce_l1.cpp visitors/op/reduce_l1.cpp
visitors/op/reduce_l2.cpp visitors/op/reduce_l2.cpp
visitors/op/reduce_logical_and.cpp visitors/op/reduce_logical_and.cpp

View File

@ -0,0 +1,309 @@
//*****************************************************************************
// Copyright 2017-2022 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace ngraph;
struct IRDFTConstantAxesAndConstantSignalSizeTestParams {
PartialShape input_shape;
Shape axes_shape;
Shape signal_size_shape;
PartialShape ref_output_shape;
std::vector<int64_t> axes;
std::vector<int64_t> signal_size;
};
struct IRDFTConstantAxesAndConstantSignalSizeTest
: ::testing::TestWithParam<IRDFTConstantAxesAndConstantSignalSizeTestParams> {};
TEST_P(IRDFTConstantAxesAndConstantSignalSizeTest, irdft_constant_axes_and_signal_size) {
auto params = GetParam();
auto data = std::make_shared<op::Parameter>(element::f32, params.input_shape);
auto axes_input = op::Constant::create<int64_t>(element::i64, params.axes_shape, params.axes);
std::shared_ptr<op::v9::IRDFT> irdft;
if (params.signal_size.empty()) {
irdft = std::make_shared<op::v9::IRDFT>(data, axes_input);
} else {
auto signal_size_input =
op::Constant::create<int64_t>(element::i64, params.signal_size_shape, params.signal_size);
irdft = std::make_shared<op::v9::IRDFT>(data, axes_input, signal_size_input);
}
EXPECT_EQ(irdft->get_element_type(), element::f32);
ASSERT_TRUE(irdft->get_output_partial_shape(0).same_scheme(params.ref_output_shape));
}
INSTANTIATE_TEST_SUITE_P(
type_prop,
IRDFTConstantAxesAndConstantSignalSizeTest,
::testing::Values(
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, 180, 2}, {2}, Shape{}, {2, 180, 358}, {1, 2}, {}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, 180, 2}, {2}, Shape{}, {2, 180, 180}, {2, 0}, {}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{16, 500, 180, 369, 2},
{3},
Shape{},
{16, 998, 180, 369},
{0, 3, 1},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, 180, Dimension(1, 18)},
{2},
Shape{},
{2, 180, 358},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, Dimension(7, 500), 2},
{2},
Shape{},
{2, 180, Dimension(12, 998)},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, Dimension(7, 500), Dimension(1, 18)},
{2},
Shape{},
{2, 180, Dimension(12, 998)},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(7, 500), 180, 2},
{2},
Shape{},
{2, Dimension(7, 500), 358},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(7, 500), 180, Dimension(1, 18)},
{2},
Shape{},
{2, Dimension(7, 500), 358},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(7, 500), Dimension(7, 500), 2},
{2},
Shape{},
{2, Dimension(7, 500), Dimension(12, 998)},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
{2},
Shape{},
{2, Dimension(7, 500), Dimension(12, 998)},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), 180, 180, 2},
{2},
Shape{},
{Dimension(0, 2), 180, 358},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), 180, 180, Dimension(1, 18)},
{2},
Shape{},
{Dimension(0, 2), 180, 358},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), 180, Dimension(7, 500), 2},
{2},
Shape{},
{Dimension(0, 2), 180, Dimension(12, 998)},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), 180, Dimension(7, 500), Dimension(1, 18)},
{2},
Shape{},
{Dimension(0, 2), 180, Dimension(12, 998)},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), Dimension(7, 500), 180, 2},
{2},
Shape{},
{Dimension(0, 2), Dimension(7, 500), 358},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), Dimension(7, 500), 180, Dimension(1, 18)},
{2},
Shape{},
{Dimension(0, 2), Dimension(7, 500), 358},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), 2},
{2},
Shape{},
{Dimension(0, 2), Dimension(7, 500), Dimension(12, 998)},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{
{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
{2},
Shape{},
{Dimension(0, 2), Dimension(7, 500), Dimension(12, 998)},
{1, 2},
{}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, 180, 2}, {2}, {2}, {2, 180, 77}, {1, 2}, {-1, 77}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, 180, 2}, {2}, {2}, {87, 180, 390}, {2, 0}, {390, 87}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{7, 50, 130, 400, 2},
{3},
{3},
{7, 40, 130, 600},
{3, 0, 1},
{600, -1, 40}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(0, 200), 180, 2},
{2},
{2},
{2, Dimension(0, 200), 77},
{1, 2},
{-1, 77}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 18), 180, Dimension(0, 400), 2},
{2},
{2},
{87, 180, 390},
{2, 0},
{390, 87}},
IRDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(8, 129), 50, 130, Dimension(0, 500), 2},
{3},
{3},
{Dimension(8, 129), 40, 130, 600},
{3, 0, 1},
{600, -1, 40}}),
PrintToDummyParamName());
TEST(type_prop, irdft_dynamic_axes) {
const auto input_shape = PartialShape{2, 180, 180, Dimension(1, 18)};
const auto axes_shape = PartialShape::dynamic();
const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
auto data = std::make_shared<op::Parameter>(element::f32, input_shape);
auto axes_input = std::make_shared<op::Parameter>(element::i64, axes_shape);
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes_input);
EXPECT_EQ(irdft->get_element_type(), element::f32);
ASSERT_TRUE(irdft->get_output_partial_shape(0).same_scheme(ref_output_shape));
}
struct IRDFTNonConstantAxesTestParams {
PartialShape input_shape;
Shape axes_shape;
PartialShape ref_output_shape;
};
struct IRDFTNonConstantAxesTest : ::testing::TestWithParam<IRDFTNonConstantAxesTestParams> {};
TEST_P(IRDFTNonConstantAxesTest, irdft_non_constant_axes) {
auto params = GetParam();
auto data = std::make_shared<op::Parameter>(element::f32, params.input_shape);
auto axes_input = std::make_shared<op::Parameter>(element::i64, params.axes_shape);
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes_input);
EXPECT_EQ(irdft->get_element_type(), element::f32);
ASSERT_TRUE(irdft->get_output_partial_shape(0).same_scheme(params.ref_output_shape));
}
INSTANTIATE_TEST_SUITE_P(
type_prop,
IRDFTNonConstantAxesTest,
::testing::Values(
IRDFTNonConstantAxesTestParams{{2, 180, 180, Dimension(1, 18)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{2, 180, Dimension(7, 500), 2},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{2, 180, Dimension(7, 500), Dimension(1, 18)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{2, Dimension(7, 500), 180, 2},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{2, Dimension(7, 500), 180, Dimension(1, 18)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{2, Dimension(7, 500), Dimension(7, 500), 2},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{2, Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{Dimension(0, 2), 180, 180, 2},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{Dimension(0, 2), 180, 180, Dimension(1, 18)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{Dimension(0, 2), 180, Dimension(7, 500), 2},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{Dimension(0, 2), 180, Dimension(7, 500), Dimension(1, 18)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{Dimension(0, 2), Dimension(7, 500), 180, 2},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{Dimension(0, 2), Dimension(7, 500), 180, Dimension(1, 18)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), 2},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}},
IRDFTNonConstantAxesTestParams{{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}}),
PrintToDummyParamName());
struct IRDFTNonConstantSignalSizeTestParams {
PartialShape input_shape;
Shape axes_shape;
Shape signal_size_shape;
PartialShape ref_output_shape;
std::vector<int64_t> axes;
};
struct IRDFTNonConstantSignalSizeTest : ::testing::TestWithParam<IRDFTNonConstantSignalSizeTestParams> {};
TEST_P(IRDFTNonConstantSignalSizeTest, irdft_non_constant_signal_size) {
auto params = GetParam();
auto data = std::make_shared<op::Parameter>(element::f32, params.input_shape);
auto axes_input = op::Constant::create<int64_t>(element::i64, params.axes_shape, params.axes);
auto signal_size_input = std::make_shared<op::Parameter>(element::i64, params.signal_size_shape);
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes_input, signal_size_input);
EXPECT_EQ(irdft->get_element_type(), element::f32);
ASSERT_TRUE(irdft->get_output_partial_shape(0).same_scheme(params.ref_output_shape));
}
INSTANTIATE_TEST_SUITE_P(
type_prop,
IRDFTNonConstantSignalSizeTest,
::testing::Values(IRDFTNonConstantSignalSizeTestParams{{2, Dimension(0, 200), 180, 2},
{2},
{2},
{2, Dimension(0, 200), Dimension::dynamic()},
{1, 2}},
IRDFTNonConstantSignalSizeTestParams{{Dimension(0, 18), 180, Dimension(0, 400), 2},
{2},
{2},
{Dimension::dynamic(), 180, Dimension(0, 400)},
{2, 0}},
IRDFTNonConstantSignalSizeTestParams{
{Dimension(8, 129), 50, 130, Dimension(0, 500), 2},
{3},
{3},
{Dimension(8, 129), Dimension::dynamic(), 130, Dimension(0, 500)},
{3, 0, 1}}),
PrintToDummyParamName());

View File

@ -0,0 +1,245 @@
//*****************************************************************************
// Copyright 2017-2022 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace ngraph;
struct RDFTConstantAxesAndConstantSignalSizeTestParams {
PartialShape input_shape;
Shape axes_shape;
Shape signal_size_shape;
PartialShape ref_output_shape;
std::vector<int64_t> axes;
std::vector<int64_t> signal_size;
};
struct RDFTConstantAxesAndConstantSignalSizeTest
: ::testing::TestWithParam<RDFTConstantAxesAndConstantSignalSizeTestParams> {};
TEST_P(RDFTConstantAxesAndConstantSignalSizeTest, rdft_constant_axes_and_signal_size) {
auto params = GetParam();
auto data = std::make_shared<op::Parameter>(element::f32, params.input_shape);
auto axes_input = op::Constant::create<int64_t>(element::i64, params.axes_shape, params.axes);
std::shared_ptr<op::v9::RDFT> rdft;
if (params.signal_size.empty()) {
rdft = std::make_shared<op::v9::RDFT>(data, axes_input);
} else {
auto signal_size_input =
op::Constant::create<int64_t>(element::i64, params.signal_size_shape, params.signal_size);
rdft = std::make_shared<op::v9::RDFT>(data, axes_input, signal_size_input);
}
EXPECT_EQ(rdft->get_element_type(), element::f32);
ASSERT_TRUE(rdft->get_output_partial_shape(0).same_scheme(params.ref_output_shape));
}
INSTANTIATE_TEST_SUITE_P(
type_prop,
RDFTConstantAxesAndConstantSignalSizeTest,
::testing::Values(
RDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, 180}, {2}, Shape{}, {2, 180, 91, 2}, {1, 2}, {}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{6, 180, 180}, {2}, Shape{}, {4, 180, 180, 2}, {2, 0}, {}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{16, 500, 180, 369},
{3},
Shape{},
{16, 251, 180, 369, 2},
{0, 3, 1},
{}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, Dimension(1, 18)},
{2},
Shape{},
{2, 180, Dimension(1, 10), 2},
{1, 2},
{}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, Dimension(7, 500)},
{2},
Shape{},
{2, 180, Dimension(4, 251), 2},
{1, 2},
{}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(7, 500), 180},
{2},
Shape{},
{2, Dimension(7, 500), 91, 2},
{1, 2},
{}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(7, 500), Dimension(7, 500)},
{2},
Shape{},
{2, Dimension(7, 500), Dimension(4, 251), 2},
{1, 2},
{}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), 180, 180},
{2},
Shape{},
{Dimension(0, 2), 180, 91, 2},
{1, 2},
{}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), 180, Dimension(7, 500)},
{2},
Shape{},
{Dimension(0, 2), 180, Dimension(4, 251), 2},
{1, 2},
{}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), Dimension(7, 500), 180},
{2},
Shape{},
{Dimension(0, 2), Dimension(7, 500), 91, 2},
{1, 2},
{}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500)},
{2},
Shape{},
{Dimension(0, 2), Dimension(7, 500), Dimension(4, 251), 2},
{1, 2},
{}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, 180}, {2}, {2}, {2, 180, 39, 2}, {1, 2}, {-1, 77}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{2, 180, 180}, {2}, {2}, {44, 180, 390, 2}, {2, 0}, {390, 87}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{7, 50, 130, 400},
{3},
{3},
{7, 21, 130, 600, 2},
{3, 0, 1},
{600, -1, 40}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(0, 200), 180},
{2},
{2},
{2, Dimension(0, 200), 39, 2},
{1, 2},
{-1, 77}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 18), 180, Dimension(0, 400)},
{2},
{2},
{44, 180, 390, 2},
{2, 0},
{390, 87}},
RDFTConstantAxesAndConstantSignalSizeTestParams{{Dimension(8, 129), 50, 130, Dimension(0, 500)},
{3},
{3},
{Dimension(8, 129), 21, 130, 600, 2},
{3, 0, 1},
{600, -1, 40}}),
PrintToDummyParamName());
TEST(type_prop, rdft_dynamic_axes) {
const auto input_shape = PartialShape{2, 180, 180};
const auto axes_shape = PartialShape::dynamic();
const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2};
auto data = std::make_shared<op::Parameter>(element::f32, input_shape);
auto axes_input = std::make_shared<op::Parameter>(element::i64, axes_shape);
auto rdft = std::make_shared<op::v9::RDFT>(data, axes_input);
EXPECT_EQ(rdft->get_element_type(), element::f32);
ASSERT_TRUE(rdft->get_output_partial_shape(0).same_scheme(ref_output_shape));
}
struct RDFTNonConstantAxesTestParams {
PartialShape input_shape;
Shape axes_shape;
PartialShape ref_output_shape;
};
struct RDFTNonConstantAxesTest : ::testing::TestWithParam<RDFTNonConstantAxesTestParams> {};
TEST_P(RDFTNonConstantAxesTest, rdft_non_constant_axes) {
auto params = GetParam();
auto data = std::make_shared<op::Parameter>(element::f32, params.input_shape);
auto axes_input = std::make_shared<op::Parameter>(element::i64, params.axes_shape);
auto rdft = std::make_shared<op::v9::RDFT>(data, axes_input);
EXPECT_EQ(rdft->get_element_type(), element::f32);
ASSERT_TRUE(rdft->get_output_partial_shape(0).same_scheme(params.ref_output_shape));
}
INSTANTIATE_TEST_SUITE_P(
type_prop,
RDFTNonConstantAxesTest,
::testing::Values(
RDFTNonConstantAxesTestParams{{2, 180, 180},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
RDFTNonConstantAxesTestParams{{2, 180, Dimension(7, 500)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
RDFTNonConstantAxesTestParams{{2, Dimension(7, 500), 180},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
RDFTNonConstantAxesTestParams{{2, Dimension(7, 500), Dimension(7, 500)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
RDFTNonConstantAxesTestParams{{Dimension(0, 2), 180, 180},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
RDFTNonConstantAxesTestParams{{Dimension(0, 2), 180, Dimension(7, 500)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
RDFTNonConstantAxesTestParams{{Dimension(0, 2), Dimension(7, 500), 180},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
RDFTNonConstantAxesTestParams{{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500)},
{2},
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}}),
PrintToDummyParamName());
struct RDFTNonConstantSignalSizeTestParams {
PartialShape input_shape;
Shape axes_shape;
Shape signal_size_shape;
PartialShape ref_output_shape;
std::vector<int64_t> axes;
};
struct RDFTNonConstantSignalSizeTest : ::testing::TestWithParam<RDFTNonConstantSignalSizeTestParams> {};
TEST_P(RDFTNonConstantSignalSizeTest, rdft_non_constant_signal_size) {
auto params = GetParam();
auto data = std::make_shared<op::Parameter>(element::f32, params.input_shape);
auto axes_input = op::Constant::create<int64_t>(element::i64, params.axes_shape, params.axes);
auto signal_size_input = std::make_shared<op::Parameter>(element::i64, params.signal_size_shape);
auto rdft = std::make_shared<op::v9::RDFT>(data, axes_input, signal_size_input);
EXPECT_EQ(rdft->get_element_type(), element::f32);
ASSERT_TRUE(rdft->get_output_partial_shape(0).same_scheme(params.ref_output_shape));
}
INSTANTIATE_TEST_SUITE_P(
type_prop,
RDFTNonConstantSignalSizeTest,
::testing::Values(RDFTNonConstantSignalSizeTestParams{{2, Dimension(0, 200), 180},
{2},
{2},
{2, Dimension(0, 200), Dimension::dynamic(), 2},
{1, 2}},
RDFTNonConstantSignalSizeTestParams{{Dimension(0, 18), 180, Dimension(0, 400)},
{2},
{2},
{Dimension::dynamic(), 180, Dimension(0, 400), 2},
{2, 0}},
RDFTNonConstantSignalSizeTestParams{
{Dimension(8, 129), 50, 130, Dimension(0, 500)},
{3},
{3},
{Dimension(8, 129), Dimension::dynamic(), 130, Dimension(0, 500), 2},
{3, 0, 1}}),
PrintToDummyParamName());

View File

@ -0,0 +1,37 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "util/visitor.hpp"
using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;
TEST(attributes, irdft_op) {
NodeBuilder::get_ops().register_factory<op::v9::IRDFT>();
auto data = make_shared<op::v0::Parameter>(element::f32, Shape{2, 10, 10, 2});
auto axes = op::v0::Constant::create<int64_t>(element::i64, Shape{1}, {2});
auto irdft = make_shared<op::v9::IRDFT>(data, axes);
NodeBuilder builder(irdft);
const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}
TEST(attributes, irdft_op_signal) {
NodeBuilder::get_ops().register_factory<op::v9::IRDFT>();
auto data = make_shared<op::v0::Parameter>(element::f32, Shape{2, 10, 10, 2});
auto signal = op::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {20});
auto axes = op::v0::Constant::create<int64_t>(element::i64, Shape{1}, {2});
auto irdft = make_shared<op::v9::IRDFT>(data, axes, signal);
NodeBuilder builder(irdft);
const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}

View File

@ -0,0 +1,37 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "util/visitor.hpp"
using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;
TEST(attributes, rdft_op) {
NodeBuilder::get_ops().register_factory<op::v9::RDFT>();
auto data = make_shared<op::v0::Parameter>(element::f32, Shape{2, 10, 10});
auto axes = op::v0::Constant::create<int64_t>(element::i64, Shape{1}, {2});
auto rdft = make_shared<op::v9::RDFT>(data, axes);
NodeBuilder builder(rdft);
const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}
TEST(attributes, rdft_op_signal) {
NodeBuilder::get_ops().register_factory<op::v9::RDFT>();
auto data = make_shared<op::v0::Parameter>(element::f32, Shape{2, 10, 10});
auto signal = op::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {20});
auto axes = op::v0::Constant::create<int64_t>(element::i64, Shape{1}, {2});
auto rdft = make_shared<op::v9::RDFT>(data, axes, signal);
NodeBuilder builder(rdft);
const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}