[GPU] ScatterElementsUpdate blocked layout support (#12465)

* add parameterized test
* add blocked layouts support
* support for mixed input formats
* fix RHEL 8.2 build
* add scatter_elements_update to whitelist for blocked formats
* Added bs_fs_yx_bsv16_fsv32 format
This commit is contained in:
Oleksii Khovan
2022-10-25 07:22:27 +02:00
committed by GitHub
parent 0e242b3244
commit 1960746c44
5 changed files with 368 additions and 34 deletions

View File

@@ -76,17 +76,26 @@ public:
namespace detail {
attach_scatter_elements_update_impl::attach_scatter_elements_update_impl() {
implementation_map<scatter_elements_update>::add(impl_types::ocl, scatter_elements_update_impl::create, {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::i32, format::bfyx),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
std::make_tuple(data_types::i32, format::bfzyx),
std::make_tuple(data_types::f32, format::bfwzyx),
std::make_tuple(data_types::f16, format::bfwzyx),
std::make_tuple(data_types::i32, format::bfwzyx),
});
auto types = {data_types::f16, data_types::f32, data_types::i32};
auto formats = {
format::bfyx,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv16_fsv32,
format::bs_fs_yx_bsv32_fsv32,
format::bfzyx,
format::b_fs_zyx_fsv16,
format::b_fs_zyx_fsv32,
format::bs_fs_zyx_bsv16_fsv32,
format::bs_fs_zyx_bsv16_fsv16,
format::bs_fs_zyx_bsv32_fsv32,
format::bs_fs_zyx_bsv32_fsv16,
format::bfwzyx
};
implementation_map<scatter_elements_update>::add(impl_types::ocl, scatter_elements_update_impl::create, types, formats);
}
} // namespace detail

View File

@@ -1448,7 +1448,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::eye::type_id() &&
prim.type() != cldnn::generate_proposals::type_id() &&
prim.type() != cldnn::reverse::type_id() &&
prim.type() != cldnn::reorg_yolo::type_id()) {
prim.type() != cldnn::reorg_yolo::type_id() &&
prim.type() != cldnn::scatter_elements_update::type_id()) {
can_use_fsv16 = false;
}
@@ -1488,7 +1489,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::eye::type_id() &&
prim.type() != cldnn::generate_proposals::type_id() &&
prim.type() != cldnn::reverse::type_id() &&
prim.type() != cldnn::reorg_yolo::type_id()) {
prim.type() != cldnn::reorg_yolo::type_id() &&
prim.type() != cldnn::scatter_elements_update::type_id()) {
can_use_bs_fs_yx_bsv16_fsv16 = false;
}
}

View File

