[GPU] Added fp16 support for GatherTree (#15983)

This commit is contained in:
Roman Lyamin 2023-02-28 09:54:56 +04:00 committed by GitHub
parent 913f616964
commit 1070a3b6c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 3 deletions

View File

@ -39,7 +39,12 @@ struct gather_tree_impl : typed_primitive_impl_ocl<gather_tree> {
namespace detail {
attach_gather_tree_impl::attach_gather_tree_impl() {
auto types = {data_types::i32, data_types::f32};
auto types = {
data_types::f32,
data_types::f16,
data_types::i32
};
auto formats = {
format::yxfb,
format::bfyx,

View File

@ -26,8 +26,8 @@ KERNEL(gather_tree_gpu_ref)(
}
for (int parent = beam; time >= 0; time--) {
output[OUTPUT_GET_INDEX(time, batch, beam, 0)] = step_input[INPUT0_GET_INDEX(time, batch, parent, 0)];
parent = parent_input[INPUT1_GET_INDEX(time, batch, parent, 0)];
output[OUTPUT_GET_INDEX(time, batch, beam, 0)] = TO_OUTPUT_TYPE(step_input[INPUT0_GET_INDEX(time, batch, parent, 0)]);
parent = (int)parent_input[INPUT1_GET_INDEX(time, batch, parent, 0)];
}
bool finished = false;
for (int time = 0; time < max_sequence_in_beam; time++) {

View File

@ -14,9 +14,13 @@ ParamsKey GatherTreeKernelRef::GetSupportedKey() const {
k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F16);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);

View File

@ -13,6 +13,7 @@ namespace {
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32
};