[GPU] ScatterElementsUpdate blocked layout support (#12465)
* add parameterized test * add blocked layouts support * support for mixed input formats * fix RHEL 8.2 build * add scatter_elements_update to whitelist for blocked formats * Added bs_fs_yx_bsv16_fsv32 format
This commit is contained in:
@@ -76,17 +76,26 @@ public:
|
||||
namespace detail {
|
||||
|
||||
attach_scatter_elements_update_impl::attach_scatter_elements_update_impl() {
|
||||
implementation_map<scatter_elements_update>::add(impl_types::ocl, scatter_elements_update_impl::create, {
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::i32, format::bfyx),
|
||||
std::make_tuple(data_types::f32, format::bfzyx),
|
||||
std::make_tuple(data_types::f16, format::bfzyx),
|
||||
std::make_tuple(data_types::i32, format::bfzyx),
|
||||
std::make_tuple(data_types::f32, format::bfwzyx),
|
||||
std::make_tuple(data_types::f16, format::bfwzyx),
|
||||
std::make_tuple(data_types::i32, format::bfwzyx),
|
||||
});
|
||||
auto types = {data_types::f16, data_types::f32, data_types::i32};
|
||||
auto formats = {
|
||||
format::bfyx,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::bs_fs_yx_bsv16_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
format::bs_fs_yx_bsv16_fsv32,
|
||||
format::bs_fs_yx_bsv32_fsv32,
|
||||
format::bfzyx,
|
||||
format::b_fs_zyx_fsv16,
|
||||
format::b_fs_zyx_fsv32,
|
||||
format::bs_fs_zyx_bsv16_fsv32,
|
||||
format::bs_fs_zyx_bsv16_fsv16,
|
||||
format::bs_fs_zyx_bsv32_fsv32,
|
||||
format::bs_fs_zyx_bsv32_fsv16,
|
||||
format::bfwzyx
|
||||
};
|
||||
|
||||
implementation_map<scatter_elements_update>::add(impl_types::ocl, scatter_elements_update_impl::create, types, formats);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@@ -1448,7 +1448,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
||||
prim.type() != cldnn::eye::type_id() &&
|
||||
prim.type() != cldnn::generate_proposals::type_id() &&
|
||||
prim.type() != cldnn::reverse::type_id() &&
|
||||
prim.type() != cldnn::reorg_yolo::type_id()) {
|
||||
prim.type() != cldnn::reorg_yolo::type_id() &&
|
||||
prim.type() != cldnn::scatter_elements_update::type_id()) {
|
||||
can_use_fsv16 = false;
|
||||
}
|
||||
|
||||
@@ -1488,7 +1489,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
||||
prim.type() != cldnn::eye::type_id() &&
|
||||
prim.type() != cldnn::generate_proposals::type_id() &&
|
||||
prim.type() != cldnn::reverse::type_id() &&
|
||||
prim.type() != cldnn::reorg_yolo::type_id()) {
|
||||
prim.type() != cldnn::reorg_yolo::type_id() &&
|
||||
prim.type() != cldnn::scatter_elements_update::type_id()) {
|
||||
can_use_bs_fs_yx_bsv16_fsv16 = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
#include "include/batch_headers/data_types.cl"
|
||||
#include "include/batch_headers/fetch_data.cl"
|
||||
|
||||
#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
|
||||
#define GET_INDICES_INDEX(idx_order) INPUT1_GET_INDEX(idx_order)
|
||||
#define GET_UPDATES_INDEX(idx_order) INPUT2_GET_INDEX(idx_order)
|
||||
#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)
|
||||
#if OUTPUT_DIMS == 4
|
||||
#define ORDER b,f,y,x
|
||||
@@ -87,8 +88,8 @@ KERNEL(scatter_elements_update_ref)(const __global INPUT0_TYPE* data,
|
||||
const uint idx_b = dim2 / INPUT2_FEATURE_NUM;
|
||||
#endif
|
||||
|
||||
const uint updates_idx = GET_UPDATES_INDEX(INPUT2, IDX_ORDER);
|
||||
INPUT1_TYPE index = indices[(int)updates_idx];
|
||||
const uint indices_idx = GET_INDICES_INDEX(IDX_ORDER);
|
||||
INPUT1_TYPE index = indices[(int)indices_idx];
|
||||
|
||||
#if OUTPUT_DIMS == 4
|
||||
#if AXIS_VALUE == 0
|
||||
@@ -129,6 +130,7 @@ KERNEL(scatter_elements_update_ref)(const __global INPUT0_TYPE* data,
|
||||
#endif
|
||||
const uint output_idx = GET_OUTPUT_INDEX(ORDER);
|
||||
|
||||
const uint updates_idx = GET_UPDATES_INDEX(IDX_ORDER);
|
||||
INPUT2_TYPE val = updates[(int)updates_idx];
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS_SECOND_KERNEL;
|
||||
@@ -139,6 +141,7 @@ KERNEL(scatter_elements_update_ref)(const __global INPUT0_TYPE* data,
|
||||
#endif
|
||||
}
|
||||
|
||||
#undef GET_INDICES_INDEX
|
||||
#undef GET_UPDATES_INDEX
|
||||
#undef GET_OUTPUT_INDEX
|
||||
#undef IDX_ORDER
|
||||
|
||||
@@ -34,20 +34,36 @@ static size_t GetScatterElementsUpdateChannelIndex(const scatter_elements_update
|
||||
|
||||
ParamsKey ScatterElementsUpdateKernelRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableInputLayout(DataLayout::bfzyx);
|
||||
k.EnableOutputLayout(DataLayout::bfzyx);
|
||||
k.EnableInputLayout(DataLayout::bfwzyx);
|
||||
k.EnableOutputLayout(DataLayout::bfwzyx);
|
||||
const std::vector<Datatype> supportedTypes{
|
||||
Datatype::F16, Datatype::F32, Datatype::INT32, Datatype::INT8, Datatype::UINT8
|
||||
};
|
||||
for (const auto t : supportedTypes) {
|
||||
k.EnableInputDataType(t);
|
||||
k.EnableOutputDataType(t);
|
||||
}
|
||||
|
||||
const std::vector<DataLayout> supportedLayots{
|
||||
DataLayout::bfyx,
|
||||
DataLayout::b_fs_yx_fsv16,
|
||||
DataLayout::b_fs_yx_fsv32,
|
||||
DataLayout::bs_fs_yx_bsv16_fsv16,
|
||||
DataLayout::bs_fs_yx_bsv32_fsv16,
|
||||
DataLayout::bs_fs_yx_bsv16_fsv32,
|
||||
DataLayout::bs_fs_yx_bsv32_fsv32,
|
||||
DataLayout::bfzyx,
|
||||
DataLayout::b_fs_zyx_fsv16,
|
||||
DataLayout::b_fs_zyx_fsv32,
|
||||
DataLayout::bs_fs_zyx_bsv16_fsv32,
|
||||
DataLayout::bs_fs_zyx_bsv16_fsv16,
|
||||
DataLayout::bs_fs_zyx_bsv32_fsv32,
|
||||
DataLayout::bs_fs_zyx_bsv32_fsv16,
|
||||
DataLayout::bfwzyx
|
||||
};
|
||||
for (const auto l : supportedLayots) {
|
||||
k.EnableInputLayout(l);
|
||||
k.EnableOutputLayout(l);
|
||||
}
|
||||
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
@@ -79,22 +95,24 @@ CommonDispatchData ScatterElementsUpdateKernelRef::SetDefault(const scatter_elem
|
||||
|
||||
const auto& scope = is_second ? indices : output;
|
||||
|
||||
switch (params.inputs[0].GetLayout()) {
|
||||
case DataLayout::bfyx:
|
||||
const auto rank = params.inputs[0].GetDims().size();
|
||||
|
||||
switch (rank) {
|
||||
case 4:
|
||||
dispatchData.gws = {scope.X().v, scope.Y().v, scope.Feature().v * scope.Batch().v};
|
||||
dims_by_gws = {{Tensor::DataChannelName::X},
|
||||
{Tensor::DataChannelName::Y},
|
||||
{Tensor::DataChannelName::FEATURE, Tensor::DataChannelName::BATCH}};
|
||||
break;
|
||||
|
||||
case DataLayout::bfzyx:
|
||||
case 5:
|
||||
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v, scope.Feature().v * scope.Batch().v};
|
||||
dims_by_gws = {{Tensor::DataChannelName::X, Tensor::DataChannelName::Y},
|
||||
{Tensor::DataChannelName::Z},
|
||||
{Tensor::DataChannelName::FEATURE, Tensor::DataChannelName::BATCH}};
|
||||
break;
|
||||
|
||||
case DataLayout::bfwzyx:
|
||||
case 6:
|
||||
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v * scope.W().v, scope.Feature().v * scope.Batch().v};
|
||||
dims_by_gws = {{Tensor::DataChannelName::X, Tensor::DataChannelName::Y},
|
||||
{Tensor::DataChannelName::Z, Tensor::DataChannelName::W},
|
||||
|
||||
@@ -90,3 +90,305 @@ TEST(scatter_elements_update_gpu_fp16, d2411_axisF) {
|
||||
EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template<typename T>
|
||||
struct ScatterElementsUpdateParams {
|
||||
int64_t axis;
|
||||
tensor data_tensor;
|
||||
std::vector<T> data;
|
||||
tensor indices_tensor;
|
||||
std::vector<T> indices;
|
||||
std::vector<T> updates;
|
||||
std::vector<T> expected;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
using ScatterElementsUpdateParamsWithFormat = std::tuple<
|
||||
ScatterElementsUpdateParams<T>,
|
||||
format::type, // source (plain) layout
|
||||
format::type, // target (blocked) data layout
|
||||
format::type, // target (blocked) indices layout
|
||||
format::type // target (blocked) updates layout
|
||||
>;
|
||||
|
||||
const std::vector<format::type> formats2D{
|
||||
format::bfyx,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::bs_fs_yx_bsv16_fsv16,
|
||||
format::bs_fs_yx_bsv16_fsv32,
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv32,
|
||||
};
|
||||
|
||||
const std::vector<format::type> formats3D{
|
||||
format::bfzyx,
|
||||
format::b_fs_zyx_fsv16,
|
||||
format::bs_fs_zyx_bsv16_fsv16
|
||||
};
|
||||
|
||||
const std::vector<format::type> formats4D{
|
||||
format::bfwzyx
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
std::vector<T> getValues(const std::vector<float> &values) {
|
||||
std::vector<T> result(values.begin(), values.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::vector<ScatterElementsUpdateParams<T>> generateScatterElementsUpdateParams2D() {
|
||||
const std::vector<ScatterElementsUpdateParams<T>> result = {
|
||||
{ 1,
|
||||
tensor{2, 4, 1, 1},
|
||||
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7 }),
|
||||
tensor{2, 2, 1, 1},
|
||||
getValues<T>({ 0, 1, 2, 3 }),
|
||||
getValues<T>({ -10, -11, -12, -13 }),
|
||||
getValues<T>({ -10, -11, 2, 3, 4, 5, -12, -13 })
|
||||
},
|
||||
{ 2,
|
||||
tensor{2, 1, 2, 2},
|
||||
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7 }),
|
||||
tensor{2, 1, 2, 1},
|
||||
getValues<T>({ 0, 1, 0, 1 }),
|
||||
getValues<T>({ -10, -11, -12, -13 }),
|
||||
getValues<T>({ -10, 1, 2, -11, -12, 5, 6, -13 })
|
||||
},
|
||||
{ 3,
|
||||
tensor{2, 1, 2, 2},
|
||||
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7 }),
|
||||
tensor{2, 1, 1, 2},
|
||||
getValues<T>({ 0, 1, 0, 1 }),
|
||||
getValues<T>({ -10, -11, -12, -13 }),
|
||||
getValues<T>({ -10, 1, 2, -11, -12, 5, 6, -13 })
|
||||
},
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::vector<ScatterElementsUpdateParams<T>> generateScatterElementsUpdateParams3D() {
|
||||
const std::vector<ScatterElementsUpdateParams<T>> result = {
|
||||
{ 1,
|
||||
tensor{2, 4, 1, 1, 3},
|
||||
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 }),
|
||||
tensor{2, 1, 1, 1, 2},
|
||||
getValues<T>({ 0, 3, 1, 2 }),
|
||||
getValues<T>({ -100, -110, -120, -130 }),
|
||||
getValues<T>({ -100, 1, 2, 3, 4, 5, 6, 7, 8, 9, -110, 11, 12, 13, 14, -120, 16, 17, 18, -130, 20, 21, 22, 23 })
|
||||
},
|
||||
{ 4,
|
||||
tensor{2, 4, 1, 1, 3},
|
||||
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 }),
|
||||
tensor{2, 1, 1, 1, 2},
|
||||
getValues<T>({ 0, 1, 0, 1 }),
|
||||
getValues<T>({ -100, -110, -120, -130 }),
|
||||
getValues<T>({ -100, 1, -110, 3, 4, 5, 6, 7, 8, 9, 10, 11, -120, 13, -130, 15, 16, 17, 18, 19, 20, 21, 22, 23 })
|
||||
},
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::vector<ScatterElementsUpdateParams<T>> generateScatterElementsUpdateParams4D() {
|
||||
const std::vector<ScatterElementsUpdateParams<T>> result = {
|
||||
{ 5,
|
||||
tensor{2, 4, 2, 1, 1, 3},
|
||||
getValues<T>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}),
|
||||
tensor{2, 1, 1, 1, 1, 2},
|
||||
getValues<T>({2, 1, 1, 1, 2}),
|
||||
getValues<T>({-100, -110, -120, -130}),
|
||||
getValues<T>({0, 1, -100, -110, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, -120, 26, -130, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}),
|
||||
}
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
float getError() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
template<>
|
||||
float getError<float>() {
|
||||
return 0.001;
|
||||
}
|
||||
|
||||
template<>
|
||||
float getError<half_t>() {
|
||||
return 0.2;
|
||||
}
|
||||
|
||||
struct PrintToStringParamName {
|
||||
template<class T>
|
||||
std::string operator()(const testing::TestParamInfo<ScatterElementsUpdateParamsWithFormat<T> > ¶m) {
|
||||
std::stringstream buf;
|
||||
ScatterElementsUpdateParams<T> p;
|
||||
format::type plain_format;
|
||||
format::type target_data_format;
|
||||
format::type target_indices_format;
|
||||
format::type target_updates_format;
|
||||
std::tie(p, plain_format, target_data_format, target_indices_format, target_updates_format) = param.param;
|
||||
buf << "_axis=" << p.axis
|
||||
<< "_data=" << p.data_tensor.to_string()
|
||||
<< "_indices=" << p.indices_tensor.to_string()
|
||||
<< "_plainFormat=" << fmt_to_str(plain_format)
|
||||
<< "_targetDataFormat=" << fmt_to_str(target_data_format)
|
||||
<< "_targetIndicesFormat=" << fmt_to_str(target_indices_format)
|
||||
<< "_targetUpdatesFormat=" << fmt_to_str(target_updates_format);
|
||||
return buf.str();
|
||||
}
|
||||
};
|
||||
}; // namespace
|
||||
|
||||
template<typename T>
|
||||
struct scatter_elements_update_gpu_formats_test
|
||||
: public ::testing::TestWithParam<ScatterElementsUpdateParamsWithFormat<T> > {
|
||||
public:
|
||||
void test() {
|
||||
const auto data_type = type_to_data_type<T>::value;
|
||||
ScatterElementsUpdateParams<T> params;
|
||||
format::type plain_format;
|
||||
format::type target_data_format;
|
||||
format::type target_indices_format;
|
||||
format::type target_updates_format;
|
||||
|
||||
std::tie(params, plain_format, target_data_format, target_indices_format, target_updates_format) = this->GetParam();
|
||||
if (target_indices_format == format::any) {
|
||||
target_indices_format = target_data_format;
|
||||
}
|
||||
if (target_updates_format == format::any) {
|
||||
target_updates_format = target_data_format;
|
||||
}
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
const auto data = engine.allocate_memory({data_type, plain_format, params.data_tensor});
|
||||
const auto indices = engine.allocate_memory({data_type, plain_format, params.indices_tensor});
|
||||
const auto updates = engine.allocate_memory({data_type, plain_format, params.indices_tensor});
|
||||
|
||||
set_values(data, params.data);
|
||||
set_values(indices, params.indices);
|
||||
set_values(updates, params.updates);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Data", data->get_layout()));
|
||||
topology.add(input_layout("Indices", indices->get_layout()));
|
||||
topology.add(input_layout("Updates", updates->get_layout()));
|
||||
topology.add(reorder("DataReordered", "Data", target_data_format, data_type));
|
||||
topology.add(reorder("IndicesReordered", "Indices", target_indices_format, data_type));
|
||||
topology.add(reorder("UpdatesReordered", "Updates", target_updates_format, data_type));
|
||||
topology.add(
|
||||
scatter_elements_update("ScatterEelementsUpdate", "DataReordered", "IndicesReordered",
|
||||
"UpdatesReordered", params.axis)
|
||||
);
|
||||
topology.add(reorder("ScatterEelementsUpdatePlain", "ScatterEelementsUpdate", plain_format, data_type));
|
||||
|
||||
network network{engine, topology};
|
||||
|
||||
network.set_input_data("Data", data);
|
||||
network.set_input_data("Indices", indices);
|
||||
network.set_input_data("Updates", updates);
|
||||
|
||||
const auto outputs = network.execute();
|
||||
const auto output = outputs.at("ScatterEelementsUpdatePlain").get_memory();
|
||||
const cldnn::mem_lock<T> output_ptr(output, get_test_stream());
|
||||
|
||||
ASSERT_EQ(params.data.size(), output_ptr.size());
|
||||
ASSERT_EQ(params.expected.size(), output_ptr.size());
|
||||
for (uint32_t i = 0; i < output_ptr.size(); i++) {
|
||||
EXPECT_NEAR(output_ptr[i], params.expected[i], getError<T>())
|
||||
<< "format=" << fmt_to_str(target_data_format) << ", i=" << i;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using scatter_elements_update_gpu_formats_test_f32 = scatter_elements_update_gpu_formats_test<float>;
|
||||
using scatter_elements_update_gpu_formats_test_f16 = scatter_elements_update_gpu_formats_test<half_t>;
|
||||
using scatter_elements_update_gpu_formats_test_i32 = scatter_elements_update_gpu_formats_test<int32_t>;
|
||||
|
||||
TEST_P(scatter_elements_update_gpu_formats_test_f32, basic) {
|
||||
ASSERT_NO_FATAL_FAILURE(test());
|
||||
}
|
||||
|
||||
TEST_P(scatter_elements_update_gpu_formats_test_f16, basic) {
|
||||
ASSERT_NO_FATAL_FAILURE(test());
|
||||
}
|
||||
|
||||
TEST_P(scatter_elements_update_gpu_formats_test_i32, basic) {
|
||||
ASSERT_NO_FATAL_FAILURE(test());
|
||||
}
|
||||
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_f32_2d,
|
||||
scatter_elements_update_gpu_formats_test_f32,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateScatterElementsUpdateParams2D<float>()),
|
||||
::testing::Values(format::bfyx),
|
||||
::testing::ValuesIn(formats2D),
|
||||
::testing::Values(format::any),
|
||||
::testing::Values(format::any)
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_f16_2d,
|
||||
scatter_elements_update_gpu_formats_test_f16,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateScatterElementsUpdateParams2D<half_t>()),
|
||||
::testing::Values(format::bfyx),
|
||||
::testing::ValuesIn(formats2D),
|
||||
::testing::Values(format::any),
|
||||
::testing::Values(format::any)
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_i32_2d,
|
||||
scatter_elements_update_gpu_formats_test_i32,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateScatterElementsUpdateParams2D<int32_t>()),
|
||||
::testing::Values(format::bfyx),
|
||||
::testing::ValuesIn(formats2D),
|
||||
::testing::Values(format::any),
|
||||
::testing::Values(format::any)
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_f32_3d,
|
||||
scatter_elements_update_gpu_formats_test_f32,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateScatterElementsUpdateParams3D<float>()),
|
||||
::testing::Values(format::bfzyx),
|
||||
::testing::ValuesIn(formats3D),
|
||||
::testing::Values(format::any),
|
||||
::testing::Values(format::any)
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_f32_4d,
|
||||
scatter_elements_update_gpu_formats_test_f32,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateScatterElementsUpdateParams4D<float>()),
|
||||
::testing::Values(format::bfwzyx),
|
||||
::testing::ValuesIn(formats4D),
|
||||
::testing::ValuesIn(formats4D),
|
||||
::testing::ValuesIn(formats4D)
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_mixed_inputs,
|
||||
scatter_elements_update_gpu_formats_test_f32,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateScatterElementsUpdateParams2D<float>()),
|
||||
::testing::Values(format::bfyx),
|
||||
::testing::ValuesIn({format::b_fs_yx_fsv16, format::b_fs_yx_fsv32}),
|
||||
::testing::ValuesIn({format::bs_fs_yx_bsv16_fsv16, format::bs_fs_yx_bsv32_fsv16}),
|
||||
::testing::ValuesIn({format::bs_fs_yx_bsv32_fsv32, format::bfyx})
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
||||
Reference in New Issue
Block a user