add scatterNDupdate serialize SLT (#4962)

Co-authored-by: Patryk Elszkowski <patryk.elszkowki@intel.com>
This commit is contained in:
Patryk Elszkowski 2021-03-26 05:08:01 +01:00 committed by GitHub
parent e55954e81f
commit 4d2fc1c678
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 7 deletions

View File

@ -0,0 +1,50 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <map>
#include <vector>
#include "shared_test_classes/single_layer/scatter_ND_update.hpp"
using namespace LayerTestsDefinitions;
namespace {
TEST_P(ScatterNDUpdateLayerTest, Serialize) {
Serialize();
}
const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};
// map<inputShape map<indicesShape, indicesValue>>
// updateShape is gotten from inputShape and indicesShape
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>>
sliceSelectInShape{
{{10, 9, 9, 11},
{{{4, 1}, {1, 3, 5, 7}},
{{1, 2}, {4, 6}},
{{2, 3}, {0, 1, 1, 2, 2, 2}},
{{1, 4}, {5, 5, 4, 9}}}},
{{10, 9, 10, 9, 10}, {{{2, 2, 1}, {5, 6, 2, 8}}, {{2, 3}, {0, 4, 6, 5, 7, 1}}}},
};
const auto ScatterNDUpdateCases = ::testing::Combine(
::testing::ValuesIn(ScatterNDUpdateLayerTest::combineShapes(sliceSelectInShape)),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU));
INSTANTIATE_TEST_CASE_P(
smoke_ScatterNDUpdateLayerTestSerialization,
ScatterNDUpdateLayerTest,
ScatterNDUpdateCases,
ScatterNDUpdateLayerTest::getTestCaseName);
} // namespace

View File

@ -12,14 +12,14 @@
#include "shared_test_classes/base/layer_test_utils.hpp"
namespace LayerTestsDefinitions {
using sliceSelcetInShape = std::tuple<
using sliceSelectInShape = std::tuple<
std::vector<size_t>, // input shape
std::vector<size_t>, // indices shape
std::vector<size_t>, // indices value
std::vector<size_t>>; // update shape
using scatterNDUpdateParamsTuple = typename std::tuple<
sliceSelcetInShape, // Input description
sliceSelectInShape, // Input description
InferenceEngine::Precision, // Network precision
InferenceEngine::Precision, // indices precision
std::string>; // Device name
@ -28,7 +28,7 @@ class ScatterNDUpdateLayerTest : public testing::WithParamInterface<scatterNDUpd
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<scatterNDUpdateParamsTuple> &obj);
static std::vector<sliceSelcetInShape> combineShapes(
static std::vector<sliceSelectInShape> combineShapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>>& inputShapes);
protected:

View File

@ -8,7 +8,7 @@
namespace LayerTestsDefinitions {
std::string ScatterNDUpdateLayerTest::getTestCaseName(const testing::TestParamInfo<scatterNDUpdateParamsTuple> &obj) {
sliceSelcetInShape shapeDescript;
sliceSelectInShape shapeDescript;
std::vector<size_t> inShape;
std::vector<size_t> indicesShape;
std::vector<size_t> indicesValue;
@ -28,9 +28,9 @@ std::string ScatterNDUpdateLayerTest::getTestCaseName(const testing::TestParamIn
return result.str();
}
std::vector<sliceSelcetInShape> ScatterNDUpdateLayerTest::combineShapes(
std::vector<sliceSelectInShape> ScatterNDUpdateLayerTest::combineShapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>>& inputShapes) {
std::vector<sliceSelcetInShape> resVec;
std::vector<sliceSelectInShape> resVec;
for (auto& inputShape : inputShapes) {
for (auto& item : inputShape.second) {
auto indiceShape = item.first;
@ -50,7 +50,7 @@ std::vector<sliceSelcetInShape> ScatterNDUpdateLayerTest::combineShapes(
}
void ScatterNDUpdateLayerTest::SetUp() {
sliceSelcetInShape shapeDescript;
sliceSelectInShape shapeDescript;
InferenceEngine::SizeVector inShape;
InferenceEngine::SizeVector indicesShape;
InferenceEngine::SizeVector indicesValue;