[GPU] i8/u8 data types support for gather (#6116)
This commit is contained in:
committed by
GitHub
parent
8f487c7f63
commit
5812a150c0
@@ -38,6 +38,8 @@ ParamsKey GatherKernelRef::GetSupportedKey() const {
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT32);
|
||||
k.EnableInputDataType(Datatype::UINT8);
|
||||
k.EnableInputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT32);
|
||||
|
||||
@@ -74,12 +74,20 @@ attach_gather_impl::attach_gather_impl() {
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::i32, format::bfyx),
|
||||
std::make_tuple(data_types::i8, format::bfyx),
|
||||
std::make_tuple(data_types::u8, format::bfyx),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bfzyx),
|
||||
std::make_tuple(data_types::f16, format::bfzyx),
|
||||
std::make_tuple(data_types::i32, format::bfzyx),
|
||||
std::make_tuple(data_types::i8, format::bfzyx),
|
||||
std::make_tuple(data_types::u8, format::bfzyx),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bfwzyx),
|
||||
std::make_tuple(data_types::f16, format::bfwzyx),
|
||||
std::make_tuple(data_types::i32, format::bfwzyx),
|
||||
std::make_tuple(data_types::i8, format::bfwzyx),
|
||||
std::make_tuple(data_types::u8, format::bfwzyx),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1609,3 +1609,46 @@ TEST(gather_gpu_fp32, 322_axisF) {
|
||||
EXPECT_EQ(expected_results[i], output_ptr[i]) << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(gather_gpu_u8, 322_axisF) {
|
||||
// Dictionary : 3x3x1x1
|
||||
// Indexes : 2x2x1x1
|
||||
// Axis : 1
|
||||
// Output : 3x2x2x1
|
||||
// Input values in u8
|
||||
|
||||
auto &engine = get_test_engine();
|
||||
|
||||
auto input1 = engine.allocate_memory({data_types::u8, format::bfyx, {3, 3, 1, 1}}); // data
|
||||
auto input2 = engine.allocate_memory({data_types::i32, format::bfyx, {2, 2, 1, 1}}); // Indexes
|
||||
auto axis = cldnn::gather::gather_axis::along_f;
|
||||
|
||||
set_values<uint8_t>(input1, {0, 1, 2, 10, 11, 12, 20, 21, 22});
|
||||
|
||||
set_values(input2, {1, 0,
|
||||
2, 1});
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("InputDictionary", input1->get_layout()));
|
||||
topology.add(input_layout("InputText", input2->get_layout()));
|
||||
topology.add(
|
||||
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(3, 2, 1, 2)));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("InputDictionary", input1);
|
||||
network.set_input_data("InputText", input2);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("gather").get_memory();
|
||||
cldnn::mem_lock<uint8_t> output_ptr(output, get_test_stream());
|
||||
|
||||
std::vector<uint8_t> expected_results = {
|
||||
1, 0, 2, 1, 11, 10, 12, 11, 21, 20, 22, 21};
|
||||
|
||||
ASSERT_EQ(expected_results.size(), output_ptr.size());
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_EQ(expected_results[i], output_ptr[i]) << i;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user