[GPU] Added fp16 support for GatherTree (#15983)
This commit is contained in:
parent
913f616964
commit
1070a3b6c1
@ -39,7 +39,12 @@ struct gather_tree_impl : typed_primitive_impl_ocl<gather_tree> {
|
||||
|
||||
namespace detail {
|
||||
attach_gather_tree_impl::attach_gather_tree_impl() {
|
||||
auto types = {data_types::i32, data_types::f32};
|
||||
auto types = {
|
||||
data_types::f32,
|
||||
data_types::f16,
|
||||
data_types::i32
|
||||
};
|
||||
|
||||
auto formats = {
|
||||
format::yxfb,
|
||||
format::bfyx,
|
||||
|
@ -26,8 +26,8 @@ KERNEL(gather_tree_gpu_ref)(
|
||||
}
|
||||
|
||||
for (int parent = beam; time >= 0; time--) {
|
||||
output[OUTPUT_GET_INDEX(time, batch, beam, 0)] = step_input[INPUT0_GET_INDEX(time, batch, parent, 0)];
|
||||
parent = parent_input[INPUT1_GET_INDEX(time, batch, parent, 0)];
|
||||
output[OUTPUT_GET_INDEX(time, batch, beam, 0)] = TO_OUTPUT_TYPE(step_input[INPUT0_GET_INDEX(time, batch, parent, 0)]);
|
||||
parent = (int)parent_input[INPUT1_GET_INDEX(time, batch, parent, 0)];
|
||||
}
|
||||
bool finished = false;
|
||||
for (int time = 0; time < max_sequence_in_beam; time++) {
|
||||
|
@ -14,9 +14,13 @@ ParamsKey GatherTreeKernelRef::GetSupportedKey() const {
|
||||
|
||||
k.EnableInputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::INT32);
|
||||
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
|
||||
|
@ -13,6 +13,7 @@ namespace {
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16,
|
||||
InferenceEngine::Precision::I32
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user