Added test for opset7::Gather (#5373)
This commit is contained in:
parent
a7353f4b28
commit
bb022e2d26
@ -12,4 +12,8 @@ TEST_P(GatherLayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
|
||||
TEST_P(Gather7LayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
@ -41,4 +41,26 @@ protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<size_t>, // Input shapes
|
||||
std::vector<size_t>, // Indices shape
|
||||
std::tuple<int, int>, // Gather axis and batch
|
||||
InferenceEngine::Precision, // Network precision
|
||||
InferenceEngine::Precision, // Input precision
|
||||
InferenceEngine::Precision, // Output precision
|
||||
InferenceEngine::Layout, // Input layout
|
||||
InferenceEngine::Layout, // Output layout
|
||||
std::string // Device name
|
||||
> gather7ParamsTuple;
|
||||
|
||||
class Gather7LayerTest : public testing::WithParamInterface<gather7ParamsTuple>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<gather7ParamsTuple>& obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -51,4 +51,46 @@ void GatherLayerTest::SetUp() {
|
||||
GatherLayerTestBase::SetUp(GetParam());
|
||||
}
|
||||
|
||||
std::string Gather7LayerTest::getTestCaseName(const testing::TestParamInfo<gather7ParamsTuple>& obj) {
|
||||
std::tuple<int, int> axis_batchIdx;
|
||||
std::vector<int> indices;
|
||||
std::vector<size_t> indicesShape, inputShape;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
InferenceEngine::Precision inPrc, outPrc;
|
||||
InferenceEngine::Layout inLayout, outLayout;
|
||||
std::string targetName;
|
||||
std::tie(inputShape, indicesShape, axis_batchIdx, netPrecision, inPrc, outPrc, inLayout, outLayout, targetName) = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
|
||||
result << "axis=" << std::get<0>(axis_batchIdx) << "_";
|
||||
result << "batchIdx=" << std::get<1>(axis_batchIdx) << "_";
|
||||
result << "indicesShape=" << CommonTestUtils::vec2str(indicesShape) << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "inPRC=" << inPrc.name() << "_";
|
||||
result << "outPRC=" << outPrc.name() << "_";
|
||||
result << "inL=" << inLayout << "_";
|
||||
result << "outL=" << outLayout << "_";
|
||||
result << "trgDev=" << targetName << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void Gather7LayerTest::SetUp() {
|
||||
std::tuple<int, int> axis_batchIdx;
|
||||
std::vector<size_t> indicesShape;
|
||||
std::vector<size_t> inputShape;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(inputShape, indicesShape, axis_batchIdx, netPrecision, inPrc, outPrc, inLayout, outLayout, targetDevice) = GetParam();
|
||||
int axis = std::get<0>(axis_batchIdx);
|
||||
int batchIdx = std::get<1>(axis_batchIdx);
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto functionParams = ngraph::builder::makeParams(ngPrc, { inputShape });
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(functionParams));
|
||||
auto indicesNode = ngraph::builder::makeConstant<int>(ngraph::element::i64, indicesShape, {}, true,
|
||||
inputShape[axis < 0 ? axis + inputShape.size() : axis] - 1, 0);
|
||||
auto axisNode = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape({}), { axis });
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(paramOuts[0], indicesNode, axisNode, batchIdx);
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset7::Result>(gather) };
|
||||
function = std::make_shared<ngraph::Function>(results, functionParams, "gather");
|
||||
}
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
|
||||
#include "ngraph_functions/utils/data_utils.hpp"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user