ScatterNDUpdate, ScatterElementsUpdate, Roll layer tests to API2.0 (#20048)

* `ScatterNDUpdateLayerTest` to API2.0

* `ScatterElementsUpdateLayerTest` to API2.0

* `RollLayerTest` to API2.0
This commit is contained in:
Vitaliy Urusovskij 2023-09-26 17:29:40 +04:00 committed by GitHub
parent 79ff291314
commit c3565e3eac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 455 additions and 91 deletions

View File

@ -4,101 +4,109 @@
#include <vector>
#include "single_layer_tests/roll.hpp"
#include "single_op_tests/roll.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
using ov::test::RollLayerTest;
namespace {
const std::vector<InferenceEngine::Precision> inputPrecision = {
InferenceEngine::Precision::I8,
InferenceEngine::Precision::U8,
InferenceEngine::Precision::I16,
InferenceEngine::Precision::I32,
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::BF16
const std::vector<ov::element::Type> model_types = {
ov::element::i8,
ov::element::u8,
ov::element::i16,
ov::element::i32,
ov::element::f32,
ov::element::bf16
};
const auto testCase2DZeroShifts = ::testing::Combine(
::testing::Values(std::vector<size_t>{17, 19}), // Input shape
::testing::ValuesIn(inputPrecision), // Precision
::testing::Values(std::vector<int64_t>{0, 0}), // Shift
::testing::Values(std::vector<int64_t>{0, 1}), // Axes
const auto test_case_2D_zero_shifts = ::testing::Combine(
::testing::Values(
ov::test::static_shapes_to_test_representation({{17, 19}})), // Input shape
::testing::ValuesIn(model_types), // Model type
::testing::Values(std::vector<int64_t>{0, 0}), // Shift
::testing::Values(std::vector<int64_t>{0, 1}), // Axes
::testing::Values(ov::test::utils::DEVICE_CPU)
);
const auto testCase1D = ::testing::Combine(
::testing::Values(std::vector<size_t>{16}), // Input shape
::testing::ValuesIn(inputPrecision), // Precision
::testing::Values(std::vector<int64_t>{5}), // Shift
::testing::Values(std::vector<int64_t>{0}), // Axes
const auto test_case_1D = ::testing::Combine(
::testing::Values(
ov::test::static_shapes_to_test_representation({ov::Shape{16}})), // Input shape
::testing::ValuesIn(model_types), // Model type
::testing::Values(std::vector<int64_t>{5}), // Shift
::testing::Values(std::vector<int64_t>{0}), // Axes
::testing::Values(ov::test::utils::DEVICE_CPU)
);
const auto testCase2D = ::testing::Combine(
::testing::Values(std::vector<size_t>{600, 450}), // Input shape
::testing::ValuesIn(inputPrecision), // Precision
::testing::Values(std::vector<int64_t>{300, 250}), // Shift
::testing::Values(std::vector<int64_t>{0, 1}), // Axes
const auto test_case_2D = ::testing::Combine(
::testing::Values(
ov::test::static_shapes_to_test_representation({{600, 450}})), // Input shape
::testing::ValuesIn(model_types), // Model type
::testing::Values(std::vector<int64_t>{300, 250}), // Shift
::testing::Values(std::vector<int64_t>{0, 1}), // Axes
::testing::Values(ov::test::utils::DEVICE_CPU)
);
const auto testCase3D = ::testing::Combine(
::testing::Values(std::vector<size_t>{2, 320, 320}), // Input shape
::testing::ValuesIn(inputPrecision), // Precision
::testing::Values(std::vector<int64_t>{160, 160}), // Shift
::testing::Values(std::vector<int64_t>{1, 2}), // Axes
const auto test_case_3D = ::testing::Combine(
::testing::Values(
ov::test::static_shapes_to_test_representation({{2, 320, 320}})), // Input shape
::testing::ValuesIn(model_types), // Model type
::testing::Values(std::vector<int64_t>{160, 160}), // Shift
::testing::Values(std::vector<int64_t>{1, 2}), // Axes
::testing::Values(ov::test::utils::DEVICE_CPU)
);
const auto testCaseNegativeUnorderedAxes4D = ::testing::Combine(
::testing::Values(std::vector<size_t>{3, 11, 6, 4}), // Input shape
::testing::ValuesIn(inputPrecision), // Precision
::testing::Values(std::vector<int64_t>{7, 3}), // Shift
::testing::Values(std::vector<int64_t>{-3, -2}), // Axes
const auto test_case_negative_unordered_axes_4D = ::testing::Combine(
::testing::Values(
ov::test::static_shapes_to_test_representation({{3, 11, 6, 4}})), // Input shape
::testing::ValuesIn(model_types), // Model type
::testing::Values(std::vector<int64_t>{7, 3}), // Shift
::testing::Values(std::vector<int64_t>{-3, -2}), // Axes
::testing::Values(ov::test::utils::DEVICE_CPU)
);
const auto testCaseRepeatingAxes5D = ::testing::Combine(
::testing::Values(std::vector<size_t>{2, 16, 32, 7, 32}), // Input shape
::testing::ValuesIn(inputPrecision), // Precision
::testing::Values(std::vector<int64_t>{16, 15, 10, 2, 1, 7, 2, 8, 1, 1}), // Shift
const auto test_case_repeating_axes_5D = ::testing::Combine(
::testing::Values(
ov::test::static_shapes_to_test_representation({{2, 16, 32, 7, 32}})), // Input shape
::testing::ValuesIn(model_types), // Model type
::testing::Values(std::vector<int64_t>{16, 15, 10, 2, 1, 7, 2, 8, 1, 1}), // Shift
::testing::Values(std::vector<int64_t>{-1, -2, -3, 1, 0, 3, 3, 2, -2, -3}), // Axes
::testing::Values(ov::test::utils::DEVICE_CPU)
);
const auto testCaseNegativeShifts6D = ::testing::Combine(
::testing::Values(std::vector<size_t>{4, 16, 3, 6, 5, 2}), // Input shape
::testing::ValuesIn(inputPrecision), // Precision
::testing::Values(std::vector<int64_t>{-2, -15, -2, -1, -4, -1}), // Shift
::testing::Values(std::vector<int64_t>{0, 1, 2, 3, 4, 5}), // Axes
const auto test_case_negative_shifts_6D = ::testing::Combine(
::testing::Values(
ov::test::static_shapes_to_test_representation({{4, 16, 3, 6, 5, 2}})), // Input shape
::testing::ValuesIn(model_types), // Model type
::testing::Values(std::vector<int64_t>{-2, -15, -2, -1, -4, -1}), // Shift
::testing::Values(std::vector<int64_t>{0, 1, 2, 3, 4, 5}), // Axes
::testing::Values(ov::test::utils::DEVICE_CPU)
);
const auto testCaseUnordNegAxesAndShifts10D = ::testing::Combine(
::testing::Values(std::vector<size_t>{2, 2, 4, 2, 3, 6, 3, 2, 3, 2}), // Input shape
::testing::ValuesIn(inputPrecision), // Precision
::testing::Values(std::vector<int64_t>{-2, -1, 1, 1, 1, -2}), // Shift
::testing::Values(std::vector<int64_t>{-6, -4, -3, 1, -10, -2}), // Axes
const auto test_case_unord_neg_axes_and_shifts_10D = ::testing::Combine(
::testing::Values(
ov::test::static_shapes_to_test_representation({{2, 2, 4, 2, 3, 6, 3, 2, 3, 2}})), // Input shape
::testing::ValuesIn(model_types), // Model type
::testing::Values(std::vector<int64_t>{-2, -1, 1, 1, 1, -2}), // Shift
::testing::Values(std::vector<int64_t>{-6, -4, -3, 1, -10, -2}), // Axes
::testing::Values(ov::test::utils::DEVICE_CPU)
);
INSTANTIATE_TEST_SUITE_P(smoke_TestsRoll_2d_zero_shifts, RollLayerTest,
testCase2DZeroShifts, RollLayerTest::getTestCaseName);
test_case_2D_zero_shifts, RollLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_TestsRoll_1d, RollLayerTest,
testCase1D, RollLayerTest::getTestCaseName);
test_case_1D, RollLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_TestsRoll_2d, RollLayerTest,
testCase2D, RollLayerTest::getTestCaseName);
test_case_2D, RollLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_TestsRoll_3d, RollLayerTest,
testCase3D, RollLayerTest::getTestCaseName);
test_case_3D, RollLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_TestsRoll_negative_unordered_axes_4d, RollLayerTest,
testCaseNegativeUnorderedAxes4D, RollLayerTest::getTestCaseName);
test_case_negative_unordered_axes_4D, RollLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_TestsRoll_negative_unordered_axes_5d, RollLayerTest,
testCaseRepeatingAxes5D, RollLayerTest::getTestCaseName);
test_case_repeating_axes_5D, RollLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_TestsRoll_negative_shifts_6d, RollLayerTest,
testCaseNegativeShifts6D, RollLayerTest::getTestCaseName);
test_case_negative_shifts_6D, RollLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_TestsRoll_unord_neg_shifts_and_axes_10d, RollLayerTest,
testCaseUnordNegAxesAndShifts10D, RollLayerTest::getTestCaseName);
test_case_unord_neg_axes_and_shifts_10D, RollLayerTest::getTestCaseName);
} // namespace

View File

@ -3,36 +3,60 @@
//
#include <vector>
#include <ngraph/opsets/opset3.hpp>
#include "single_layer_tests/scatter_ND_update.hpp"
#include "single_op_tests/scatter_ND_update.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
using namespace ngraph::opset3;
using ov::test::ScatterNDUpdateLayerTest;
namespace {
const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
const std::vector<ov::element::Type> model_types = {
ov::element::f32,
ov::element::f16,
ov::element::i32,
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
const std::vector<ov::element::Type> idx_types = {
ov::element::i32,
ov::element::i64,
};
// map<inputShape map<indicesShape, indicesValue>>
// updateShape is gotten from inputShape and indicesShape
// map<input_shape map<indices_shape, indices_value>>
// update_shape is gotten from input_shape and indices_shape
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>> sliceSelectInShape {
{{10, 9, 9, 11}, {{{4, 1}, {1, 3, 5, 7}}, {{1, 2}, {4, 6}}, {{2, 3}, {0, 1, 1, 2, 2, 2}}, {{1, 4}, {5, 5, 4, 9}}}},
{{10, 9, 10, 9, 10}, {{{2, 2, 1}, {5, 6, 2, 8}}, {{2, 3}, {0, 4, 6, 5, 7, 1}}}},
};
std::vector<ov::test::scatterNDUpdateSpecParams> combineShapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>>& input_shapes) {
std::vector<ov::test::scatterNDUpdateSpecParams> resVec;
for (auto& input_shape : input_shapes) {
for (auto& item : input_shape.second) {
auto indices_shape = item.first;
size_t indices_rank = indices_shape.size();
std::vector<size_t> update_shape;
for (size_t i = 0; i < indices_rank - 1; i++) {
update_shape.push_back(indices_shape[i]);
}
auto src_shape = input_shape.first;
for (size_t j = indices_shape[indices_rank - 1]; j < src_shape.size(); j++) {
update_shape.push_back(src_shape[j]);
}
std::vector<ov::Shape> in_shapes{src_shape, update_shape};
resVec.push_back(
ov::test::scatterNDUpdateSpecParams{
ov::test::static_shapes_to_test_representation(in_shapes),
ov::Shape{indices_shape},
item.second});
}
}
return resVec;
}
const auto ScatterNDUpdateCases = ::testing::Combine(
::testing::ValuesIn(ScatterNDUpdateLayerTest::combineShapes(sliceSelectInShape)),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::ValuesIn(combineShapes(sliceSelectInShape)),
::testing::ValuesIn(model_types),
::testing::ValuesIn(idx_types),
::testing::Values(ov::test::utils::DEVICE_CPU)
);

View File

@ -3,46 +3,59 @@
//
#include <vector>
#include <ngraph/opsets/opset3.hpp>
#include "single_layer_tests/scatter_elements_update.hpp"
#include "single_op_tests/scatter_elements_update.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
using namespace ngraph::opset3;
using ov::test::ScatterElementsUpdateLayerTest;
namespace {
// map<inputShape, map<indicesShape, axis>>
// map<input_shape, map<indices_shape, axis>>
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape {
{{10, 12, 15}, {{{1, 2, 4}, {0, 1, 2}}, {{2, 2, 2}, {-1, -2, -3}}}},
{{15, 9, 8, 12}, {{{1, 2, 2, 2}, {0, 1, 2, 3}}, {{1, 2, 1, 4}, {-1, -2, -3, -4}}}},
{{9, 9, 8, 8, 11, 10}, {{{1, 2, 1, 2, 1, 2}, {5, -3}}}},
};
// index value should not be random data
const std::vector<std::vector<size_t>> idxValue = {
const std::vector<std::vector<size_t>> idx_value = {
{1, 0, 4, 6, 2, 3, 7, 5}
};
const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
const std::vector<ov::element::Type> model_types = {
ov::element::f32,
ov::element::f16,
ov::element::i32,
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
const std::vector<ov::element::Type> idx_types = {
ov::element::i32,
ov::element::i64,
};
const auto ScatterEltUpdateCases = ::testing::Combine(
::testing::ValuesIn(ScatterElementsUpdateLayerTest::combineShapes(axesShapeInShape)),
::testing::ValuesIn(idxValue),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
std::vector<ov::test::axisShapeInShape> combine_shapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>>& input_shapes) {
std::vector<ov::test::axisShapeInShape> res_vec;
for (auto& input_shape : input_shapes) {
for (auto& item : input_shape.second) {
for (auto& elt : item.second) {
res_vec.push_back(ov::test::axisShapeInShape{
ov::test::static_shapes_to_test_representation({input_shape.first, item.first}),
elt});
}
}
}
return res_vec;
}
const auto scatter_elt_update_cases = ::testing::Combine(
::testing::ValuesIn(combine_shapes(axesShapeInShape)),
::testing::ValuesIn(idx_value),
::testing::ValuesIn(model_types),
::testing::ValuesIn(idx_types),
::testing::Values(ov::test::utils::DEVICE_CPU)
);
INSTANTIATE_TEST_SUITE_P(smoke_ScatterEltsUpdate, ScatterElementsUpdateLayerTest,
ScatterEltUpdateCases, ScatterElementsUpdateLayerTest::getTestCaseName);
scatter_elt_update_cases, ScatterElementsUpdateLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shared_test_classes/single_op/roll.hpp"
namespace ov {
namespace test {
TEST_P(RollLayerTest, Inference) {
run();
}
} // namespace test
} // namespace ov

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shared_test_classes/single_op/scatter_ND_update.hpp"
namespace ov {
namespace test {
TEST_P(ScatterNDUpdateLayerTest, Inference) {
run();
}
} // namespace test
} // namespace ov

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shared_test_classes/single_op/scatter_elements_update.hpp"
namespace ov {
namespace test {
TEST_P(ScatterElementsUpdateLayerTest, Inference) {
run();
}
} // namespace test
} // namespace ov

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include "shared_test_classes/base/ov_subgraph.hpp"
namespace ov {
namespace test {
typedef std::tuple<
std::vector<InputShape>, // Input shapes
ov::element::Type, // Model type
std::vector<int64_t>, // Shift
std::vector<int64_t>, // Axes
ov::test::TargetDevice // Device name
> rollParams;
class RollLayerTest : public testing::WithParamInterface<rollParams>, virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<rollParams>& obj);
protected:
void SetUp() override;
};
} // namespace test
} // namespace ov

View File

@ -0,0 +1,38 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "shared_test_classes/base/ov_subgraph.hpp"
namespace ov {
namespace test {
using scatterNDUpdateSpecParams = std::tuple<
std::vector<InputShape>, // input, update shapes
ov::Shape, // indices shape
std::vector<size_t> // indices value
>;
using scatterNDUpdateParamsTuple = typename std::tuple<
scatterNDUpdateSpecParams,
ov::element::Type, // Model type
ov::element::Type, // Indices type
ov::test::TargetDevice // Device name
>;
class ScatterNDUpdateLayerTest : public testing::WithParamInterface<scatterNDUpdateParamsTuple>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<scatterNDUpdateParamsTuple> &obj);
protected:
void SetUp() override;
};
} // namespace test
} // namespace ov

View File

@ -0,0 +1,37 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "shared_test_classes/base/ov_subgraph.hpp"
namespace ov {
namespace test {
using axisShapeInShape = std::tuple<
std::vector<InputShape>, // Input, update/indices shapes
int // Axis
>;
using scatterElementsUpdateParamsTuple = typename std::tuple<
axisShapeInShape, // Shape description
std::vector<size_t>, // Indices value
ov::element::Type, // Model type
ov::element::Type, // Indices type
ov::test::TargetDevice // Device name
>;
class ScatterElementsUpdateLayerTest : public testing::WithParamInterface<scatterElementsUpdateParamsTuple>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<scatterElementsUpdateParamsTuple> &obj);
protected:
void SetUp() override;
};
} // namespace test
} // namespace ov

