[GPU] Add additional handling of fused activations for OneDNN primitives (#8004)
This commit is contained in:
parent
897cd09a5a
commit
8ea9986896
@ -783,6 +783,17 @@ protected:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cldnn_post_ops.size() && arg.get_fused_activations_funcs().size())
|
||||||
|
throw std::runtime_error("Unsupported mix of fused ops and activations");
|
||||||
|
|
||||||
|
for (size_t i = 0; i < arg.get_fused_activations_funcs().size(); i++) {
|
||||||
|
auto activation_type = arg.get_fused_activations_funcs()[i];
|
||||||
|
auto params = arg.get_fused_activations_params()[i];
|
||||||
|
dnnl::algorithm alg = onednn::convert_activation_func(activation_type);
|
||||||
|
post_ops.append_eltwise(1.0f, alg, params.a, params.b);
|
||||||
|
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
|
||||||
|
}
|
||||||
|
|
||||||
// Update total onednn post-ops info
|
// Update total onednn post-ops info
|
||||||
onednn_fusing_map.emplace(arg.id(), std::move(fused_ops));
|
onednn_fusing_map.emplace(arg.id(), std::move(fused_ops));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user