diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/gather_tree.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/gather_tree.cpp index 7ee17e191b2..907c8ecb38e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/gather_tree.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/gather_tree.cpp @@ -39,7 +39,12 @@ struct gather_tree_impl : typed_primitive_impl_ocl { namespace detail { attach_gather_tree_impl::attach_gather_tree_impl() { - auto types = {data_types::i32, data_types::f32}; + auto types = { + data_types::f32, + data_types::f16, + data_types::i32 + }; + auto formats = { format::yxfb, format::bfyx, diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gather_tree_gpu_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gather_tree_gpu_ref.cl index 0fe1b3d5f75..f560d7655fc 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gather_tree_gpu_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gather_tree_gpu_ref.cl @@ -26,8 +26,8 @@ KERNEL(gather_tree_gpu_ref)( } for (int parent = beam; time >= 0; time--) { - output[OUTPUT_GET_INDEX(time, batch, beam, 0)] = step_input[INPUT0_GET_INDEX(time, batch, parent, 0)]; - parent = parent_input[INPUT1_GET_INDEX(time, batch, parent, 0)]; + output[OUTPUT_GET_INDEX(time, batch, beam, 0)] = TO_OUTPUT_TYPE(step_input[INPUT0_GET_INDEX(time, batch, parent, 0)]); + parent = (int)parent_input[INPUT1_GET_INDEX(time, batch, parent, 0)]; } bool finished = false; for (int time = 0; time < max_sequence_in_beam; time++) { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/gather_tree/gather_tree_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/gather_tree/gather_tree_kernel_ref.cpp index f8676f9128b..93a22d11069 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/gather_tree/gather_tree_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/gather_tree/gather_tree_kernel_ref.cpp @@ -14,9 +14,13 @@ ParamsKey GatherTreeKernelRef::GetSupportedKey() const { k.EnableInputDataType(Datatype::INT32); k.EnableOutputDataType(Datatype::INT32); + k.EnableInputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F16); + k.EnableInputLayout(DataLayout::bfyx); k.EnableOutputLayout(DataLayout::bfyx); diff --git a/src/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_tree.cpp b/src/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_tree.cpp index 82a94204d31..f70fb29ac97 100644 --- a/src/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_tree.cpp +++ b/src/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_tree.cpp @@ -13,6 +13,7 @@ namespace { const std::vector netPrecisions = { InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16, InferenceEngine::Precision::I32 };