View File

@ -0,0 +1,55 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/single_op/roll.hpp"
namespace ov {
namespace test {
std::string RollLayerTest::getTestCaseName(const testing::TestParamInfo<rollParams>& obj) {
std::vector<InputShape> input_shapes;
ov::element::Type model_type;
std::vector<int64_t> shift;
std::vector<int64_t> axes;
std::string target_device;
std::tie(input_shapes, model_type, shift, axes, target_device) = obj.param;
std::ostringstream result;
result << "IS=(";
for (size_t i = 0lu; i < input_shapes.size(); i++) {
result << ov::test::utils::partialShape2str({input_shapes[i].first})
<< (i < input_shapes.size() - 1lu ? "_" : "");
}
result << ")_TS=";
for (size_t i = 0lu; i < input_shapes.front().second.size(); i++) {
result << "{";
for (size_t j = 0lu; j < input_shapes.size(); j++) {
result << ov::test::utils::vec2str(input_shapes[j].second[i]) << (j < input_shapes.size() - 1lu ? "_" : "");
}
result << "}_";
}
result << "modelType=" << model_type.to_string() << "_";
result << "Shift=" << ov::test::utils::vec2str(shift) << "_";
result << "Axes=" << ov::test::utils::vec2str(axes) << "_";
result << "trgDev=" << target_device;
return result.str();
}
void RollLayerTest::SetUp() {
std::vector<InputShape> input_shapes;
ov::element::Type model_type;
std::vector<int64_t> shift;
std::vector<int64_t> axes;
std::string target_device;
std::tie(input_shapes, model_type, shift, axes, targetDevice) = this->GetParam();
init_input_shapes(input_shapes);
auto param = std::make_shared<ov::op::v0::Parameter>(model_type, inputDynamicShapes.at(0));
auto shift_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{shift.size()}, shift);
auto axes_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{axes.size()}, axes);
auto roll = std::make_shared<ov::op::v7::Roll>(param, shift_const, axes_const);
function = std::make_shared<ov::Model>(roll->outputs(), ov::ParameterVector{param}, "Roll");
}
} // namespace test
} // namespace ov

