nGraph shell for the operation DFT (#4625)
* Written the header file for the nGraph operation DFT. * Small change. * Started to write the implementation of the nGraph shell of the operation DFT. * Started to write void op::v7::DFT::validate_and_infer_types(). * Small fixes. * Code style fixes. * Written the draft of the shape infer for the nGraph operation DFT. * Small fixes. * Code style fixes. * Added DFT into opset7 table. * Some additions. * Small fixes. * Code style fix. * Some fixes. * Some fix. * Small fixes. * Started to write shape infer tests for the nGraph operation DFT. * Written shape infer tests for the nGraph operation DFT. * Some code style fixes. * Small fix. * Code style fixes. * Code style fixes. * Deleted unused variables. * Added support for negative axes. * Started to write IE IR Reader tests for the nGraph operation DFT. * Small fix. * Added the second IE IR Reader test for the nGraph operation DFT. * Small fix. * Added the third IE IR Reader test for the nGraph operation DFT. * Corrected Doxygen comment. * Started to rewrite DFT type_prop tests as parametrized tests. * Small fixes. * Some additions. * Small fix. * Small fix. * Some tests were written as parametrized tests. Some code style fixes. * Code style fixes. * Continued to rewrite tests for DFT as parametrized ones. * Some deletions. * Some additions. * Deleted redundant tests. * Started to rewrite some tests. * Some changes. * Deleted commented code. * Started to convert tests for the case non-constant axes into parametrized tests. * Rewritten tests for the case non-constant axes as parametrized tests. * Started to convert tests for the case non-constant signal_size into parametrized tests. * Rewritten tests for the case non-constant signal size as parametrized tests. * Added checks for number of inputs. * Small fixes. * Small fixes. * Refactored shape infer and corrected tests. * Some refactoring. * Now the function validate() is protected. * Small refactoring. * Fixed typo. * Added some comments. * Fixes in infer function. * Added test. * Fixed test.
This commit is contained in:
parent
b76e965e95
commit
af95452026
@ -0,0 +1,451 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <string>
|
||||
#include "ngraph_reader_tests.hpp"
|
||||
#include "common_test_utils/data_utils.hpp"
|
||||
|
||||
TEST_F(NGraphReaderTests, ReadDFTNetwork) {
|
||||
std::string model = R"V0G0N(
|
||||
<net name="deformable_convolution" version="10">
|
||||
<layers>
|
||||
<layer id="0" name="in1" type="Parameter" version="opset1">
|
||||
<data shape="1,180,180,2" element_type="f32"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>180</dim>
|
||||
<dim>180</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="1" name="axes" type="Const" version="opset1">
|
||||
<data offset="0" size="16" shape="2" element_type="i64"/>
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="2" name="dft" type="DFT" version="opset7">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
<dim>180</dim>
|
||||
<dim>180</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
<port id="1">
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>180</dim>
|
||||
<dim>180</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="3" name="output" type="Result" version="opset1">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
<dim>180</dim>
|
||||
<dim>180</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</input>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="2" to-port="0"/>
|
||||
<edge from-layer="1" from-port="0" to-layer="2" to-port="1"/>
|
||||
<edge from-layer="2" from-port="2" to-layer="3" to-port="0"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
std::string modelV7 = R"V0G0N(
|
||||
<net name="deformable_convolution" version="7">
|
||||
<layers>
|
||||
<layer id="0" name="in1" type="Input" version="opset1">
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>180</dim>
|
||||
<dim>180</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="1" name="axes" type="Const" version="opset1">
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
<blobs>
|
||||
<custom offset="0" size="16" precision="I64"/>
|
||||
</blobs>
|
||||
</layer>
|
||||
<layer id="2" name="dft" type="DFT" version="opset7">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
<dim>180</dim>
|
||||
<dim>180</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
<port id="1">
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>180</dim>
|
||||
<dim>180</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="2" to-port="0"/>
|
||||
<edge from-layer="1" from-port="0" to-layer="2" to-port="1"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
|
||||
compareIRs(model, modelV7, 16, [](Blob::Ptr& weights) {
|
||||
auto * i64w = weights->buffer().as<int64_t*>();
|
||||
i64w[0] = 2;
|
||||
i64w[1] = 0;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(NGraphReaderTests, ReadDFTNetwork2) {
|
||||
std::string model = R"V0G0N(
|
||||
<net name="deformable_convolution" version="10">
|
||||
<layers>
|
||||
<layer id="0" name="in1" type="Parameter" version="opset1">
|
||||
<data shape="7,50,130,400,2" element_type="f32"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>7</dim>
|
||||
<dim>50</dim>
|
||||
<dim>130</dim>
|
||||
<dim>400</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="1" name="axes" type="Const" version="opset1">
|
||||
<data offset="0" size="24" shape="3" element_type="i64"/>
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="2" name="signal_size" type="Const" version="opset1">
|
||||
<data offset="24" size="24" shape="3" element_type="i64"/>
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="3" name="dft" type="DFT" version="opset7">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>7</dim>
|
||||
<dim>50</dim>
|
||||
<dim>130</dim>
|
||||
<dim>400</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
<port id="1">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
<port id="2">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="3" precision="FP32">
|
||||
<dim>7</dim>
|
||||
<dim>40</dim>
|
||||
<dim>130</dim>
|
||||
<dim>600</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="4" name="output" type="Result" version="opset1">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>7</dim>
|
||||
<dim>40</dim>
|
||||
<dim>130</dim>
|
||||
<dim>600</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</input>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="3" to-port="0"/>
|
||||
<edge from-layer="1" from-port="0" to-layer="3" to-port="1"/>
|
||||
<edge from-layer="2" from-port="0" to-layer="3" to-port="2"/>
|
||||
<edge from-layer="3" from-port="3" to-layer="4" to-port="0"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
std::string modelV7 = R"V0G0N(
|
||||
<net name="deformable_convolution" version="7">
|
||||
<layers>
|
||||
<layer id="0" name="in1" type="Input" version="opset1">
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>7</dim>
|
||||
<dim>50</dim>
|
||||
<dim>130</dim>
|
||||
<dim>400</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="1" name="axes" type="Const" version="opset1">
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
<blobs>
|
||||
<custom offset="0" size="24" precision="I64"/>
|
||||
</blobs>
|
||||
</layer>
|
||||
<layer id="2" name="signal_size" type="Const" version="opset1">
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
<blobs>
|
||||
<custom offset="24" size="24" precision="I64"/>
|
||||
</blobs>
|
||||
</layer>
|
||||
<layer id="3" name="dft" type="DFT" version="opset7">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>7</dim>
|
||||
<dim>50</dim>
|
||||
<dim>130</dim>
|
||||
<dim>400</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
<port id="1">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
<port id="2">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2" precision="FP32">
|
||||
<dim>7</dim>
|
||||
<dim>40</dim>
|
||||
<dim>130</dim>
|
||||
<dim>600</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="3" to-port="0"/>
|
||||
<edge from-layer="1" from-port="0" to-layer="3" to-port="1"/>
|
||||
<edge from-layer="2" from-port="0" to-layer="3" to-port="2"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
|
||||
compareIRs(model, modelV7, 48, [](Blob::Ptr& weights) {
|
||||
auto * i64w = weights->buffer().as<int64_t*>();
|
||||
i64w[0] = 3;
|
||||
i64w[1] = 0;
|
||||
i64w[2] = 1;
|
||||
i64w[3] = 600;
|
||||
i64w[4] = -1;
|
||||
i64w[5] = 40;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(NGraphReaderTests, ReadDFTNetwork3) {
|
||||
std::string model = R"V0G0N(
|
||||
<net name="deformable_convolution" version="10">
|
||||
<layers>
|
||||
<layer id="0" name="in1" type="Parameter" version="opset1">
|
||||
<data shape="7,15,200,124,70,2" element_type="f32"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>7</dim>
|
||||
<dim>15</dim>
|
||||
<dim>200</dim>
|
||||
<dim>124</dim>
|
||||
<dim>70</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="1" name="axes" type="Const" version="opset1">
|
||||
<data offset="0" size="24" shape="3" element_type="i64"/>
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="2" name="signal_size" type="Const" version="opset1">
|
||||
<data offset="24" size="24" shape="3" element_type="i64"/>
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="3" name="dft" type="DFT" version="opset7">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>7</dim>
|
||||
<dim>15</dim>
|
||||
<dim>200</dim>
|
||||
<dim>124</dim>
|
||||
<dim>70</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
<port id="1">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
<port id="2">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="3" precision="FP32">
|
||||
<dim>7</dim>
|
||||
<dim>15</dim>
|
||||
<dim>100</dim>
|
||||
<dim>124</dim>
|
||||
<dim>280</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="4" name="output" type="Result" version="opset1">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>7</dim>
|
||||
<dim>15</dim>
|
||||
<dim>100</dim>
|
||||
<dim>124</dim>
|
||||
<dim>280</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</input>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="3" to-port="0"/>
|
||||
<edge from-layer="1" from-port="0" to-layer="3" to-port="1"/>
|
||||
<edge from-layer="2" from-port="0" to-layer="3" to-port="2"/>
|
||||
<edge from-layer="3" from-port="3" to-layer="4" to-port="0"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
std::string modelV7 = R"V0G0N(
|
||||
<net name="deformable_convolution" version="7">
|
||||
<layers>
|
||||
<layer id="0" name="in1" type="Input" version="opset1">
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>7</dim>
|
||||
<dim>15</dim>
|
||||
<dim>200</dim>
|
||||
<dim>124</dim>
|
||||
<dim>70</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="1" name="axes" type="Const" version="opset1">
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
<blobs>
|
||||
<custom offset="0" size="24" precision="I64"/>
|
||||
</blobs>
|
||||
</layer>
|
||||
<layer id="2" name="signal_size" type="Const" version="opset1">
|
||||
<output>
|
||||
<port id="0" precision="I64">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
<blobs>
|
||||
<custom offset="24" size="24" precision="I64"/>
|
||||
</blobs>
|
||||
</layer>
|
||||
<layer id="3" name="dft" type="DFT" version="opset7">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>7</dim>
|
||||
<dim>15</dim>
|
||||
<dim>200</dim>
|
||||
<dim>124</dim>
|
||||
<dim>70</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
<port id="1">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
<port id="2">
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2" precision="FP32">
|
||||
<dim>7</dim>
|
||||
<dim>15</dim>
|
||||
<dim>100</dim>
|
||||
<dim>124</dim>
|
||||
<dim>280</dim>
|
||||
<dim>2</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="3" to-port="0"/>
|
||||
<edge from-layer="1" from-port="0" to-layer="3" to-port="1"/>
|
||||
<edge from-layer="2" from-port="0" to-layer="3" to-port="2"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
|
||||
compareIRs(model, modelV7, 48, [](Blob::Ptr& weights) {
|
||||
auto * i64w = weights->buffer().as<int64_t*>();
|
||||
i64w[0] = -3;
|
||||
i64w[1] = 4;
|
||||
i64w[2] = 0;
|
||||
i64w[3] = 100;
|
||||
i64w[4] = 280;
|
||||
i64w[5] = -1;
|
||||
});
|
||||
}
|
65
ngraph/core/include/ngraph/op/dft.hpp
Normal file
65
ngraph/core/include/ngraph/op/dft.hpp
Normal file
@ -0,0 +1,65 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
#include "ngraph/attribute_adapter.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v7
|
||||
{
|
||||
/// \brief An operation DFT that computes the discrete Fourier transformation.
|
||||
class NGRAPH_API DFT : public Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
DFT() = default;
|
||||
|
||||
/// \brief Constructs a DFT operation. DFT is performed for full size axes.
|
||||
///
|
||||
/// \param data Input data
|
||||
/// \param axes Axes to perform DFT
|
||||
DFT(const Output<Node>& data, const Output<Node>& axes);
|
||||
|
||||
/// \brief Constructs a DFT operation.
|
||||
///
|
||||
/// \param data Input data
|
||||
/// \param axes Axes to perform DFT
|
||||
/// \param signal_size Signal sizes for 'axes'
|
||||
DFT(const Output<Node>& data,
|
||||
const Output<Node>& axes,
|
||||
const Output<Node>& signal_size);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
protected:
|
||||
void validate();
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
@ -39,6 +39,7 @@
|
||||
#include "ngraph/op/deformable_psroi_pooling.hpp"
|
||||
#include "ngraph/op/depth_to_space.hpp"
|
||||
#include "ngraph/op/detection_output.hpp"
|
||||
#include "ngraph/op/dft.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/elu.hpp"
|
||||
#include "ngraph/op/embedding_segments_sum.hpp"
|
||||
|
@ -170,6 +170,7 @@ NGRAPH_OP(Assign, ngraph::op::v6) // new version
|
||||
NGRAPH_OP(ReadValue, ngraph::op::v6) // new version
|
||||
|
||||
// New operations added in opset7
|
||||
NGRAPH_OP(DFT, ngraph::op::v7)
|
||||
NGRAPH_OP(Gelu, ngraph::op::v7)
|
||||
NGRAPH_OP(IDFT, ngraph::op::v7)
|
||||
NGRAPH_OP(Roll, ngraph::op::v7)
|
||||
|
267
ngraph/core/src/op/dft.cpp
Normal file
267
ngraph/core/src/op/dft.cpp
Normal file
@ -0,0 +1,267 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/axis_set.hpp"
|
||||
#include "ngraph/axis_vector.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/dft.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::DFT, "DFT", 7);
|
||||
|
||||
op::v7::DFT::DFT(const Output<Node>& data, const Output<Node>& axes)
|
||||
: Op({data, axes})
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::v7::DFT::DFT(const Output<Node>& data,
|
||||
const Output<Node>& axes,
|
||||
const Output<Node>& signal_size)
|
||||
: Op({data, axes, signal_size})
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool op::v7::DFT::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_DFT_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::v7::DFT::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_DFT_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, new_args.size() == 2 || new_args.size() == 3, "Number of inputs must be 2 or 3");
|
||||
|
||||
if (new_args.size() == 2)
|
||||
{
|
||||
return std::make_shared<op::v7::DFT>(new_args.at(0), new_args.at(1));
|
||||
}
|
||||
|
||||
return std::make_shared<op::v7::DFT>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
}
|
||||
|
||||
void op::v7::DFT::validate()
|
||||
{
|
||||
size_t num_of_inputs = get_input_size();
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, num_of_inputs == 2 || num_of_inputs == 3, "DFT must have 2 or 3 inputs.");
|
||||
|
||||
element::Type input_et = get_input_element_type(0);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_et == element::f32 || input_et == element::f16 ||
|
||||
input_et == element::bf16,
|
||||
"DFT input element type must be f32, f16, or bf16");
|
||||
|
||||
element::Type axes_et = get_input_element_type(1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_et == element::i64 || axes_et == element::i32,
|
||||
"DFT axes element type must be i32 or i64");
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
if (input_shape.rank().is_static())
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= 2,
|
||||
"The input rank must be greater or equal to 2. Got input rank: ",
|
||||
input_rank);
|
||||
|
||||
auto last_dim_with_two = input_shape[input_rank - 1] & Dimension(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
!last_dim_with_two.get_interval().empty(),
|
||||
"The last dimension of input data must be 2. Got: ",
|
||||
input_shape[input_rank - 1]);
|
||||
}
|
||||
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
if (axes_shape.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.rank().get_length() == 1,
|
||||
"DFT axes input must be 1D tensor. Got axes input rank: ",
|
||||
axes_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && axes_shape.is_static())
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= axes_shape.to_shape()[0] + 1,
|
||||
"The input rank must be greater than number of DFT axes. Got "
|
||||
"input rank: ",
|
||||
input_rank,
|
||||
", number of axes: ",
|
||||
axes_shape.to_shape()[0]);
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && is_type<op::Constant>(input_value(1).get_node()))
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
|
||||
// DFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the DFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
//'r - 1 + a'. The reason is the following.
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
AxisVector axes_vector;
|
||||
AxisSet axes_set;
|
||||
for (const int64_t axis : axes)
|
||||
{
|
||||
axes_vector.push_back(static_cast<size_t>(axis));
|
||||
axes_set.insert(static_cast<size_t>(axis));
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, axes.size() == axes_set.size(), "DFT axes must be unique. Got: ", axes_vector);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
|
||||
"DFT axes cannot contain the last axis. Got axes: ",
|
||||
axes_vector);
|
||||
}
|
||||
|
||||
if (num_of_inputs == 3)
|
||||
{
|
||||
element::Type signal_size_et = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_et == element::i64 || signal_size_et == element::i32,
|
||||
"DFT signal_size element type must be i32 or i64");
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_shape.rank().get_length() == 1,
|
||||
"DFT Signal size input must be 1D tensor. Got signal size "
|
||||
"input rank: ",
|
||||
signal_size_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (axes_shape.is_static() && signal_size_shape.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.to_shape()[0] == signal_size_shape.to_shape()[0],
|
||||
"Sizes of inputs 'axes' and 'signal_size' must be equal. Got "
|
||||
"size of 'axes': ",
|
||||
axes_shape.to_shape()[0],
|
||||
"size of 'signal_size': ",
|
||||
signal_size_shape.to_shape()[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void op::v7::DFT::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_DFT_validate_and_infer_types);
|
||||
validate();
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
PartialShape output_shape = input_shape;
|
||||
if (input_shape.rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
|
||||
if (axes_shape.rank().is_dynamic() || !is_type<op::Constant>(input_value(1).get_node()))
|
||||
{
|
||||
for (size_t i = 0; i < input_rank - 1; ++i)
|
||||
{
|
||||
output_shape[i] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
if (input_values().size() == 2)
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
// DFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the DFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
//'r - 1 + a'. The reason is the following.
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_type<op::Constant>(input_value(2).get_node()))
|
||||
{
|
||||
for (int64_t axis : axes)
|
||||
{
|
||||
output_shape[axis] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_signal_size = get_constant_from_source(input_value(2));
|
||||
const auto signal_size = const_signal_size->cast_vector<int64_t>();
|
||||
|
||||
size_t num_of_axes = axes.size();
|
||||
for (size_t i = 0; i < num_of_axes; ++i)
|
||||
{
|
||||
if (signal_size[i] == -1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
output_shape[axes[i]] = Dimension(signal_size[i]);
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
}
|
@ -113,6 +113,7 @@ set(SRC
|
||||
type_prop/deformable_psroi_pooling.cpp
|
||||
type_prop/detection_output.cpp
|
||||
type_prop/depth_to_space.cpp
|
||||
type_prop/dft.cpp
|
||||
type_prop/dyn_reshape.cpp
|
||||
type_prop/experimental_detectron_generate_proposals.cpp
|
||||
type_prop/experimental_detectron_roi_feature_extractor.cpp
|
||||
|
346
ngraph/test/type_prop/dft.cpp
Normal file
346
ngraph/test/type_prop/dft.cpp
Normal file
@ -0,0 +1,346 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
struct ConstantAxesAndConstantSignalSizeTestParams
|
||||
{
|
||||
PartialShape input_shape;
|
||||
Shape axes_shape;
|
||||
Shape signal_size_shape;
|
||||
PartialShape ref_output_shape;
|
||||
std::vector<int64_t> axes;
|
||||
std::vector<int64_t> signal_size;
|
||||
};
|
||||
|
||||
struct ConstantAxesAndConstantSignalSizeTest
|
||||
: ::testing::TestWithParam<ConstantAxesAndConstantSignalSizeTestParams>
|
||||
{
|
||||
};
|
||||
|
||||
TEST_P(ConstantAxesAndConstantSignalSizeTest, dft_constant_axes_and_signal_size)
|
||||
{
|
||||
auto params = GetParam();
|
||||
|
||||
auto data = std::make_shared<op::Parameter>(element::f32, params.input_shape);
|
||||
auto axes_input = op::Constant::create<int64_t>(element::i64, params.axes_shape, params.axes);
|
||||
|
||||
std::shared_ptr<op::v7::DFT> dft;
|
||||
if (params.signal_size.empty())
|
||||
{
|
||||
dft = std::make_shared<op::v7::DFT>(data, axes_input);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto signal_size_input = op::Constant::create<int64_t>(
|
||||
element::i64, params.signal_size_shape, params.signal_size);
|
||||
dft = std::make_shared<op::v7::DFT>(data, axes_input, signal_size_input);
|
||||
}
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(dft->get_output_partial_shape(0).same_scheme(params.ref_output_shape));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
type_prop,
|
||||
ConstantAxesAndConstantSignalSizeTest,
|
||||
::testing::Values(
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{2, 180, 180, 2}, {2}, Shape{}, {2, 180, 180, 2}, {1, 2}, {}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{2, 180, 180, 2}, {2}, Shape{}, {2, 180, 180, 2}, {2, 0}, {}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{16, 500, 180, 369, 2}, {3}, Shape{}, {16, 500, 180, 369, 2}, {0, 3, 1}, {}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{2, 180, 180, Dimension(1, 18)},
|
||||
{2},
|
||||
Shape{},
|
||||
{2, 180, 180, Dimension(1, 18)},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{2, 180, Dimension(7, 500), 2},
|
||||
{2},
|
||||
Shape{},
|
||||
{2, 180, Dimension(7, 500), 2},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{2, 180, Dimension(7, 500), Dimension(1, 18)},
|
||||
{2},
|
||||
Shape{},
|
||||
{2, 180, Dimension(7, 500), Dimension(1, 18)},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(7, 500), 180, 2},
|
||||
{2},
|
||||
Shape{},
|
||||
{2, Dimension(7, 500), 180, 2},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(7, 500), 180, Dimension(1, 18)},
|
||||
{2},
|
||||
Shape{},
|
||||
{2, Dimension(7, 500), 180, Dimension(1, 18)},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(7, 500), Dimension(7, 500), 2},
|
||||
{2},
|
||||
Shape{},
|
||||
{2, Dimension(7, 500), Dimension(7, 500), 2},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{2, Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
|
||||
{2},
|
||||
Shape{},
|
||||
{2, Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), 180, 180, 2},
|
||||
{2},
|
||||
Shape{},
|
||||
{Dimension(0, 2), 180, 180, 2},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), 180, 180, Dimension(1, 18)},
|
||||
{2},
|
||||
Shape{},
|
||||
{Dimension(0, 2), 180, 180, Dimension(1, 18)},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), 180, Dimension(7, 500), 2},
|
||||
{2},
|
||||
Shape{},
|
||||
{Dimension(0, 2), 180, Dimension(7, 500), 2},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{Dimension(0, 2), 180, Dimension(7, 500), Dimension(1, 18)},
|
||||
{2},
|
||||
Shape{},
|
||||
{Dimension(0, 2), 180, Dimension(7, 500), Dimension(1, 18)},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 2), Dimension(7, 500), 180, 2},
|
||||
{2},
|
||||
Shape{},
|
||||
{Dimension(0, 2), Dimension(7, 500), 180, 2},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{Dimension(0, 2), Dimension(7, 500), 180, Dimension(1, 18)},
|
||||
{2},
|
||||
Shape{},
|
||||
{Dimension(0, 2), Dimension(7, 500), 180, Dimension(1, 18)},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), 2},
|
||||
{2},
|
||||
Shape{},
|
||||
{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), 2},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
|
||||
{2},
|
||||
Shape{},
|
||||
{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
|
||||
{1, 2},
|
||||
{}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{2, 180, 180, 2}, {2}, {2}, {2, 180, 77, 2}, {1, 2}, {-1, 77}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{2, 180, 180, 2}, {2}, {2}, {87, 180, 390, 2}, {2, 0}, {390, 87}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{7, 50, 130, 400, 2}, {3}, {3}, {7, 40, 130, 600, 2}, {3, 0, 1}, {600, -1, 40}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{2, Dimension(0, 200), 180, 2},
|
||||
{2},
|
||||
{2},
|
||||
{2, Dimension(0, 200), 77, 2},
|
||||
{1, 2},
|
||||
{-1, 77}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{{Dimension(0, 18), 180, Dimension(0, 400), 2},
|
||||
{2},
|
||||
{2},
|
||||
{87, 180, 390, 2},
|
||||
{2, 0},
|
||||
{390, 87}},
|
||||
ConstantAxesAndConstantSignalSizeTestParams{
|
||||
{Dimension(8, 129), 50, 130, Dimension(0, 500), 2},
|
||||
{3},
|
||||
{3},
|
||||
{Dimension(8, 129), 40, 130, 600, 2},
|
||||
{3, 0, 1},
|
||||
{600, -1, 40}}),
|
||||
PrintToDummyParamName());
|
||||
|
||||
TEST(type_prop, dft_dynamic_axes)
|
||||
{
|
||||
const auto input_shape = PartialShape{2, 180, 180, Dimension(1, 18)};
|
||||
const auto axes_shape = PartialShape::dynamic();
|
||||
const auto ref_output_shape = PartialShape{Dimension::dynamic(),
|
||||
Dimension::dynamic(),
|
||||
Dimension::dynamic(),
|
||||
Dimension(1, 18)};
|
||||
|
||||
auto data = std::make_shared<op::Parameter>(element::f32, input_shape);
|
||||
auto axes_input = std::make_shared<op::Parameter>(element::i64, axes_shape);
|
||||
auto dft = std::make_shared<op::v7::DFT>(data, axes_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(dft->get_output_partial_shape(0).same_scheme(ref_output_shape));
|
||||
}
|
||||
|
||||
struct NonConstantAxesTestParams
|
||||
{
|
||||
PartialShape input_shape;
|
||||
Shape axes_shape;
|
||||
PartialShape ref_output_shape;
|
||||
};
|
||||
|
||||
struct NonConstantAxesTest : ::testing::TestWithParam<NonConstantAxesTestParams>
|
||||
{
|
||||
};
|
||||
|
||||
TEST_P(NonConstantAxesTest, dft_non_constant_axes)
|
||||
{
|
||||
auto params = GetParam();
|
||||
|
||||
auto data = std::make_shared<op::Parameter>(element::f32, params.input_shape);
|
||||
auto axes_input = std::make_shared<op::Parameter>(element::i64, params.axes_shape);
|
||||
auto dft = std::make_shared<op::v7::DFT>(data, axes_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(dft->get_output_partial_shape(0).same_scheme(params.ref_output_shape));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
type_prop,
|
||||
NonConstantAxesTest,
|
||||
::testing::Values(
|
||||
NonConstantAxesTestParams{
|
||||
{2, 180, 180, Dimension(1, 18)},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension(1, 18)}},
|
||||
NonConstantAxesTestParams{
|
||||
{2, 180, Dimension(7, 500), 2},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
|
||||
NonConstantAxesTestParams{
|
||||
{2, 180, Dimension(7, 500), Dimension(1, 18)},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension(1, 18)}},
|
||||
NonConstantAxesTestParams{
|
||||
{2, Dimension(7, 500), 180, 2},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
|
||||
NonConstantAxesTestParams{
|
||||
{2, Dimension(7, 500), 180, Dimension(1, 18)},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension(1, 18)}},
|
||||
NonConstantAxesTestParams{
|
||||
{2, Dimension(7, 500), Dimension(7, 500), 2},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
|
||||
NonConstantAxesTestParams{
|
||||
{2, Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension(1, 18)}},
|
||||
NonConstantAxesTestParams{
|
||||
{Dimension(0, 2), 180, 180, 2},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
|
||||
NonConstantAxesTestParams{
|
||||
{Dimension(0, 2), 180, 180, Dimension(1, 18)},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension(1, 18)}},
|
||||
NonConstantAxesTestParams{
|
||||
{Dimension(0, 2), 180, Dimension(7, 500), 2},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
|
||||
NonConstantAxesTestParams{
|
||||
{Dimension(0, 2), 180, Dimension(7, 500), Dimension(1, 18)},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension(1, 18)}},
|
||||
NonConstantAxesTestParams{
|
||||
{Dimension(0, 2), Dimension(7, 500), 180, 2},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
|
||||
NonConstantAxesTestParams{
|
||||
{Dimension(0, 2), Dimension(7, 500), 180, Dimension(1, 18)},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension(1, 18)}},
|
||||
NonConstantAxesTestParams{
|
||||
{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), 2},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}},
|
||||
NonConstantAxesTestParams{
|
||||
{Dimension(0, 2), Dimension(7, 500), Dimension(7, 500), Dimension(1, 18)},
|
||||
{2},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension(1, 18)}}),
|
||||
PrintToDummyParamName());
|
||||
|
||||
struct NonConstantSignalSizeTestParams
|
||||
{
|
||||
PartialShape input_shape;
|
||||
Shape axes_shape;
|
||||
Shape signal_size_shape;
|
||||
PartialShape ref_output_shape;
|
||||
std::vector<int64_t> axes;
|
||||
};
|
||||
|
||||
struct NonConstantSignalSizeTest : ::testing::TestWithParam<NonConstantSignalSizeTestParams>
|
||||
{
|
||||
};
|
||||
|
||||
TEST_P(NonConstantSignalSizeTest, dft_non_constant_signal_size)
|
||||
{
|
||||
auto params = GetParam();
|
||||
|
||||
auto data = std::make_shared<op::Parameter>(element::f32, params.input_shape);
|
||||
auto axes_input = op::Constant::create<int64_t>(element::i64, params.axes_shape, params.axes);
|
||||
auto signal_size_input =
|
||||
std::make_shared<op::Parameter>(element::i64, params.signal_size_shape);
|
||||
auto dft = std::make_shared<op::v7::DFT>(data, axes_input, signal_size_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(dft->get_output_partial_shape(0).same_scheme(params.ref_output_shape));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
type_prop,
|
||||
NonConstantSignalSizeTest,
|
||||
::testing::Values(
|
||||
NonConstantSignalSizeTestParams{{2, Dimension(0, 200), 180, 2},
|
||||
{2},
|
||||
{2},
|
||||
{2, Dimension::dynamic(), Dimension::dynamic(), 2},
|
||||
{1, 2}},
|
||||
NonConstantSignalSizeTestParams{{Dimension(0, 18), 180, Dimension(0, 400), 2},
|
||||
{2},
|
||||
{2},
|
||||
{Dimension::dynamic(), 180, Dimension::dynamic(), 2},
|
||||
{2, 0}},
|
||||
NonConstantSignalSizeTestParams{
|
||||
{Dimension(8, 129), 50, 130, Dimension(0, 500), 2},
|
||||
{3},
|
||||
{3},
|
||||
{Dimension::dynamic(), Dimension::dynamic(), 130, Dimension::dynamic(), 2},
|
||||
{3, 0, 1}}),
|
||||
PrintToDummyParamName());
|
Loading…
Reference in New Issue
Block a user