[GPU] Convert activation's slope buffer data type for OneDNN fusions (#13444)
This commit is contained in:
parent
0b8f1f8c00
commit
5fea4c3fc3
@ -789,4 +789,31 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
|
|||||||
n->set_preferred_impl_type(preferred_impl);
|
n->set_preferred_impl_type(preferred_impl);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WA for OneDNN PRelu activation fusions: convert activation's slope buffer to expected f32 data type
|
||||||
|
for (auto& node : p.get_processing_order()) {
|
||||||
|
if (node->get_preferred_impl_type() == impl_types::onednn) {
|
||||||
|
auto fused_prims = node->get_fused_primitives();
|
||||||
|
for (auto& fused_desc : fused_prims) {
|
||||||
|
if (!fused_desc.is_type<activation>())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto activation_desc = fused_desc.typed_desc<activation>();
|
||||||
|
if (activation_desc->activation_function == cldnn::activation_func::relu_negative_slope &&
|
||||||
|
!activation_desc->additional_params_input.empty()) {
|
||||||
|
const auto expected_dt = data_types::f32;
|
||||||
|
const auto dep_idx = fused_desc.dep_start_idx;
|
||||||
|
const auto orig_layout = node->get_dependency(dep_idx).get_output_layout();
|
||||||
|
if (orig_layout.data_type == expected_dt)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto new_layout = orig_layout;
|
||||||
|
new_layout.data_type = expected_dt;
|
||||||
|
auto new_input = rf.get_reorder(node->get_dependency(dep_idx).id(), orig_layout, new_layout);
|
||||||
|
if (new_input.first)
|
||||||
|
p.add_intermediate(new_input.first, *node, dep_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -265,6 +265,10 @@ public:
|
|||||||
layout get_per_channel_layout(convolution_test_params& p) {
|
layout get_per_channel_layout(convolution_test_params& p) {
|
||||||
return layout{ p.default_type, p.default_format, tensor{1, p.out_shape.feature[0], 1, 1} };
|
return layout{ p.default_type, p.default_format, tensor{1, p.out_shape.feature[0], 1, 1} };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
layout get_prelu_slope_layout(convolution_test_params& p) {
|
||||||
|
return layout{ p.default_type, p.default_format, tensor{1, p.out_shape.feature[0], 1, 1} };
|
||||||
|
}
|
||||||
};
|
};
|
||||||
#endif // ENABLE_ONEDNN_FOR_GPU
|
#endif // ENABLE_ONEDNN_FOR_GPU
|
||||||
|
|
||||||
@ -2914,6 +2918,30 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp32_reorder_bfyx_to_fsv32_conv_data_
|
|||||||
}));
|
}));
|
||||||
|
|
||||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||||
|
class conv_fp16_prelu_onednn : public WeightsPrimitiveFusingTestOneDNN {};
|
||||||
|
TEST_P(conv_fp16_prelu_onednn, basic_activation_eltwise) {
|
||||||
|
auto p = GetParam();
|
||||||
|
|
||||||
|
create_topologies(
|
||||||
|
input_layout("input", get_input_layout(p)),
|
||||||
|
data("weights", get_mem(get_weights_layout(p))),
|
||||||
|
data("bias", get_mem(get_bias_layout(p))),
|
||||||
|
data("slope_data", get_mem(get_prelu_slope_layout(p))),
|
||||||
|
data("eltwise_data", get_mem(get_output_layout(p))),
|
||||||
|
convolution("conv_prim", "input", { "weights" }, { "bias" }, p.groups, p.stride, p.pad, p.dilation),
|
||||||
|
activation("activation", "conv_prim", "slope_data", activation_func::relu_negative_slope),
|
||||||
|
eltwise("eltwise", "activation", "eltwise_data", eltwise_mode::sum),
|
||||||
|
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
|
||||||
|
);
|
||||||
|
|
||||||
|
tolerance = default_tolerance(p.default_type);
|
||||||
|
execute(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp16_prelu_onednn, ::testing::ValuesIn(std::vector<convolution_test_params>{
|
||||||
|
convolution_test_params{ CASE_CONV_FP16_1, 2, 4 },
|
||||||
|
}));
|
||||||
|
|
||||||
class conv_int8_eltwise_onednn : public WeightsPrimitiveFusingTestOneDNN {};
|
class conv_int8_eltwise_onednn : public WeightsPrimitiveFusingTestOneDNN {};
|
||||||
TEST_P(conv_int8_eltwise_onednn, u8_eltwise_sum_out) {
|
TEST_P(conv_int8_eltwise_onednn, u8_eltwise_sum_out) {
|
||||||
auto p = GetParam();
|
auto p = GetParam();
|
||||||
|
Loading…
Reference in New Issue
Block a user