View File

@ -0,0 +1,57 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/single_op/scatter_ND_update.hpp"
namespace ov {
namespace test {
std::string ScatterNDUpdateLayerTest::getTestCaseName(const testing::TestParamInfo<scatterNDUpdateParamsTuple> &obj) {
auto shapes_ss = [](const InputShape& shape) {
std::stringstream ss;
ss << "_IS=(" << ov::test::utils::partialShape2str({shape.first}) << ")_TS=";
for (size_t j = 0lu; j < shape.second.size(); j++)
ss << "{" << ov::test::utils::vec2str(shape.second[j]) << "}";
return ss;
};
scatterNDUpdateSpecParams shapes_desc;
std::vector<InputShape> input_shapes;
ov::Shape indices_shape;
std::vector<size_t> indices_value;
ov::element::Type model_type, indices_type;
std::string target_device;
std::tie(shapes_desc, model_type, indices_type, target_device) = obj.param;
std::tie(input_shapes, indices_shape, indices_value) = shapes_desc;
std::ostringstream result;
result << "InputShape=" << shapes_ss(input_shapes.at(0)).str() << "_";
result << "IndicesShape=" << ov::test::utils::vec2str(indices_shape) << "_";
result << "IndicesValue=" << ov::test::utils::vec2str(indices_value) << "_";
result << "UpdateShape=" << shapes_ss(input_shapes.at(1)).str() << "_";
result << "modelType=" << model_type.to_string() << "_";
result << "idxType=" << indices_type.to_string() << "_";
result << "trgDev=" << target_device;
return result.str();
}
void ScatterNDUpdateLayerTest::SetUp() {
scatterNDUpdateSpecParams shapes_desc;
std::vector<InputShape> input_shapes;
ov::Shape indices_shape;
std::vector<size_t> indices_value;
ov::element::Type model_type, indices_type;
std::tie(shapes_desc, model_type, indices_type, targetDevice) = this->GetParam();
std::tie(input_shapes, indices_shape, indices_value) = shapes_desc;
init_input_shapes(input_shapes);
auto param = std::make_shared<ov::op::v0::Parameter>(model_type, inputDynamicShapes.at(0));
auto update_param = std::make_shared<ov::op::v0::Parameter>(model_type, inputDynamicShapes.at(1));
auto indices_const = std::make_shared<ov::op::v0::Constant>(indices_type, indices_shape, indices_value);
auto scatter_nd = std::make_shared<ov::op::v3::ScatterNDUpdate>(param, indices_const, update_param);
function = std::make_shared<ov::Model>(scatter_nd->outputs(), ov::ParameterVector{param, update_param}, "ScatterNDUpdate");
}
} // namespace test
} // namespace ov

