Added GenerateInput to Reduce tests (#5004)

* Added GenerateInput to Reduce tests

* Skip ReduceProd CPU tests
This commit is contained in:
Liubov Batanina 2021-03-30 15:06:36 +03:00 committed by GitHub
parent 41f0eb51c5
commit 2f257a2955
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 0 deletions

View File

@ -32,6 +32,7 @@ class ReduceOpsLayerTest : public testing::WithParamInterface<reduceMeanParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<reduceMeanParams> obj);
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo &info) const override;
protected:
void SetUp() override;

View File

@ -69,6 +69,21 @@ void ReduceOpsLayerTest::SetUp() {
const ngraph::ResultVector results{std::make_shared<ngraph::opset3::Result>(reduce)};
function = std::make_shared<ngraph::Function>(results, params, "Reduce");
}
InferenceEngine::Blob::Ptr ReduceOpsLayerTest::GenerateInput(const InferenceEngine::InputInfo &info) const {
ngraph::helpers::ReductionType reductionType = std::get<3>(GetParam());
InferenceEngine::Precision netPrecision = std::get<4>(GetParam());
if (reductionType == ngraph::helpers::ReductionType::LogicalOr ||
reductionType == ngraph::helpers::ReductionType::LogicalAnd) {
return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), 2, 0);
} else if (!netPrecision.is_float()) {
return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), 5, 0);
}
auto td = info.getTensorDesc();
auto blob = make_blob_with_precision(td);
blob->allocate();
CommonTestUtils::fill_data_random_float<InferenceEngine::Precision::FP32>(blob, 5, 0, 1000);
return blob;
}
InferenceEngine::Blob::Ptr ReduceOpsLayerWithSpecificInputTest::GenerateInput(const InferenceEngine::InputInfo &info) const {
auto axis_vec = std::get<0>(GetParam());