From e003bf3af712d617f5303c01e6db081ca5a71f5a Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Mon, 6 Feb 2023 12:17:55 +0400 Subject: [PATCH] [GPU] Shape agnostic FC opt tiled kernel (#15396) --- .../intel_gpu/src/graph/fully_connected.cpp | 6 +- .../intel_gpu/src/graph/primitive_inst.cpp | 5 +- .../fully_connected_gpu_bf_tiled.cl | 25 ++- .../fully_connected_kernel_bf_tiled.cpp | 103 ++++++---- .../fake_alignment/fc_fake_alignment_test.cpp | 6 +- .../test_cases/fully_connected_gpu_test.cpp | 186 +++++++++++++++++- .../tests/test_cases/hash_key_gpu_test.cpp | 4 +- 7 files changed, 270 insertions(+), 65 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/fully_connected.cpp b/src/plugins/intel_gpu/src/graph/fully_connected.cpp index f5e8422f6be..fc7bf008f37 100644 --- a/src/plugins/intel_gpu/src/graph/fully_connected.cpp +++ b/src/plugins/intel_gpu/src/graph/fully_connected.cpp @@ -166,7 +166,7 @@ std::vector fully_connected_inst::calc_output_layouts(fully_connected_no kernel_impl_params fully_connected_inst::get_fake_aligned_params(kernel_impl_params const& orig_impl_param) { - // fc_tiled_opt kernel is optimized for row shape aligned by 16. + // fc_tiled_opt kernel is optimized for row shape aligned by 8. // Thus, use fake aligned shape at kernel execution for better performance. auto orig_input_layout = orig_impl_param.get_input_layout(); auto orig_output_layout = orig_impl_param.get_output_layout(); @@ -176,10 +176,10 @@ kernel_impl_params fully_connected_inst::get_fake_aligned_params(kernel_impl_par auto updated_param = orig_impl_param; auto input_shape = orig_input_layout.get_partial_shape().to_shape(); auto input_row_idx = input_shape.size() - 2; - input_shape[input_row_idx] = align_to(input_shape[input_row_idx], 16); + input_shape[input_row_idx] = align_to(input_shape[input_row_idx], 8); auto output_shape = orig_output_layout.get_partial_shape().to_shape(); auto output_row_idx = output_shape.size() - 2; - output_shape[output_row_idx] = align_to(output_shape[output_row_idx], 16); + output_shape[output_row_idx] = align_to(output_shape[output_row_idx], 8); updated_param.input_layouts[0] = layout(ov::PartialShape(input_shape), orig_input_layout.data_type, diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index 41a26c71797..f2a17f6be1f 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -356,9 +356,9 @@ bool primitive_inst::update_impl() { }); _impl = _dynamic_impl->clone(); - _impl->update_dispatch_data(updated_params); + _impl->update_dispatch_data(*_impl_params); - update_shape_info(updated_params); + update_shape_info(*_impl_params); } else { _impl = _node->type()->choose_impl(*_node, updated_params); auto& kernels_cache = get_network().get_kernels_cache(); @@ -1274,6 +1274,7 @@ size_t primitive_inst::get_impl_key(const kernel_impl_params& params) const { } return seed; } + size_t primitive_inst::get_impl_key() const { auto updated_params = _node->type()->get_fake_aligned_params(*_impl_params); return get_impl_key(updated_params); diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl index c58e8cc6199..62f8548514f 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl @@ -79,6 +79,7 @@ REQD_SUB_GROUP_SIZE(SIMD) KERNEL(fc)( + OPTIONAL_SHAPE_INFO_ARG const __global INPUT0_TYPE* input, __global OUTPUT_TYPE* output, const __global FILTER_TYPE* weights @@ -149,10 +150,10 @@ KERNEL(fc)( weights_offset += TILE_K_OFM * SIMD; unroll_for (uint kii = 0; kii < TILE_K; ++kii) { - unroll_for (uint fi = 0; fi < TILE_OFM; ++fi) { - unroll_for (uint bi = 0; bi < TILE_B; ++bi) { - const uint total_k = ki * TILE_K + kii; - INPUT0_TYPE in_val = _sub_group_shuffle(((INPUT0_TYPE*)(&in_0[bi]))[total_k / SIMD], total_k % SIMD); + const uint total_k = ki * TILE_K + kii; + unroll_for (uint bi = 0; bi < TILE_B; ++bi) { + INPUT0_TYPE in_val = _sub_group_shuffle(((INPUT0_TYPE*)(&in_0[bi]))[total_k / SIMD], total_k % SIMD); + unroll_for (uint fi = 0; fi < TILE_OFM; ++fi) { ((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += in_val * ((FILTER_TYPE*)(&wei))[kii * TILE_OFM + fi]; } } @@ -236,11 +237,18 @@ KERNEL(fc)( uint output_offset = out_f * TILE_OUT_F_PITCH + out_b * TILE_OUT_B_PITCH + OUTPUT_OFFSET; if (USE_BLOCK_WRITE && (TILE_OUT_F_NUM % (TILE_OFM * SIMD) == 0 || out_f + (TILE_OFM * SIMD) <= TILE_OUT_F_NUM)) { +#if IS_DYNAMIC + #define WRITE_OUTPUT(bi) do { \ + if (bi + out_b < BATCH_SIZE) \ + OUTPUT_BLOCK_WRITE(output, output_offset, result[bi]); \ + output_offset += TILE_OUT_B_PITCH; \ + } while (false) +#else #define WRITE_OUTPUT(bi) do { \ OUTPUT_BLOCK_WRITE(output, output_offset, result[bi]); \ output_offset += TILE_OUT_B_PITCH; \ } while (false) - +#endif CONST_LOOP(TILE_B, WRITE_OUTPUT); #undef WRITE_OUTPUT } else { @@ -270,8 +278,11 @@ KERNEL(fc)( for (uint bi = 0; bi < TILE_B; ++bi) { for (uint fi = 0; fi < TILE_OFM; ++fi) { const bool should_write = - TILE_OUT_F_NUM % (TILE_OFM * SIMD) == 0 || - out_f + fi * SIMD + sglid < TILE_OUT_F_NUM; +#if IS_DYNAMIC + bi + out_b < BATCH_SIZE && +#endif + (TILE_OUT_F_NUM % (TILE_OFM * SIMD) == 0 || + out_f + fi * SIMD + sglid < TILE_OUT_F_NUM); if (should_write) { output[output_offset] = ((OUTPUT_TYPE*)(&result[bi]))[fi]; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp index 60f174ddd1d..a4a3f1c94f0 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp @@ -51,6 +51,7 @@ ParamsKey FullyConnected_bf_tiled::GetSupportedKey() const { k.EnableTensorPitches(); k.EnableDifferentTypes(); k.EnableDifferentInputWeightsTypes(); + k.EnableDynamicShapesSupport(); return k; } @@ -72,10 +73,17 @@ bool FullyConnected_bf_tiled::Validate(const Params& params, const optional_para // Block reads must be aligned to 4 bytes, for fp16 we can correct for offset misalignment, // but we need to ensure that batch pitch preserves alignment. if (input.GetDType() == Datatype::F16) { - if (input.Batch().pitch % 2 != 0 && input.Batch().v > 1) + if (input.Batch().pitch % 2 != 0 && (input.Batch().v > 1 || fc_params.is_shape_agnostic)) return false; // for 3d case we have to check feature alignment as well - if (output.GetLayout() == DataLayout::bfyx && input.Feature().pitch % 2 != 0 && input.Feature().v > 1) + if (output.GetLayout() == DataLayout::bfyx && input.Feature().pitch % 2 != 0 && (input.Feature().v > 1 || fc_params.is_shape_agnostic)) + return false; + } + + // Dynamic kernel doesn't support dynamic weights yet + if (fc_params.is_shape_agnostic && input.is_dynamic()) { + if ((output.GetLayout() == DataLayout::bfyx && input.Y().v == 0) || + (output.GetLayout() == DataLayout::bf && input.Feature().v == 0)) return false; } @@ -141,7 +149,8 @@ bool TuneParamsSelector::VerifyTuneParams(const fully_connected_params& params, output_f = params.outputs[0].Y().v; } - if (output_b % (tparams.tile_b * tparams.dispatch_bsv) != 0) + auto batch_size = params.is_shape_agnostic ? Align(output_b, tparams.tile_b) : output_b; + if (batch_size % (tparams.tile_b * tparams.dispatch_bsv) != 0) return false; if (CeilDiv(output_f, tparams.tile_ofm * simd) % tparams.dispatch_fsv != 0) return false; @@ -191,44 +200,52 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params, while (max_tile_ofm * 2 * simd <= output_f && max_tile_ofm < 4) max_tile_ofm *= 2; - if (dtype == Datatype::F16) { - // tune_params(tile_b, tile_ofm, tile_ifm, tile_k, dispatch_bsv, dispatch_fsv, exec_options) - selector.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 16, 2, EXE_MODE_AGE_BASED)) - .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 16, 1, EXE_MODE_AGE_BASED)) - .Case(tune_params(16, std::min(max_tile_ofm, 2u), 1, 2, 4, 2, EXE_MODE_AGE_BASED)) - .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 8, 1, EXE_MODE_AGE_BASED)) - .Case(tune_params(16, std::min(max_tile_ofm, 2u), 1, 2, 2, 2, EXE_MODE_AGE_BASED)) - .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 4, 1, EXE_MODE_AGE_BASED)) - .Case(tune_params(16, std::min(max_tile_ofm, 2u), 1, 2, 1, 1, EXE_MODE_AGE_BASED)) - .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 1, 1, EXE_MODE_AGE_BASED)); + if (params.is_shape_agnostic) { + if (dtype == Datatype::F16) { + // tune_params(tile_b, tile_ofm, tile_ifm, tile_k, dispatch_bsv, dispatch_fsv, exec_options) + selector.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 1, 1, EXE_MODE_AGE_BASED)); + } else if (dtype == Datatype::F32) { + // tune_params(tile_b, tile_ofm, tile_ifm, tile_k, dispatch_bsv, dispatch_fsv, exec_options) + selector.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 1, 1, EXE_MODE_AGE_BASED)); + } + } else { + if (dtype == Datatype::F16) { + // tune_params(tile_b, tile_ofm, tile_ifm, tile_k, dispatch_bsv, dispatch_fsv, exec_options) + selector.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 16, 2, EXE_MODE_AGE_BASED)) + .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 16, 1, EXE_MODE_AGE_BASED)) + .Case(tune_params(16, std::min(max_tile_ofm, 2u), 1, 2, 4, 2, EXE_MODE_AGE_BASED)) + .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 8, 1, EXE_MODE_AGE_BASED)) + .Case(tune_params(16, std::min(max_tile_ofm, 2u), 1, 2, 2, 2, EXE_MODE_AGE_BASED)) + .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 4, 1, EXE_MODE_AGE_BASED)) + .Case(tune_params(16, std::min(max_tile_ofm, 2u), 1, 2, 1, 1, EXE_MODE_AGE_BASED)) + .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 1, 1, EXE_MODE_AGE_BASED)); + } else if (dtype == Datatype::F32) { + // tune_params(tile_b, tile_ofm, tile_ifm, tile_k, dispatch_bsv, dispatch_fsv, exec_options) + selector.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 16, 2, EXE_MODE_AGE_BASED)) + .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 16, 1, EXE_MODE_AGE_BASED)) + .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 8, 1, EXE_MODE_AGE_BASED)) + .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 4, 1, EXE_MODE_AGE_BASED)) + .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 2, 1, EXE_MODE_AGE_BASED)) + .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 1, 1, EXE_MODE_AGE_BASED)); + } + + selector.Case([&](const fully_connected_params&) -> tune_params { + tune_params result(8, std::min(max_tile_ofm, 2u), 1, 2, 1, 1, EXE_MODE_DEFAULT); + + while (batch % result.tile_b != 0) + result.tile_b--; + + result.dispatch_bsv = 16; + while (batch % (result.tile_b * result.dispatch_bsv) != 0) + result.dispatch_bsv--; + + if (result.tile_b >= 8) + result.exec_options = EXE_MODE_AGE_BASED; + + return result; + }); } - if (dtype == Datatype::F32) { - // tune_params(tile_b, tile_ofm, tile_ifm, tile_k, dispatch_bsv, dispatch_fsv, exec_options) - selector.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 16, 2, EXE_MODE_AGE_BASED)) - .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 16, 1, EXE_MODE_AGE_BASED)) - .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 8, 1, EXE_MODE_AGE_BASED)) - .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 4, 1, EXE_MODE_AGE_BASED)) - .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 2, 1, EXE_MODE_AGE_BASED)) - .Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 1, 1, 1, EXE_MODE_AGE_BASED)); - } - - selector.Case([&](const fully_connected_params&) -> tune_params { - tune_params result(8, std::min(max_tile_ofm, 2u), 1, 2, 1, 1, EXE_MODE_DEFAULT); - - while (batch % result.tile_b != 0) - result.tile_b--; - - result.dispatch_bsv = 16; - while (batch % (result.tile_b * result.dispatch_bsv) != 0) - result.dispatch_bsv--; - - if (result.tile_b >= 8) - result.exec_options = EXE_MODE_AGE_BASED; - - return result; - }); - return selector.Default(tune_params(1, 1, 1, 1, 1, 1, EXE_MODE_DEFAULT)); } @@ -238,12 +255,14 @@ FullyConnected_bf_tiled::SetDefault(const fully_connected_params& params, int au auto tparams = GetAutoTuneParams(params, autoTuneIndex); size_t feature_threads = CeilDiv(params.outputs[0].Feature().v, tparams.tile_ofm * simd); - size_t batch_threads = params.outputs[0].Batch().v / tparams.tile_b; + size_t batch_threads = params.outputs[0].Batch().v; if (params.outputs[0].GetLayout() == DataLayout::bfyx) { feature_threads = CeilDiv(params.outputs[0].Y().v, tparams.tile_ofm * simd); - batch_threads = (params.outputs[0].Batch().v * params.outputs[0].Feature().v) / tparams.tile_b; + batch_threads = params.outputs[0].Batch().v * params.outputs[0].Feature().v; } + batch_threads = CeilDiv(batch_threads, tparams.tile_b); + dispatchData.gws[0] = feature_threads * batch_threads * simd; dispatchData.gws[1] = 1; dispatchData.gws[2] = 1; @@ -308,11 +327,13 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para jit.AddConstant(MakeJitConstant("TILE_IN_B_PITCH", params.inputs[0].Feature().pitch)); jit.AddConstant(MakeJitConstant("TILE_OUT_B_PITCH", params.outputs[0].Feature().pitch)); jit.AddConstant(MakeJitConstant("OUTPUT_3D", true)); + jit.AddConstant(MakeJitConstant("BATCH_SIZE", "(OUTPUT_BATCH_NUM * OUTPUT_FEATURE_NUM)")); } else { jit.AddConstant(MakeJitConstant("TILE_OUT_F_NUM", params.outputs[0].Feature().v)); jit.AddConstant(MakeJitConstant("TILE_OUT_F_PITCH", params.outputs[0].Feature().pitch)); jit.AddConstant(MakeJitConstant("TILE_IN_B_PITCH", params.inputs[0].Batch().pitch)); jit.AddConstant(MakeJitConstant("TILE_OUT_B_PITCH", params.outputs[0].Batch().pitch)); + jit.AddConstant(MakeJitConstant("BATCH_SIZE", "(OUTPUT_BATCH_NUM)")); } if (!params.fused_ops.empty()) { diff --git a/src/plugins/intel_gpu/tests/fake_alignment/fc_fake_alignment_test.cpp b/src/plugins/intel_gpu/tests/fake_alignment/fc_fake_alignment_test.cpp index 2e26e39460b..8cc0318cc1b 100644 --- a/src/plugins/intel_gpu/tests/fake_alignment/fc_fake_alignment_test.cpp +++ b/src/plugins/intel_gpu/tests/fake_alignment/fc_fake_alignment_test.cpp @@ -77,10 +77,10 @@ INSTANTIATE_TEST_SUITE_P(smoke, fully_connected_fake_align_test, }, { layout{ov::PartialShape{133, 511}, data_types::i8, format::bfyx, padding{{1,1,1,1}, 0}}, // input_layout - layout{ov::PartialShape{800, 511}, data_types::i8, format::bfyx}, // weight layout + layout{ov::PartialShape{800, 511}, data_types::i8, format::bfyx}, // weight layout data_types::f16, - layout{ov::PartialShape{144, 511}, data_types::i8, format::bfyx, padding{{1,1,1,1}, 0}}, // fake_aligned input layout - layout{ov::PartialShape{144, 800}, data_types::f16, format::bfyx} // fake_aligned output layout + layout{ov::PartialShape{136, 511}, data_types::i8, format::bfyx, padding{{1,1,1,1}, 0}}, // fake_aligned input layout + layout{ov::PartialShape{136, 800}, data_types::f16, format::bfyx} // fake_aligned output layout }, { layout{ov::PartialShape::dynamic(2), data_types::i8, format::bfyx, padding{{1,1,1,1}, 0}}, // input_layout diff --git a/src/plugins/intel_gpu/tests/test_cases/fully_connected_gpu_test.cpp b/src/plugins/intel_gpu/tests/test_cases/fully_connected_gpu_test.cpp index 263131223e6..3aedd83a2fb 100644 --- a/src/plugins/intel_gpu/tests/test_cases/fully_connected_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/test_cases/fully_connected_gpu_test.cpp @@ -1791,7 +1791,7 @@ TEST(fully_connected_gpu, dynamic) { auto output_prim_mem = outputs.begin()->second.get_memory(); auto out_l = network.get_output_layout(outputs.begin()->first); - ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(input_b, 16)); // fake_alignment + ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(input_b, 8)); // fake_alignment ASSERT_EQ(out_l.batch(), input_b); ASSERT_EQ(out_l.feature(), weight_b); ASSERT_EQ(out_l.spatial(0), 1); @@ -1843,7 +1843,7 @@ TEST(fully_connected_gpu, dynamic_multi_inference_same_shape) { auto output_prim_mem = outputs.begin()->second.get_memory(); auto out_l = network.get_output_layout(outputs.begin()->first); - ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(input_b, 16)); // fake_alignment + ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(input_b, 8)); // fake_alignment ASSERT_EQ(out_l.batch(), input_b); ASSERT_EQ(out_l.feature(), weight_b); ASSERT_EQ(out_l.spatial(0), 1); @@ -1867,7 +1867,7 @@ TEST(fully_connected_gpu, dynamic_multi_inference_same_shape) { auto output_prim_mem = outputs.begin()->second.get_memory(); auto out_l = network.get_output_layout(outputs.begin()->first); - ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(input_b, 16)); // fake_alignment + ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(input_b, 8)); // fake_alignment ASSERT_EQ(out_l.batch(), input_b); ASSERT_EQ(out_l.feature(), weight_b); ASSERT_EQ(out_l.spatial(0), 1); @@ -1928,7 +1928,7 @@ TEST(fully_connected_gpu, dynamic_multi_inference_different_shape) { auto output_prim_mem = outputs.begin()->second.get_memory(); auto out_l = network.get_output_layout(outputs.begin()->first); - ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(2, 16)); // fake_alignment + ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(2, 8)); // fake_alignment ASSERT_EQ(out_l.batch(), 2); ASSERT_EQ(out_l.feature(), weight_b); ASSERT_EQ(out_l.spatial(0), 1); @@ -1957,7 +1957,7 @@ TEST(fully_connected_gpu, dynamic_multi_inference_different_shape) { auto output_prim_mem = outputs.begin()->second.get_memory(); auto out_l = network.get_output_layout(outputs.begin()->first); - ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(1, 16)); // fake_alignment + ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(1, 8)); // fake_alignment ASSERT_EQ(out_l.batch(), 1); ASSERT_EQ(out_l.feature(), weight_b); ASSERT_EQ(out_l.spatial(0), 1); @@ -2015,7 +2015,7 @@ TEST(fully_connected_gpu, dynamic_multi_inference_multiple_shapes) { auto output_prim_mem = outputs.begin()->second.get_memory(); auto out_l = network.get_output_layout(outputs.begin()->first); - ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(2, 16)); // fake_alignment + ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(2, 8)); // fake_alignment ASSERT_EQ(out_l.batch(), 2); // fake_alignment ASSERT_EQ(out_l.feature(), weight_b); ASSERT_EQ(out_l.spatial(0), 1); @@ -2044,7 +2044,7 @@ TEST(fully_connected_gpu, dynamic_multi_inference_multiple_shapes) { auto output_prim_mem = outputs.begin()->second.get_memory(); auto out_l = network.get_output_layout(outputs.begin()->first); - ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(1, 16)); // fake_alignment + ASSERT_EQ(output_prim_mem->get_layout().batch(), align_to(1, 8)); // fake_alignment ASSERT_EQ(out_l.batch(), 1); // fake_alignment ASSERT_EQ(out_l.feature(), weight_b); ASSERT_EQ(out_l.spatial(0), 1); @@ -2059,3 +2059,175 @@ TEST(fully_connected_gpu, dynamic_multi_inference_multiple_shapes) { } } } + +namespace { + template + VF dynamic_fully_connected_reference_calc(ov::Dimension::value_type batch, + ov::Dimension::value_type input_f, + ov::Dimension::value_type output_f, + VF& input, + VF& weights, + VF& bias) { + VF result(batch * output_f); + for (int b = 0; b < batch; b++) { + for (int ofm = 0; ofm < output_f; ofm++) { + AccT acc = static_cast(bias[ofm]); + for (int ifm = 0; ifm < input_f; ifm++) { + acc += weights[ofm * input_f + ifm] * input[b * input_f + ifm]; + } + result[b * output_f + ofm] = acc; + } + } + + return result; + } +} // namespace + +using fully_connected_dynamic_test_params = std::tuple< + std::vector, // batch_sizes + ov::Dimension::value_type, // input_f + ov::Dimension::value_type, // output_f + bool // 3D case +>; + +template +struct dynamic_fully_connected_gpu : ::testing::TestWithParam { + void run_test() { + std::vector batch_sizes; + ov::Dimension::value_type input_f; + ov::Dimension::value_type output_f; + bool fc_3d = false; + + std::tie(batch_sizes, input_f, output_f, fc_3d) = GetParam(); + + auto input_dt = cldnn::type_to_data_type::value; + auto weights_dt = cldnn::type_to_data_type::value; + auto output_dt = cldnn::type_to_data_type::value; + + auto& engine = get_test_engine(); + auto input_dyn_layout = layout{ ov::PartialShape{ ov::Dimension(), input_f }, input_dt, format::bfyx }; + if (fc_3d) + input_dyn_layout = layout{ ov::PartialShape{ ov::Dimension(), ov::Dimension(), input_f }, input_dt, format::bfyx }; + + auto weights_mem = engine.allocate_memory({ ov::PartialShape{ output_f, input_f }, weights_dt, format::bfyx }); + auto weights_data_vec = generate_random_1d(output_f * input_f, -1, 1); + + auto bias_mem = engine.allocate_memory({ ov::PartialShape{ output_f }, output_dt, format::bfyx }); + auto bias_data_vec = generate_random_1d(output_f, 0, 1); + + set_values(weights_mem, weights_data_vec); + set_values(bias_mem, bias_data_vec); + + cldnn::topology topology{ + input_layout("input", input_dyn_layout), + data("weights", weights_mem), + data("bias", bias_mem), + }; + + if (fc_3d) + topology.add(fully_connected("fc", input_info("input"), "weights", "bias", padding(), 3)); + else + topology.add(fully_connected("fc", input_info("input"), "weights", "bias")); + + ExecutionConfig config; + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + network network(engine, topology, config); + + for (const auto& batch_size : batch_sizes) { + auto input_actual_layout = layout{ ov::PartialShape{ batch_size, input_f }, input_dt, format::bfyx }; + if (fc_3d) + input_actual_layout = layout{ ov::PartialShape{ 1, batch_size, input_f }, input_dt, format::bfyx }; + cldnn::memory_ptr input_mem = engine.allocate_memory(input_actual_layout); + std::vector input_data_vec = generate_random_1d(batch_size * input_f, 0, 1); + set_values(input_mem, input_data_vec); + network.set_input_data("input", input_mem); + + auto outputs = network.execute(); + ASSERT_EQ(outputs.size(), size_t(1)); + ASSERT_EQ(outputs.begin()->first, "fc"); + + auto output_prim_mem = outputs.begin()->second.get_memory(); + + auto out_l = network.get_output_layout(outputs.begin()->first); + ASSERT_EQ(out_l.batch(), fc_3d ? 1 : batch_size); + ASSERT_EQ(out_l.feature(), fc_3d ? batch_size : output_f); + ASSERT_EQ(out_l.spatial(0), 1); + ASSERT_EQ(out_l.spatial(1), fc_3d ? output_f : 1); + + cldnn::mem_lock output_ptr(output_prim_mem, get_test_stream()); + + auto ref_result = dynamic_fully_connected_reference_calc(batch_size, + input_f, + output_f, + input_data_vec, + weights_data_vec, + bias_data_vec); + for (int b = 0; b < batch_size; b++) { + for (int ofm = 0; ofm < output_f; ofm++) { + ASSERT_EQ(ref_result[b * output_f + ofm], output_ptr[b * output_f + ofm]); + } + } + } + } +}; + +using dynamic_fully_connected_gpu_f32_3d = dynamic_fully_connected_gpu; +using dynamic_fully_connected_gpu_f16_3d = dynamic_fully_connected_gpu; + +static const std::vector + dyn_batches_full = {1, 2, 4, 7, 8, 9, 15, 16, 31, 32, 33, 47, 48, 49, 58, 63, 64}; +static const std::vector + dyn_batches_smoke = {1, 2, 7, 8, 9, 16, 32, 33, 47, 48, 58}; + +TEST_P(dynamic_fully_connected_gpu_f32_3d, basic) { + run_test(); +} + +TEST_P(dynamic_fully_connected_gpu_f16_3d, basic) { + run_test(); +} + +INSTANTIATE_TEST_SUITE_P( + smoke, + dynamic_fully_connected_gpu_f32_3d, + ::testing::Combine( + ::testing::Values(dyn_batches_smoke), + ::testing::Values(10, 32, 42, 53, 64, 128), + ::testing::Values(2, 9, 128), + ::testing::Values(false, true)) +); + +INSTANTIATE_TEST_SUITE_P( + smoke, + dynamic_fully_connected_gpu_f16_3d, + ::testing::Combine( + ::testing::Values(dyn_batches_smoke), + ::testing::Values(10, 32, 42, 53, 64, 128), + ::testing::Values(2, 9, 128), + ::testing::Values(false, true)) +); + +INSTANTIATE_TEST_SUITE_P( + full, + dynamic_fully_connected_gpu_f32_3d, + ::testing::Combine( + ::testing::Values(dyn_batches_full), + ::testing::Values(10, 32, 42, 53, 64, 128), + ::testing::Values(2, 9, 16, 32, 64, 128), + ::testing::Values(false, true)) +); + +INSTANTIATE_TEST_SUITE_P( + full, + dynamic_fully_connected_gpu_f16_3d, + ::testing::Combine( + ::testing::Values(dyn_batches_full), + ::testing::Values(10, 32, 42, 53, 64, 128), + ::testing::Values(2, 9, 16, 32, 64, 128), + ::testing::Values(false, true)) +); diff --git a/src/plugins/intel_gpu/tests/test_cases/hash_key_gpu_test.cpp b/src/plugins/intel_gpu/tests/test_cases/hash_key_gpu_test.cpp index bf273f613d8..11aa1de291c 100644 --- a/src/plugins/intel_gpu/tests/test_cases/hash_key_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/test_cases/hash_key_gpu_test.cpp @@ -59,7 +59,7 @@ TEST(check_hash_value, fc_basic) { auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { out_f, in_f, in_y, in_x } }); auto bias_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, out_f, 1 } }); - auto key_prim_id = "eltwise"; + const auto key_prim_id = "fc"; topology topology( input_layout("input", input_prim->get_layout()), data("weights", weights_prim), @@ -79,7 +79,7 @@ TEST(check_hash_value, fc_basic) { ASSERT_EQ(primitive_hash, 7881065839556591629UL); ASSERT_EQ(prog_node_hash, 7881065839556591629UL); - ASSERT_EQ(prim_inst_hash, 2803059017090178132UL); + ASSERT_EQ(prim_inst_hash, 12327057149074647711UL); } TEST(check_hash_value, gather_basic) {