[GPU] i8/u8 data types support for gather (#6116)

This commit is contained in:
Vladimir Paramuzov
2021-10-15 09:08:10 +03:00
committed by GitHub
parent 8f487c7f63
commit 5812a150c0
3 changed files with 53 additions and 0 deletions

View File

@@ -38,6 +38,8 @@ ParamsKey GatherKernelRef::GetSupportedKey() const {
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::UINT8);
k.EnableInputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);

View File

@@ -74,12 +74,20 @@ attach_gather_impl::attach_gather_impl() {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::i32, format::bfyx),
std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
std::make_tuple(data_types::i32, format::bfzyx),
std::make_tuple(data_types::i8, format::bfzyx),
std::make_tuple(data_types::u8, format::bfzyx),
std::make_tuple(data_types::f32, format::bfwzyx),
std::make_tuple(data_types::f16, format::bfwzyx),
std::make_tuple(data_types::i32, format::bfwzyx),
std::make_tuple(data_types::i8, format::bfwzyx),
std::make_tuple(data_types::u8, format::bfwzyx),
});
}

View File

@@ -1609,3 +1609,46 @@ TEST(gather_gpu_fp32, 322_axisF) {
EXPECT_EQ(expected_results[i], output_ptr[i]) << i;
}
}
TEST(gather_gpu_u8, 322_axisF) {
// Dictionary : 3x3x1x1
// Indexes : 2x2x1x1
// Axis : 1
// Output : 3x2x2x1
// Input values in u8
auto &engine = get_test_engine();
auto input1 = engine.allocate_memory({data_types::u8, format::bfyx, {3, 3, 1, 1}}); // data
auto input2 = engine.allocate_memory({data_types::i32, format::bfyx, {2, 2, 1, 1}}); // Indexes
auto axis = cldnn::gather::gather_axis::along_f;
set_values<uint8_t>(input1, {0, 1, 2, 10, 11, 12, 20, 21, 22});
set_values(input2, {1, 0,
2, 1});
topology topology;
topology.add(input_layout("InputDictionary", input1->get_layout()));
topology.add(input_layout("InputText", input2->get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(3, 2, 1, 2)));
network network(engine, topology);
network.set_input_data("InputDictionary", input1);
network.set_input_data("InputText", input2);
auto outputs = network.execute();
auto output = outputs.at("gather").get_memory();
cldnn::mem_lock<uint8_t> output_ptr(output, get_test_stream());
std::vector<uint8_t> expected_results = {
1, 0, 2, 1, 11, 10, 12, 11, 21, 20, 22, 21};
ASSERT_EQ(expected_results.size(), output_ptr.size());
for (size_t i = 0; i < expected_results.size(); ++i) {
EXPECT_EQ(expected_results[i], output_ptr[i]) << i;
}
}