@@ -5,7 +5,8 @@
#include "include/batch_headers/data_types.cl"
#include "include/batch_headers/fetch_data.cl"
#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
#define GET_INDICES_INDEX(idx_order) INPUT1_GET_INDEX(idx_order)
#define GET_UPDATES_INDEX(idx_order) INPUT2_GET_INDEX(idx_order)
#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)
#if OUTPUT_DIMS == 4
#define ORDER b,f,y,x
@@ -87,8 +88,8 @@ KERNEL(scatter_elements_update_ref)(const __global INPUT0_TYPE* data,
const uint idx_b = dim2 / INPUT2_FEATURE_NUM;
#endif
const uint updates_idx = GET_UPDATES_INDEX(INPUT2, IDX_ORDER);
INPUT1_TYPE index = indices[(int)updates_idx];
const uint indices_idx = GET_INDICES_INDEX(IDX_ORDER);
INPUT1_TYPE index = indices[(int)indices_idx];
#if OUTPUT_DIMS == 4
#if AXIS_VALUE == 0
@@ -129,6 +130,7 @@ KERNEL(scatter_elements_update_ref)(const __global INPUT0_TYPE* data,
#endif
const uint output_idx = GET_OUTPUT_INDEX(ORDER);
const uint updates_idx = GET_UPDATES_INDEX(IDX_ORDER);
INPUT2_TYPE val = updates[(int)updates_idx];
#if HAS_FUSED_OPS
FUSED_OPS_SECOND_KERNEL;
@@ -139,6 +141,7 @@ KERNEL(scatter_elements_update_ref)(const __global INPUT0_TYPE* data,
#endif
}
#undef GET_INDICES_INDEX
#undef GET_UPDATES_INDEX
#undef GET_OUTPUT_INDEX
#undef IDX_ORDER

View File

@@ -34,20 +34,36 @@ static size_t GetScatterElementsUpdateChannelIndex(const scatter_elements_update
ParamsKey ScatterElementsUpdateKernelRef::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
const std::vector<Datatype> supportedTypes{
Datatype::F16, Datatype::F32, Datatype::INT32, Datatype::INT8, Datatype::UINT8
};
for (const auto t : supportedTypes) {
k.EnableInputDataType(t);
k.EnableOutputDataType(t);
}
const std::vector<DataLayout> supportedLayots{
DataLayout::bfyx,
DataLayout::b_fs_yx_fsv16,
DataLayout::b_fs_yx_fsv32,
DataLayout::bs_fs_yx_bsv16_fsv16,
DataLayout::bs_fs_yx_bsv32_fsv16,
DataLayout::bs_fs_yx_bsv16_fsv32,
DataLayout::bs_fs_yx_bsv32_fsv32,
DataLayout::bfzyx,
DataLayout::b_fs_zyx_fsv16,
DataLayout::b_fs_zyx_fsv32,
DataLayout::bs_fs_zyx_bsv16_fsv32,
DataLayout::bs_fs_zyx_bsv16_fsv16,
DataLayout::bs_fs_zyx_bsv32_fsv32,
DataLayout::bs_fs_zyx_bsv32_fsv16,
DataLayout::bfwzyx
};
for (const auto l : supportedLayots) {
k.EnableInputLayout(l);
k.EnableOutputLayout(l);
}
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
@@ -79,22 +95,24 @@ CommonDispatchData ScatterElementsUpdateKernelRef::SetDefault(const scatter_elem
const auto& scope = is_second ? indices : output;
switch (params.inputs[0].GetLayout()) {
case DataLayout::bfyx:
const auto rank = params.inputs[0].GetDims().size();
switch (rank) {
case 4:
dispatchData.gws = {scope.X().v, scope.Y().v, scope.Feature().v * scope.Batch().v};
dims_by_gws = {{Tensor::DataChannelName::X},
{Tensor::DataChannelName::Y},
{Tensor::DataChannelName::FEATURE, Tensor::DataChannelName::BATCH}};
break;
case DataLayout::bfzyx:
case 5:
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v, scope.Feature().v * scope.Batch().v};
dims_by_gws = {{Tensor::DataChannelName::X, Tensor::DataChannelName::Y},
{Tensor::DataChannelName::Z},
{Tensor::DataChannelName::FEATURE, Tensor::DataChannelName::BATCH}};
break;
case DataLayout::bfwzyx:
case 6:
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v * scope.W().v, scope.Feature().v * scope.Batch().v};
dims_by_gws = {{Tensor::DataChannelName::X, Tensor::DataChannelName::Y},
{Tensor::DataChannelName::Z, Tensor::DataChannelName::W},

View File

@@ -90,3 +90,305 @@ TEST(scatter_elements_update_gpu_fp16, d2411_axisF) {
EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
}
}
namespace {
template<typename T>
struct ScatterElementsUpdateParams {
int64_t axis;
tensor data_tensor;
std::vector<T> data;
tensor indices_tensor;
std::vector<T> indices;
std::vector<T> updates;
std::vector<T> expected;
};
template<typename T>
using ScatterElementsUpdateParamsWithFormat = std::tuple<
ScatterElementsUpdateParams<T>,
format::type, // source (plain) layout
format::type, // target (blocked) data layout
format::type, // target (blocked) indices layout
format::type // target (blocked) updates layout
>;
const std::vector<format::type> formats2D{
format::bfyx,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv16_fsv32,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32,
};
const std::vector<format::type> formats3D{
format::bfzyx,
format::b_fs_zyx_fsv16,
format::bs_fs_zyx_bsv16_fsv16
};
const std::vector<format::type> formats4D{
format::bfwzyx
};
template<typename T>
std::vector<T> getValues(const std::vector<float> &values) {
std::vector<T> result(values.begin(), values.end());
return result;
}
template<typename T>
std::vector<ScatterElementsUpdateParams<T>> generateScatterElementsUpdateParams2D() {
const std::vector<ScatterElementsUpdateParams<T>> result = {
{ 1,
tensor{2, 4, 1, 1},
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7 }),
tensor{2, 2, 1, 1},
getValues<T>({ 0, 1, 2, 3 }),
getValues<T>({ -10, -11, -12, -13 }),
getValues<T>({ -10, -11, 2, 3, 4, 5, -12, -13 })
},
{ 2,
tensor{2, 1, 2, 2},
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7 }),
tensor{2, 1, 2, 1},
getValues<T>({ 0, 1, 0, 1 }),
getValues<T>({ -10, -11, -12, -13 }),
getValues<T>({ -10, 1, 2, -11, -12, 5, 6, -13 })
},
{ 3,
tensor{2, 1, 2, 2},
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7 }),
tensor{2, 1, 1, 2},
getValues<T>({ 0, 1, 0, 1 }),
getValues<T>({ -10, -11, -12, -13 }),
getValues<T>({ -10, 1, 2, -11, -12, 5, 6, -13 })
},
};
return result;
}
template<typename T>
std::vector<ScatterElementsUpdateParams<T>> generateScatterElementsUpdateParams3D() {
const std::vector<ScatterElementsUpdateParams<T>> result = {
{ 1,
tensor{2, 4, 1, 1, 3},
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 }),
tensor{2, 1, 1, 1, 2},
getValues<T>({ 0, 3, 1, 2 }),
getValues<T>({ -100, -110, -120, -130 }),
getValues<T>({ -100, 1, 2, 3, 4, 5, 6, 7, 8, 9, -110, 11, 12, 13, 14, -120, 16, 17, 18, -130, 20, 21, 22, 23 })
},
{ 4,
tensor{2, 4, 1, 1, 3},
getValues<T>({ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 }),
tensor{2, 1, 1, 1, 2},
getValues<T>({ 0, 1, 0, 1 }),
getValues<T>({ -100, -110, -120, -130 }),
getValues<T>({ -100, 1, -110, 3, 4, 5, 6, 7, 8, 9, 10, 11, -120, 13, -130, 15, 16, 17, 18, 19, 20, 21, 22, 23 })
},
};
return result;
}
template<typename T>
std::vector<ScatterElementsUpdateParams<T>> generateScatterElementsUpdateParams4D() {
const std::vector<ScatterElementsUpdateParams<T>> result = {
{ 5,
tensor{2, 4, 2, 1, 1, 3},
getValues<T>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}),
tensor{2, 1, 1, 1, 1, 2},
getValues<T>({2, 1, 1, 1, 2}),
getValues<T>({-100, -110, -120, -130}),
getValues<T>({0, 1, -100, -110, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, -120, 26, -130, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}),
}
};
return result;
}
template<typename T>
float getError() {
return 0.0;
}
template<>
float getError<float>() {
return 0.001;
}
template<>
float getError<half_t>() {
return 0.2;
}
struct PrintToStringParamName {
template<class T>
std::string operator()(const testing::TestParamInfo<ScatterElementsUpdateParamsWithFormat<T> > &param) {
std::stringstream buf;
ScatterElementsUpdateParams<T> p;
format::type plain_format;
format::type target_data_format;
format::type target_indices_format;
format::type target_updates_format;
std::tie(p, plain_format, target_data_format, target_indices_format, target_updates_format) = param.param;
buf << "_axis=" << p.axis
<< "_data=" << p.data_tensor.to_string()
<< "_indices=" << p.indices_tensor.to_string()
<< "_plainFormat=" << fmt_to_str(plain_format)
<< "_targetDataFormat=" << fmt_to_str(target_data_format)
<< "_targetIndicesFormat=" << fmt_to_str(target_indices_format)
<< "_targetUpdatesFormat=" << fmt_to_str(target_updates_format);
return buf.str();
}
};
}; // namespace
template<typename T>
struct scatter_elements_update_gpu_formats_test
: public ::testing::TestWithParam<ScatterElementsUpdateParamsWithFormat<T> > {
public:
void test() {
const auto data_type = type_to_data_type<T>::value;
ScatterElementsUpdateParams<T> params;
format::type plain_format;
format::type target_data_format;
format::type target_indices_format;
format::type target_updates_format;
std::tie(params, plain_format, target_data_format, target_indices_format, target_updates_format) = this->GetParam();
if (target_indices_format == format::any) {
target_indices_format = target_data_format;
}
if (target_updates_format == format::any) {
target_updates_format = target_data_format;
}
auto& engine = get_test_engine();
const auto data = engine.allocate_memory({data_type, plain_format, params.data_tensor});
const auto indices = engine.allocate_memory({data_type, plain_format, params.indices_tensor});
const auto updates = engine.allocate_memory({data_type, plain_format, params.indices_tensor});
set_values(data, params.data);
set_values(indices, params.indices);
set_values(updates, params.updates);
topology topology;
topology.add(input_layout("Data", data->get_layout()));
topology.add(input_layout("Indices", indices->get_layout()));
topology.add(input_layout("Updates", updates->get_layout()));
topology.add(reorder("DataReordered", "Data", target_data_format, data_type));
topology.add(reorder("IndicesReordered", "Indices", target_indices_format, data_type));
topology.add(reorder("UpdatesReordered", "Updates", target_updates_format, data_type));
topology.add(
scatter_elements_update("ScatterEelementsUpdate", "DataReordered", "IndicesReordered",
"UpdatesReordered", params.axis)
);
topology.add(reorder("ScatterEelementsUpdatePlain", "ScatterEelementsUpdate", plain_format, data_type));
network network{engine, topology};
network.set_input_data("Data", data);
network.set_input_data("Indices", indices);
network.set_input_data("Updates", updates);
const auto outputs = network.execute();
const auto output = outputs.at("ScatterEelementsUpdatePlain").get_memory();
const cldnn::mem_lock<T> output_ptr(output, get_test_stream());
ASSERT_EQ(params.data.size(), output_ptr.size());
ASSERT_EQ(params.expected.size(), output_ptr.size());
for (uint32_t i = 0; i < output_ptr.size(); i++) {
EXPECT_NEAR(output_ptr[i], params.expected[i], getError<T>())
<< "format=" << fmt_to_str(target_data_format) << ", i=" << i;
}
}
};
using scatter_elements_update_gpu_formats_test_f32 = scatter_elements_update_gpu_formats_test<float>;
using scatter_elements_update_gpu_formats_test_f16 = scatter_elements_update_gpu_formats_test<half_t>;
using scatter_elements_update_gpu_formats_test_i32 = scatter_elements_update_gpu_formats_test<int32_t>;
TEST_P(scatter_elements_update_gpu_formats_test_f32, basic) {
ASSERT_NO_FATAL_FAILURE(test());
}
TEST_P(scatter_elements_update_gpu_formats_test_f16, basic) {
ASSERT_NO_FATAL_FAILURE(test());
}
TEST_P(scatter_elements_update_gpu_formats_test_i32, basic) {
ASSERT_NO_FATAL_FAILURE(test());
}
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_f32_2d,
scatter_elements_update_gpu_formats_test_f32,
::testing::Combine(
::testing::ValuesIn(generateScatterElementsUpdateParams2D<float>()),
::testing::Values(format::bfyx),
::testing::ValuesIn(formats2D),
::testing::Values(format::any),
::testing::Values(format::any)
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_f16_2d,
scatter_elements_update_gpu_formats_test_f16,
::testing::Combine(
::testing::ValuesIn(generateScatterElementsUpdateParams2D<half_t>()),
::testing::Values(format::bfyx),
::testing::ValuesIn(formats2D),
::testing::Values(format::any),
::testing::Values(format::any)
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_i32_2d,
scatter_elements_update_gpu_formats_test_i32,
::testing::Combine(
::testing::ValuesIn(generateScatterElementsUpdateParams2D<int32_t>()),
::testing::Values(format::bfyx),
::testing::ValuesIn(formats2D),
::testing::Values(format::any),
::testing::Values(format::any)
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_f32_3d,
scatter_elements_update_gpu_formats_test_f32,
::testing::Combine(
::testing::ValuesIn(generateScatterElementsUpdateParams3D<float>()),
::testing::Values(format::bfzyx),
::testing::ValuesIn(formats3D),
::testing::Values(format::any),
::testing::Values(format::any)
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_f32_4d,
scatter_elements_update_gpu_formats_test_f32,
::testing::Combine(
::testing::ValuesIn(generateScatterElementsUpdateParams4D<float>()),
::testing::Values(format::bfwzyx),
::testing::ValuesIn(formats4D),
::testing::ValuesIn(formats4D),
::testing::ValuesIn(formats4D)
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(scatter_elements_update_gpu_formats_test_mixed_inputs,
scatter_elements_update_gpu_formats_test_f32,
::testing::Combine(
::testing::ValuesIn(generateScatterElementsUpdateParams2D<float>()),
::testing::Values(format::bfyx),
::testing::ValuesIn({format::b_fs_yx_fsv16, format::b_fs_yx_fsv32}),
::testing::ValuesIn({format::bs_fs_yx_bsv16_fsv16, format::bs_fs_yx_bsv32_fsv16}),
::testing::ValuesIn({format::bs_fs_yx_bsv32_fsv32, format::bfyx})
),
PrintToStringParamName());