Added test for opset7::Gather (#5373)
This commit is contained in:
parent
a7353f4b28
commit
bb022e2d26
@ -12,4 +12,8 @@ TEST_P(GatherLayerTest, CompareWithRefs) {
|
|||||||
Run();
|
Run();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TEST_P(Gather7LayerTest, CompareWithRefs) {
|
||||||
|
Run();
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace LayerTestsDefinitions
|
} // namespace LayerTestsDefinitions
|
||||||
|
@ -41,4 +41,26 @@ protected:
|
|||||||
void SetUp() override;
|
void SetUp() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
typedef std::tuple<
|
||||||
|
std::vector<size_t>, // Input shapes
|
||||||
|
std::vector<size_t>, // Indices shape
|
||||||
|
std::tuple<int, int>, // Gather axis and batch
|
||||||
|
InferenceEngine::Precision, // Network precision
|
||||||
|
InferenceEngine::Precision, // Input precision
|
||||||
|
InferenceEngine::Precision, // Output precision
|
||||||
|
InferenceEngine::Layout, // Input layout
|
||||||
|
InferenceEngine::Layout, // Output layout
|
||||||
|
std::string // Device name
|
||||||
|
> gather7ParamsTuple;
|
||||||
|
|
||||||
|
class Gather7LayerTest : public testing::WithParamInterface<gather7ParamsTuple>,
|
||||||
|
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||||
|
public:
|
||||||
|
static std::string getTestCaseName(const testing::TestParamInfo<gather7ParamsTuple>& obj);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace LayerTestsDefinitions
|
} // namespace LayerTestsDefinitions
|
@ -51,4 +51,46 @@ void GatherLayerTest::SetUp() {
|
|||||||
GatherLayerTestBase::SetUp(GetParam());
|
GatherLayerTestBase::SetUp(GetParam());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string Gather7LayerTest::getTestCaseName(const testing::TestParamInfo<gather7ParamsTuple>& obj) {
|
||||||
|
std::tuple<int, int> axis_batchIdx;
|
||||||
|
std::vector<int> indices;
|
||||||
|
std::vector<size_t> indicesShape, inputShape;
|
||||||
|
InferenceEngine::Precision netPrecision;
|
||||||
|
InferenceEngine::Precision inPrc, outPrc;
|
||||||
|
InferenceEngine::Layout inLayout, outLayout;
|
||||||
|
std::string targetName;
|
||||||
|
std::tie(inputShape, indicesShape, axis_batchIdx, netPrecision, inPrc, outPrc, inLayout, outLayout, targetName) = obj.param;
|
||||||
|
std::ostringstream result;
|
||||||
|
result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
|
||||||
|
result << "axis=" << std::get<0>(axis_batchIdx) << "_";
|
||||||
|
result << "batchIdx=" << std::get<1>(axis_batchIdx) << "_";
|
||||||
|
result << "indicesShape=" << CommonTestUtils::vec2str(indicesShape) << "_";
|
||||||
|
result << "netPRC=" << netPrecision.name() << "_";
|
||||||
|
result << "inPRC=" << inPrc.name() << "_";
|
||||||
|
result << "outPRC=" << outPrc.name() << "_";
|
||||||
|
result << "inL=" << inLayout << "_";
|
||||||
|
result << "outL=" << outLayout << "_";
|
||||||
|
result << "trgDev=" << targetName << "_";
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Gather7LayerTest::SetUp() {
|
||||||
|
std::tuple<int, int> axis_batchIdx;
|
||||||
|
std::vector<size_t> indicesShape;
|
||||||
|
std::vector<size_t> inputShape;
|
||||||
|
InferenceEngine::Precision netPrecision;
|
||||||
|
std::tie(inputShape, indicesShape, axis_batchIdx, netPrecision, inPrc, outPrc, inLayout, outLayout, targetDevice) = GetParam();
|
||||||
|
int axis = std::get<0>(axis_batchIdx);
|
||||||
|
int batchIdx = std::get<1>(axis_batchIdx);
|
||||||
|
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||||
|
auto functionParams = ngraph::builder::makeParams(ngPrc, { inputShape });
|
||||||
|
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(functionParams));
|
||||||
|
auto indicesNode = ngraph::builder::makeConstant<int>(ngraph::element::i64, indicesShape, {}, true,
|
||||||
|
inputShape[axis < 0 ? axis + inputShape.size() : axis] - 1, 0);
|
||||||
|
auto axisNode = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape({}), { axis });
|
||||||
|
auto gather = std::make_shared<ngraph::opset7::Gather>(paramOuts[0], indicesNode, axisNode, batchIdx);
|
||||||
|
ngraph::ResultVector results{ std::make_shared<ngraph::opset7::Result>(gather) };
|
||||||
|
function = std::make_shared<ngraph::Function>(results, functionParams, "gather");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace LayerTestsDefinitions
|
} // namespace LayerTestsDefinitions
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
#include <ngraph/opsets/opset4.hpp>
|
#include <ngraph/opsets/opset4.hpp>
|
||||||
#include <ngraph/opsets/opset5.hpp>
|
#include <ngraph/opsets/opset5.hpp>
|
||||||
#include <ngraph/opsets/opset6.hpp>
|
#include <ngraph/opsets/opset6.hpp>
|
||||||
|
#include <ngraph/opsets/opset7.hpp>
|
||||||
|
|
||||||
#include "ngraph_functions/utils/data_utils.hpp"
|
#include "ngraph_functions/utils/data_utils.hpp"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user