View File

@ -0,0 +1,57 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/single_op/scatter_elements_update.hpp"
namespace ov {
namespace test {
std::string ScatterElementsUpdateLayerTest::getTestCaseName(const testing::TestParamInfo<scatterElementsUpdateParamsTuple> &obj) {
auto shapes_ss = [](const InputShape& shape) {
std::stringstream ss;
ss << "_IS=(" << ov::test::utils::partialShape2str({shape.first}) << ")_TS=";
for (size_t j = 0lu; j < shape.second.size(); j++)
ss << "{" << ov::test::utils::vec2str(shape.second[j]) << "}";
return ss;
};
axisShapeInShape shapes_desc;
std::vector<InputShape> input_shapes;
int axis;
std::vector<size_t> indices_value;
ov::element::Type model_type, indices_type;
std::string target_device;
std::tie(shapes_desc, indices_value, model_type, indices_type, target_device) = obj.param;
std::tie(input_shapes, axis) = shapes_desc;
std::ostringstream result;
result << "InputShape=" << shapes_ss(input_shapes.at(0)).str() << "_";
result << "IndicesShape=" << ov::test::utils::vec2str(input_shapes.at(1).second) << "_";
result << "Axis=" << axis << "_";
result << "modelType=" << model_type.to_string() << "_";
result << "idxType=" << indices_type.to_string() << "_";
result << "trgDev=" << target_device;
return result.str();
}
void ScatterElementsUpdateLayerTest::SetUp() {
axisShapeInShape shapes_desc;
std::vector<InputShape> input_shapes;
int axis;
std::vector<size_t> indices_value;
ov::element::Type model_type, indices_type;
std::string target_device;
std::tie(shapes_desc, indices_value, model_type, indices_type, targetDevice) = this->GetParam();
std::tie(input_shapes, axis) = shapes_desc;
init_input_shapes(input_shapes);
auto param = std::make_shared<ov::op::v0::Parameter>(model_type, inputDynamicShapes.at(0));
auto update_param = std::make_shared<ov::op::v0::Parameter>(model_type, inputDynamicShapes.at(1));
auto indices_const = std::make_shared<ov::op::v0::Constant>(indices_type, targetStaticShapes.at(0).at(1), indices_value);
auto axis_const =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, std::vector<int>{axis});
auto scatter_elements_update = std::make_shared<ov::op::v3::ScatterElementsUpdate>(param, indices_const, update_param, axis_const);
function = std::make_shared<ov::Model>(scatter_elements_update->outputs(), ov::ParameterVector{param, update_param}, "ScatterElementsUpdate");
}
} // namespace test
} // namespace ov