Added test for opset7::Gather (#5373)

This commit is contained in:
Vitaly Tuzov 2021-04-30 19:17:48 +03:00 committed by GitHub
parent a7353f4b28
commit bb022e2d26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 0 deletions

View File

@ -12,4 +12,8 @@ TEST_P(GatherLayerTest, CompareWithRefs) {
Run(); Run();
}; };
TEST_P(Gather7LayerTest, CompareWithRefs) {
Run();
};
} // namespace LayerTestsDefinitions } // namespace LayerTestsDefinitions

View File

@ -41,4 +41,26 @@ protected:
void SetUp() override; void SetUp() override;
}; };
typedef std::tuple<
std::vector<size_t>, // Input shapes
std::vector<size_t>, // Indices shape
std::tuple<int, int>, // Gather axis and batch
InferenceEngine::Precision, // Network precision
InferenceEngine::Precision, // Input precision
InferenceEngine::Precision, // Output precision
InferenceEngine::Layout, // Input layout
InferenceEngine::Layout, // Output layout
std::string // Device name
> gather7ParamsTuple;
class Gather7LayerTest : public testing::WithParamInterface<gather7ParamsTuple>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<gather7ParamsTuple>& obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions } // namespace LayerTestsDefinitions

View File

@ -51,4 +51,46 @@ void GatherLayerTest::SetUp() {
GatherLayerTestBase::SetUp(GetParam()); GatherLayerTestBase::SetUp(GetParam());
} }
std::string Gather7LayerTest::getTestCaseName(const testing::TestParamInfo<gather7ParamsTuple>& obj) {
std::tuple<int, int> axis_batchIdx;
std::vector<int> indices;
std::vector<size_t> indicesShape, inputShape;
InferenceEngine::Precision netPrecision;
InferenceEngine::Precision inPrc, outPrc;
InferenceEngine::Layout inLayout, outLayout;
std::string targetName;
std::tie(inputShape, indicesShape, axis_batchIdx, netPrecision, inPrc, outPrc, inLayout, outLayout, targetName) = obj.param;
std::ostringstream result;
result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
result << "axis=" << std::get<0>(axis_batchIdx) << "_";
result << "batchIdx=" << std::get<1>(axis_batchIdx) << "_";
result << "indicesShape=" << CommonTestUtils::vec2str(indicesShape) << "_";
result << "netPRC=" << netPrecision.name() << "_";
result << "inPRC=" << inPrc.name() << "_";
result << "outPRC=" << outPrc.name() << "_";
result << "inL=" << inLayout << "_";
result << "outL=" << outLayout << "_";
result << "trgDev=" << targetName << "_";
return result.str();
}
void Gather7LayerTest::SetUp() {
std::tuple<int, int> axis_batchIdx;
std::vector<size_t> indicesShape;
std::vector<size_t> inputShape;
InferenceEngine::Precision netPrecision;
std::tie(inputShape, indicesShape, axis_batchIdx, netPrecision, inPrc, outPrc, inLayout, outLayout, targetDevice) = GetParam();
int axis = std::get<0>(axis_batchIdx);
int batchIdx = std::get<1>(axis_batchIdx);
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto functionParams = ngraph::builder::makeParams(ngPrc, { inputShape });
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(functionParams));
auto indicesNode = ngraph::builder::makeConstant<int>(ngraph::element::i64, indicesShape, {}, true,
inputShape[axis < 0 ? axis + inputShape.size() : axis] - 1, 0);
auto axisNode = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape({}), { axis });
auto gather = std::make_shared<ngraph::opset7::Gather>(paramOuts[0], indicesNode, axisNode, batchIdx);
ngraph::ResultVector results{ std::make_shared<ngraph::opset7::Result>(gather) };
function = std::make_shared<ngraph::Function>(results, functionParams, "gather");
}
} // namespace LayerTestsDefinitions } // namespace LayerTestsDefinitions

View File

@ -13,6 +13,7 @@
#include <ngraph/opsets/opset4.hpp> #include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset5.hpp> #include <ngraph/opsets/opset5.hpp>
#include <ngraph/opsets/opset6.hpp> #include <ngraph/opsets/opset6.hpp>
#include <ngraph/opsets/opset7.hpp>
#include "ngraph_functions/utils/data_utils.hpp" #include "ngraph_functions/utils/data_utils.hpp"