[GPU] fix gen9_common_conv_fwd_data_f32
kernel to support op fusion when #input_channel == 3 (#17318)
* added op fusion code * fixed typo * added unit tests * size reduction
This commit is contained in:
parent
48604e9092
commit
634c58903d
@ -476,10 +476,16 @@ const float sum_scale = 1;
|
||||
#if OW % OW_BLOCK != 0
|
||||
if (ow + OW_BLOCK > OW) {
|
||||
for (int i = 0; i < OW - OW_LAST; i++) {
|
||||
#if HAS_FUSED_OPS
|
||||
{ FUSED_OPS_SCALAR0; blockC00[i] = FUSED_OPS_RESULT_SCALAR0; }
|
||||
#endif
|
||||
_sub_group_block_write((__global unsigned int *)(&dst_write0[i
|
||||
* OC_BLOCK * MB_BLOCK]),
|
||||
as_uint(blockC00[i]));
|
||||
#if OCB == 32
|
||||
#if HAS_FUSED_OPS
|
||||
{ FUSED_OPS_SCALAR1; blockC01[i] = FUSED_OPS_RESULT_SCALAR1; }
|
||||
#endif
|
||||
_sub_group_block_write(
|
||||
(__global unsigned int
|
||||
*)(&dst_write0[i * OC_BLOCK * MB_BLOCK
|
||||
@ -492,10 +498,16 @@ const float sum_scale = 1;
|
||||
#if OW_BLOCK != 8 || MB_BLOCK != 1
|
||||
__attribute__((opencl_unroll_hint(OW_BLOCK))) // attr:no-format
|
||||
for (int i = 0; i < OW_BLOCK; i++) {
|
||||
#if HAS_FUSED_OPS
|
||||
{ FUSED_OPS_SCALAR0; blockC00[i] = FUSED_OPS_RESULT_SCALAR0; }
|
||||
#endif
|
||||
_sub_group_block_write((__global unsigned int *)(&dst_write0[i
|
||||
* OC_BLOCK * MB_BLOCK]),
|
||||
as_uint(blockC00[i]));
|
||||
#if OCB == 32
|
||||
#if HAS_FUSED_OPS
|
||||
{ FUSED_OPS_SCALAR1; blockC01[i] = FUSED_OPS_RESULT_SCALAR1; }
|
||||
#endif
|
||||
_sub_group_block_write(
|
||||
(__global unsigned int
|
||||
*)(&dst_write0[i * OC_BLOCK * MB_BLOCK
|
||||
@ -504,9 +516,15 @@ const float sum_scale = 1;
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
#if HAS_FUSED_OPS
|
||||
{ FUSED_OPS_VEC0; blockC00 = FUSED_OPS_RESULT_VEC0; }
|
||||
#endif
|
||||
_sub_group_block_write8(
|
||||
(__global unsigned int *)(&dst_write0[0]), as_uint8(blockC00));
|
||||
#if OCB == 32
|
||||
#if HAS_FUSED_OPS
|
||||
{ FUSED_OPS_VEC1; blockC01 = FUSED_OPS_RESULT_VEC1; }
|
||||
#endif
|
||||
_sub_group_block_write8((__global unsigned int *)(&dst_write0[OC_BLOCK
|
||||
* MB_BLOCK * ODHW_SIZE]),
|
||||
as_uint8(blockC01));
|
||||
|
@ -223,6 +223,17 @@ bool ConvolutionKernel_b_fs_zyx_fsv16::Validate(const Params& p, const optional_
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if operation fusion is supported
|
||||
if (!params.fused_ops.empty()) {
|
||||
const bool is_1stconv = input.Feature().v == 3 && input.GetLayout() == DataLayout::bfzyx;
|
||||
const bool ver_16mb16c = !is_1stconv && ((output.GetDType() == Datatype::F16 && output.Batch().v % 32 == 0) ||
|
||||
(output.GetDType() == Datatype::F32 && output.Batch().v % 16 == 0));
|
||||
|
||||
if (!ver_16mb16c && is_1stconv && output.GetDType() == Datatype::F16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -310,7 +321,7 @@ JitConstants ConvolutionKernel_b_fs_zyx_fsv16::GetJitConstants(const convolution
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf_vec0, conf_vec1, conf_vec2, conf_vec3,
|
||||
conf_scalar0, conf_scalar1, conf_scalar2, conf_scalar3}));
|
||||
}
|
||||
} else if (!is_1stconv && !params.fused_ops.empty()) {
|
||||
} else if ((!is_1stconv || output.GetDType() == Datatype::F32) && !params.fused_ops.empty()) {
|
||||
FusedOpsConfiguration conf_vec0 = GenerateFusedOpsConfiguration_f16(0, "blockC0", input_dt, true);
|
||||
FusedOpsConfiguration conf_vec1 = GenerateFusedOpsConfiguration_f16(1, "blockC0", input_dt, true);
|
||||
FusedOpsConfiguration conf_scalar0 = GenerateFusedOpsConfiguration_f16(0, "blockC0", input_dt, false);
|
||||
|
@ -361,6 +361,7 @@ public:
|
||||
#define CASE_CONV_FP32_13 { 1, 16, 18, 5, 4 }, { 1, 16, 16, 3, 2 }, { 1, 1, 3, 3, 3 }, { 1, 1, 1 }, { 0, 0, 0 }, { 1, 1, 1 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::g_os_is_zyx_isv16_osv16, data_types::f32, format::bfzyx
|
||||
#define CASE_CONV_FP32_14 { 1, 3, 4, 5 }, { 1, 30, 2, 3 }, { 1, 1, 3, 3 }, { 1, 1 }, { 0, 0 }, { 1, 1 }, 1, data_types::f32, format::bfyx, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_CONV_FP32_15 { 1, 6, 4, 4 }, { 1, 16, 4, 4 }, { 1, 1, 3, 3 }, { 1, 1 }, { 1, 1 }, { 1, 1 }, 1, data_types::f32, format::bfyx, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_CONV_FP32_16 { 1, 3, 112, 112, 8 }, { 1, 16, 56, 56, 8 }, { 1, 1, 3, 3, 1 }, { 2, 2, 1 }, { 1, 1, 0 }, { 1, 1, 1 }, 1, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
||||
|
||||
|
||||
#define CASE_CONV_FP16_1 { 1, 15, 4, 5 }, { 1, 30, 2, 3 }, { 1, 1, 3, 3 }, { 1, 1 }, { 0, 0 }, { 1, 1 }, 1, data_types::f16, format::bfyx, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
@ -378,6 +379,7 @@ public:
|
||||
#define CASE_CONV_FP16_13 { 16, 32, 4, 5 }, { 16, 64, 2, 3 }, { 1, 1, 3, 3 }, { 1, 1 }, { 0, 0 }, { 1, 1 }, 1, data_types::f16, format::fs_b_yx_fsv32, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_CONV_FP16_14 { 1, 32, 55, 1 }, { 1, 32, 55, 1 }, { 1, 1, 3, 1 }, { 1, 1 }, { 1, 1 }, { 1, 1 }, 32, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::gs_oiyx_gsv16, data_types::f16, format::bfyx
|
||||
#define CASE_CONV_FP16_15 { 1, 39, 55, 1 }, { 1, 39, 55, 1 }, { 1, 1, 3, 1 }, { 1, 1 }, { 1, 1 }, { 1, 1 }, 39, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::gs_oiyx_gsv16, data_types::f16, format::bfyx
|
||||
#define CASE_CONV_FP16_16 { 1, 3, 112, 112, 8 }, { 1, 32, 56, 56, 8 }, { 1, 1, 3, 3, 1 }, { 2, 2, 1 }, { 1, 1, 0 }, { 1, 1, 1 }, 1, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||
|
||||
#define CASE_CONV_U8S8_1 { 1, 15, 4, 5 }, { 1, 30, 2, 3 }, { 1, 1, 3, 3 }, { 1, 1 }, { 0, 0 }, { 1, 1 }, 1, data_types::u8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_CONV_U8S8_2 { 1, 15, 5, 5 }, { 1, 30, 3, 3 }, { 1, 1, 3, 3 }, { 1, 1 }, { 0, 0 }, { 1, 1 }, 1, data_types::u8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
@ -3248,6 +3250,27 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp32_reorder_bfyx_to_fsv32_conv_data_
|
||||
convolution_test_params{ FSV32_CASE_CONV_FP32_1, 5, 5, 5 }
|
||||
}));
|
||||
|
||||
class conv_gen9_common_conv_fwd_data_1stconv : public ConvFusingTest {};
|
||||
TEST_P(conv_gen9_common_conv_fwd_data_1stconv, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(
|
||||
input_layout("input", get_input_layout(p)),
|
||||
data("weights", get_mem(get_weights_layout(p))),
|
||||
data("bias", get_mem(get_bias_layout(p))),
|
||||
convolution("conv_prim", input_info("input"), { "weights" }, { "bias" }, p.groups, p.stride, p.pad, p.dilation),
|
||||
activation("activation", input_info("conv_prim"), activation_func::hswish),
|
||||
reorder("reorder_bfyx", input_info("activation"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_gen9_common_conv_fwd_data_1stconv, ::testing::ValuesIn(std::vector<convolution_test_params>{
|
||||
convolution_test_params{ CASE_CONV_FP32_16, 2, 2, 3 },
|
||||
convolution_test_params{ CASE_CONV_FP16_16, 2, 2, 3 },
|
||||
}));
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
class conv_fp16_prelu_onednn : public WeightsPrimitiveFusingTestOneDNN {};
|
||||
TEST_P(conv_fp16_prelu_onednn, basic_activation_eltwise) {
|
||||
|
Loading…
Reference in New Issue
Block a user