diff --git a/inference-engine/src/cldnn_engine/cldnn_engine.cpp b/inference-engine/src/cldnn_engine/cldnn_engine.cpp index 41126fa4289..a1d89d7b6b1 100644 --- a/inference-engine/src/cldnn_engine/cldnn_engine.cpp +++ b/inference-engine/src/cldnn_engine/cldnn_engine.cpp @@ -43,6 +43,7 @@ #include #include #include +#include #include #include #include @@ -241,11 +242,17 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const In transformer.transform(nGraphFunc); } + const auto reshape_fc_callback = [](const std::shared_ptr& node) -> bool { + return node->input_value(0).get_shape().size() <= 3lu; + }; + { ngraph::pass::Manager manager = ngraph::pass::Manager(); manager.register_pass(); manager.register_pass(); manager.set_callback(transformations_callback); + auto pass_config = manager.get_pass_config(); + pass_config->set_callback(reshape_fc_callback); manager.run_passes(nGraphFunc); } diff --git a/inference-engine/src/cldnn_engine/cldnn_program.cpp b/inference-engine/src/cldnn_engine/cldnn_program.cpp index 535821e187a..1e08c0893d4 100644 --- a/inference-engine/src/cldnn_engine/cldnn_program.cpp +++ b/inference-engine/src/cldnn_engine/cldnn_program.cpp @@ -1084,6 +1084,8 @@ void Program::CreateWeightAndBiasPrimitives(cldnn::topology& topology, case FullyConnected: { groupSize = 1; outFeatures = static_cast(layer->outData[0]->getTensorDesc().getDims()[1]); + if (in0dims.size() == 3) + outFeatures = static_cast(layer->outData[0]->getTensorDesc().getDims()[2]); switch (in0dims.size()) { case 4: weightDimsVec = { TensorValue(layer->outData[0]->getTensorDesc().getDims().back()), @@ -1093,8 +1095,8 @@ void Program::CreateWeightAndBiasPrimitives(cldnn::topology& topology, break; case 3: weightDimsVec = { TensorValue(layer->outData[0]->getTensorDesc().getDims().back()), - TensorValue(in0dims[1]), TensorValue(in0dims[2]), + 1, 1 }; break; case 2: @@ -2927,11 +2929,14 @@ void Program::CreateFullyConnectedPrimitive(cldnn::topology& topology, Inference IE_ASSERT(weightPrimID.size() == 1); IE_ASSERT(biasPrimID.size() <= 1); + auto outDims = layer->outData[0]->getTensorDesc().getDims().size(); auto fcPrim = cldnn::fully_connected(fcLayerName, inputPrimitives[0], weightPrimID[0], biasPrimID.empty() ? "" : biasPrimID[0], - DataTypeFromPrecision(fcLayer->outData[0]->getTensorDesc().getPrecision())); + DataTypeFromPrecision(fcLayer->outData[0]->getTensorDesc().getPrecision()), + cldnn::padding(), + layer->outData[0]->getTensorDesc().getDims().size()); topology.add(fcPrim); diff --git a/inference-engine/thirdparty/clDNN/api/fully_connected.hpp b/inference-engine/thirdparty/clDNN/api/fully_connected.hpp index 3c2f22e26eb..0801e8fd03a 100644 --- a/inference-engine/thirdparty/clDNN/api/fully_connected.hpp +++ b/inference-engine/thirdparty/clDNN/api/fully_connected.hpp @@ -61,10 +61,12 @@ struct fully_connected : public primitive_base { const primitive_id& input, const primitive_id& weights, const primitive_id& bias = "", - const padding& output_padding = padding()) + const padding& output_padding = padding(), + const size_t input_size = 2) : primitive_base(id, {input}, output_padding), weights(weights), - bias(bias) + bias(bias), + input_size(input_size) {} /// @brief Constructs fully connected layer. @@ -77,16 +79,20 @@ struct fully_connected : public primitive_base { const primitive_id& weights, const primitive_id& bias, const data_types data_type, - const padding& output_padding = padding()) + const padding& output_padding = padding(), + const size_t input_size = 2) : primitive_base(id, { input }, output_padding, optional_data_type{data_type}), weights(weights), - bias(bias) + bias(bias), + input_size(input_size) {} /// @brief Primitive id containing weights data. primitive_id weights; /// @brief Primitive id containing bias data. primitive_id bias; + /// @brief Primitive dimension size. + size_t input_size; protected: std::vector> get_dependencies() const override { diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp index 858c43b4c56..69cfdc0f1d0 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp @@ -52,6 +52,7 @@ ParamsKey FullyConnected_bf_tiled::GetSupportedKey() const { k.EnableInputLayout(DataLayout::bf); k.EnableInputLayout(DataLayout::bfyx); k.EnableOutputLayout(DataLayout::bf); + k.EnableOutputLayout(DataLayout::bfyx); k.EnableBatching(); k.EnableBiasPerFeature(); k.EnableNonBiasTerm(); @@ -68,11 +69,17 @@ bool FullyConnected_bf_tiled::Validate(const Params& params, const optional_para auto& fc_params = static_cast(params); auto& input = fc_params.inputs[0]; + auto& output = fc_params.output; // Block reads must be aligned to 4 bytes, for fp16 we can correct for offset misalignment, // but we need to ensure that batch pitch preserves alignment. - if (input.GetDType() == Datatype::F16 && input.Batch().pitch % 2 != 0 && input.Batch().v > 1) - return false; + if (input.GetDType() == Datatype::F16) { + if (input.Batch().pitch % 2 != 0 && input.Batch().v > 1) + return false; + // for 3d case we have to check feature alignment as well + if (output.GetLayout() == DataLayout::bfyx && input.Feature().pitch % 2 != 0 && input.Feature().v > 1) + return false; + } if (input.GetLayout() == DataLayout::bfyx) { // Padding on input is not supported. @@ -83,6 +90,12 @@ bool FullyConnected_bf_tiled::Validate(const Params& params, const optional_para return false; } + // We don't support 4d output + if (fc_params.output.GetLayout() == DataLayout::bfyx) { + if (input.X().v > 1) + return false; + } + return true; } @@ -127,13 +140,20 @@ struct TuneParamsSelector { bool TuneParamsSelector::VerifyTuneParams(const fully_connected_params& params, const tune_params& tparams) { // Check divisibility by dispatch tile sizes. - if (params.output.Batch().v % (tparams.tile_b * tparams.dispatch_bsv) != 0) + size_t output_f = params.output.Feature().v; + size_t output_b = params.output.Batch().v; + if (params.output.GetLayout() == DataLayout::bfyx) { + output_b *= params.output.Feature().v; + output_f = params.output.Y().v; + } + + if (output_b % (tparams.tile_b * tparams.dispatch_bsv) != 0) return false; - if (CeilDiv(params.output.Feature().v, tparams.tile_ofm * simd) % tparams.dispatch_fsv != 0) + if (CeilDiv(output_f, tparams.tile_ofm * simd) % tparams.dispatch_fsv != 0) return false; // Same result can be achieved with smaller tile_ofm. - if (params.output.Feature().v <= (tparams.tile_ofm / 2) * simd) + if (output_f <= (tparams.tile_ofm / 2) * simd) return false; // No weights layout for such huge tile ofm. if (tparams.tile_ofm * simd > 64) @@ -163,6 +183,12 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params, size_t batch = params.output.Batch().v; size_t output_f = params.output.Feature().v; + + // 3d output + if (params.output.GetLayout() == DataLayout::bfyx) { + batch *= params.output.Feature().v; + output_f = params.output.Y().v; + } Datatype dtype = params.inputs[0].GetDType(); auto selector = TuneParamsSelector(params); @@ -219,6 +245,10 @@ FullyConnected_bf_tiled::SetDefault(const fully_connected_params& params, int au size_t feature_threads = CeilDiv(params.output.Feature().v, tparams.tile_ofm * simd); size_t batch_threads = params.output.Batch().v / tparams.tile_b; + if (params.output.GetLayout() == DataLayout::bfyx) { + feature_threads = CeilDiv(params.output.Y().v, tparams.tile_ofm * simd); + batch_threads = (params.output.Batch().v * params.output.Feature().v) / tparams.tile_b; + } dispatchData.gws[0] = feature_threads * batch_threads * simd; dispatchData.gws[1] = 1; @@ -252,7 +282,7 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para jit.Merge(MakeConstantLoopUnrollJitConstants(dispatchData.tile_m)); - bool realign_fp16_offset = params.inputs[0].GetDType() == Datatype::F16 && params.output.GetFirstElementOffset() % 2 != 0; + bool realign_fp16_offset = params.inputs[0].GetDType() == Datatype::F16 && params.inputs[0].GetFirstElementOffset() % 2 != 0; jit.AddConstant(MakeJitConstant("REALIGN_FP16_OFFSET", realign_fp16_offset)); auto activation_dt = GetActivationType(params); @@ -260,13 +290,32 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para jit.Merge(MakeTypeJitConstants(activation_dt, "ACTIVATION")); jit.Merge(MakeActivationJitConstants(params.activations, activation_dt, "_TYPED")); + // for 3d output we are treating spatial as features + if (params.output.GetLayout() == DataLayout::bfyx) { + jit.AddConstant(MakeJitConstant("TILE_OUT_F_NUM", params.output.Y().v)); + jit.AddConstant(MakeJitConstant("TILE_OUT_F_PITCH", params.output.Y().pitch)); + jit.AddConstant(MakeJitConstant("TILE_IN_B_PITCH", params.inputs[0].Feature().pitch)); + jit.AddConstant(MakeJitConstant("TILE_OUT_B_PITCH", params.output.Feature().pitch)); + jit.AddConstant(MakeJitConstant("OUTPUT_3D", true)); + } + else { + jit.AddConstant(MakeJitConstant("TILE_OUT_F_NUM", params.output.Feature().v)); + jit.AddConstant(MakeJitConstant("TILE_OUT_F_PITCH", params.output.Feature().pitch)); + jit.AddConstant(MakeJitConstant("TILE_IN_B_PITCH", params.inputs[0].Batch().pitch)); + jit.AddConstant(MakeJitConstant("TILE_OUT_B_PITCH", params.output.Batch().pitch)); + } + + size_t output_f = params.output.GetLayout() == DataLayout::bfyx ? params.output.Y().v : params.output.Feature().v; if (!params.fused_ops.empty()) { auto boundary_check = BoundaryCheck::DISABLED; - if (params.output.Feature().v % (dispatchData.tile_n * simd) != 0) + if (output_f % (dispatchData.tile_n * simd) != 0) boundary_check = BoundaryCheck::ENABLED; + std::vector idx_order = {"(out_b + bi)", "out_f", "0", "0"}; + if (params.output.GetLayout() == DataLayout::bfyx) + idx_order = {"(out_b + bi) % OUTPUT_BATCH_NUM", "(out_b + bi) / OUTPUT_BATCH_NUM", "out_f", "0"}; FusedOpsConfiguration conf = { "", - {"(out_b + bi)", "out_f", "0", "0"}, + idx_order, "activated[bi]", activation_dt, dispatchData.tile_n, @@ -284,6 +333,9 @@ KernelsData FullyConnected_bf_tiled::GetTunedKernelsDataByIndex(const Params &pa const optional_params &options, const int autoTuneIndex) const { auto& fc_params = static_cast(params); + size_t output_b = fc_params.output.Batch().v; + if (fc_params.output.GetLayout() == DataLayout::bfyx) + output_b *= fc_params.output.Feature().v; if (autoTuneIndex >= 0 && autoTuneIndex < (int)auto_tune_params.size() && !TuneParamsSelector::VerifyTuneParams(fc_params, auto_tune_params[autoTuneIndex])) @@ -298,9 +350,9 @@ KernelsData FullyConnected_bf_tiled::GetTunedKernelsDataByIndex(const Params &pa weights_layout = WeightsLayout::os_iyx_osv64; float estimated_time = DONT_USE_IF_HAVE_SOMETHING_ELSE; - if (fc_params.output.Batch().v > 1 && fc_params.inputs[0].GetDType() == Datatype::F32) + if (output_b > 1 && fc_params.inputs[0].GetDType() == Datatype::F32) estimated_time = FORCE_PRIORITY_3; - if (fc_params.output.Batch().v > 1 && fc_params.inputs[0].GetDType() == Datatype::F16) + if (output_b > 1 && fc_params.inputs[0].GetDType() == Datatype::F16) estimated_time = FORCE_PRIORITY_4; return GetCommonKernelsData(params, diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_bfyx_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_bfyx_ref.cpp index 4937335e345..5665f3db65b 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_bfyx_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_bfyx_ref.cpp @@ -34,8 +34,10 @@ ParamsKey FullyConnected_bfyx_Ref::GetSupportedKey() const { k.EnableDifferentInputWeightsTypes(); k.EnableDifferentTypes(); k.EnableInputLayout(DataLayout::bf); + k.EnableInputLayout(DataLayout::bfyx); k.EnableOutputLayout(DataLayout::bf); k.EnableOutputLayout(DataLayout::fb); + k.EnableOutputLayout(DataLayout::bfyx); k.EnableBiasPerOutput(); k.EnableBiasPerFeature(); k.EnableNonBiasTerm(); @@ -50,7 +52,11 @@ FullyConnected_bfyx_Ref::DispatchData FullyConnected_bfyx_Ref::SetDefault(const int) const { auto dispatchData = Parent::SetDefault(params); - dispatchData.gws = { params.output.Feature().v, params.output.Batch().v, 1 }; + std::vector global = {params.output.Feature().v, params.output.Batch().v, 1}; + if (params.output.GetLayout() == DataLayout::bfyx) + global = {params.output.Feature().v * params.output.Y().v, params.output.Batch().v, 1}; + + dispatchData.gws = global; dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); return dispatchData; @@ -69,13 +75,14 @@ JitConstants FullyConnected_bfyx_Ref::GetJitConstants(const fully_connected_para accumulator_dt = Datatype::F32; activation_dt = Datatype::F32; } - + if (params.output.GetLayout() == DataLayout::bfyx) + jit.AddConstant(MakeJitConstant("OUTPUT_3D", true)); jit.Merge(MakeTypeJitConstants(activation_dt, "ACTIVATION")); jit.Merge(MakeTypeJitConstants(accumulator_dt, "ACCUMULATOR")); jit.Merge(MakeActivationJitConstants(params.activations, activation_dt, "_TYPED")); if (!params.fused_ops.empty()) { - FusedOpsConfiguration conf = { "", {"b", "ofm", "y", "x"}, "dequantized", activation_dt, 1 }; + FusedOpsConfiguration conf = { "", {"b", "ofm", "oym", "0"}, "dequantized", activation_dt, 1 }; jit.Merge(MakeFusedOpsJitConstants(params, { conf })); } return jit; @@ -126,6 +133,10 @@ bool FullyConnected_bfyx_Ref::Validate(const Params& params, const optional_para if (!is_quantization && !has_fused_op) return false; + // We don't support 4d output + if (fc_params.output.GetLayout() == DataLayout::bfyx && fc_params.output.X().v > 1) + return false; + return true; } diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_mmad.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_mmad.cpp index 306d4b60d23..6c7890ce77a 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_mmad.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/fully_connected/fully_connected_kernel_mmad.cpp @@ -66,9 +66,21 @@ FullyConnectedKernelMMAD::FullyConnectedTuningData FullyConnectedKernelMMAD::Get const auto& input = params.inputs[0]; const auto& output = params.output; + size_t input_feature = input.Feature().v; + size_t input_batch = input.Batch().v; + size_t output_feature = output.Feature().v; + size_t output_batch = output.Batch().v; + // for 3d case + if (output.GetLayout() == DataLayout::bfyx) { + input_batch *= input.Feature().v; + input_feature = input.Y().v; + output_batch *= output.Feature().v; + output_feature = output.Y().v; + } tuning_data.sub_group_size = 8; - if (input.X().v == 1 && input.Y().v == 1 && input.Z().v == 1 && input.Batch().v == 1) { + if (input.X().v == 1 && input.Z().v == 1 && input.Batch().v == 1 && + ((input.Y().v == 1 && output.GetLayout() != DataLayout::bfyx) || (input.Feature().v == 1 && output.GetLayout() == DataLayout::bfyx)) ) { // Known cases for TGL where simd16 works better than simd8 bool simd16_exception_1 = input.Feature().v == 25088 && output.Feature().v == 512; bool simd16_exception_2 = input.Feature().v == 21504 && output.Feature().v == 512; @@ -79,14 +91,14 @@ FullyConnectedKernelMMAD::FullyConnectedTuningData FullyConnectedKernelMMAD::Get size_t sub_group_pack_size = tuning_data.sub_group_size * tuning_data.pack_size; - tuning_data.feature_blocks_count = input.GetLayout() == DataLayout::bfyx && input.Feature().v % sub_group_pack_size != 0 ? - input.Feature().v / sub_group_pack_size : + tuning_data.feature_blocks_count = input.GetLayout() == DataLayout::bfyx && input_feature % sub_group_pack_size != 0 ? + input_feature / sub_group_pack_size : input.GetLayout() != DataLayout::bfyx && tuning_data.sub_group_size == 16 ? - CeilDiv(input.Feature().v, 32) % 2 == 0 ? CeilDiv(input.Feature().v, 64) : CeilDiv(input.Feature().v, 64) - 1 : - CeilDiv(input.Feature().v, sub_group_pack_size); + CeilDiv(input_feature, 32) % 2 == 0 ? CeilDiv(input_feature, 64) : CeilDiv(input_feature, 64) - 1 : + CeilDiv(input_feature, sub_group_pack_size); - bool slm_div_factor_exception = input.Batch().v == 300 && input.Feature().v == 2048 && - output.Batch().v == 300 && (output.Feature().v == 324 || output.Feature().v == 81); + bool slm_div_factor_exception = input_batch == 300 && input_feature == 2048 && + output_batch == 300 && (output_feature == 324 || output_feature == 81); if (tuning_data.feature_blocks_count && tuning_data.sub_group_size == 8 && !slm_div_factor_exception) while (tuning_data.feature_blocks_count % (tuning_data.slm_div_factor * 2) == 0 && @@ -120,7 +132,11 @@ FullyConnectedKernelMMAD::DispatchData FullyConnectedKernelMMAD::SetDefault(cons auto dispatchData = Parent::SetDefault(params); const auto& output = params.output; - dispatchData.gws = { Align(output.Feature().v, tuning_data.sub_group_size) * tuning_data.slm_div_factor, output.Batch().v, 1 }; + std::vector global = { Align(output.Feature().v, tuning_data.sub_group_size) * tuning_data.slm_div_factor, output.Batch().v, 1 }; + if (output.GetLayout() == DataLayout::bfyx) + global = { Align(output.Y().v, tuning_data.sub_group_size) * tuning_data.slm_div_factor, output.Batch().v, output.Feature().v }; + + dispatchData.gws = global; dispatchData.lws = { tuning_data.work_group_size, 1, 1 }; return dispatchData; @@ -133,6 +149,7 @@ JitConstants FullyConnectedKernelMMAD::GetJitConstants(const fully_connected_par auto jit = Parent::GetJitConstants(params, runInfo); auto& input = params.inputs[0]; + auto& output = params.output; auto& weights = params.weights; size_t sub_group_pack_size = tuning_data.sub_group_size * tuning_data.pack_size; @@ -181,6 +198,9 @@ JitConstants FullyConnectedKernelMMAD::GetJitConstants(const fully_connected_par bool has_feature_leftovers = (input.GetLayout() == DataLayout::bfyx && input.Feature().v % sub_group_pack_size) || (input.GetLayout() != DataLayout::bfyx && tuning_data.sub_group_size == 16 && CeilDiv(input.Feature().v, 32) % 2); + if (output.GetLayout() == DataLayout::bfyx) + has_feature_leftovers = input.Y().v % sub_group_pack_size; + jit.AddConstant(MakeJitConstant("HAS_FEATURE_LEFTOVERS", has_feature_leftovers)); jit.AddConstant(MakeJitConstant("FEATURE_BLOCKS_COUNT", tuning_data.feature_blocks_count)); jit.AddConstant(MakeJitConstant("SLM_DIV_FACTOR", tuning_data.slm_div_factor)); @@ -200,9 +220,24 @@ JitConstants FullyConnectedKernelMMAD::GetJitConstants(const fully_connected_par jit.AddConstant(MakeJitConstant("SPLIT_SPATIAL", split_spatial)); jit.AddConstant(MakeJitConstant("SPATIAL_MAJOR", spatial_major)); + if (output.GetLayout() == DataLayout::bfyx) { + jit.AddConstant(MakeJitConstant("FEATURE_PITCH", input.Y().pitch)); + jit.AddConstant(MakeJitConstant("OUT_FEATURE_NUM", output.Y().v)); + jit.AddConstant(MakeJitConstant("IN_FEATURE_NUM", input.Y().v)); + jit.AddConstant(MakeJitConstant("IS_3D", true)); + } else { + jit.AddConstant(MakeJitConstant("FEATURE_PITCH", input.Feature().pitch)); + jit.AddConstant(MakeJitConstant("OUT_FEATURE_NUM", output.Feature().v)); + jit.AddConstant(MakeJitConstant("IN_FEATURE_NUM", input.Feature().v)); + } + if (!params.fused_ops.empty()) { auto input_dt = GetActivationType(params); - FusedOpsConfiguration conf = { "", {"batch", "feature", "0", "0"}, "dequantized", input_dt, 1 }; + std::vector idx_order = {"batch", "feature", "0", "0"}; + if (output.GetLayout() == DataLayout::bfyx) + idx_order = {"batch", "skip_f", "feature", "0"}; + + FusedOpsConfiguration conf = { "", idx_order, "dequantized", input_dt, 1 }; jit.Merge(MakeFusedOpsJitConstants(params, { conf })); } diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_MMAD.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_MMAD.cl index 7b59f7e15d5..3dadc5fe9f9 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_MMAD.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_MMAD.cl @@ -49,6 +49,7 @@ KERNEL(fully_connected_gpu_MMAD)( const uint feature = (uint)get_group_id(0) * feature_per_wg + (uint)get_global_id(0) % feature_per_wg; const uint feature_block = lid0 / feature_per_wg; const uint batch = (uint)get_global_id(1); + const uint skip_f = (uint)get_global_id(2); int dotProd = 0; @@ -56,7 +57,7 @@ KERNEL(fully_connected_gpu_MMAD)( #if INPUT0_DIMS == 5 const uint input_offset = INPUT0_GET_INDEX(batch, 0, 0, 0, 0); #else - const uint input_offset = INPUT0_GET_INDEX(batch, 0, 0, 0); + const uint input_offset = INPUT0_GET_INDEX(batch, skip_f, 0, 0); #endif #if SLM_DIV_FACTOR > 1 @@ -221,17 +222,17 @@ KERNEL(fully_connected_gpu_MMAD)( #endif // SPATIAL_MAJOR #if !SPLIT_SPATIAL - uint input_idx = input_offset + spatial * MMAD_INPUT_SPATIAL_PITCH + FEATURE_BLOCKS_COUNT * INPUT0_FEATURE_PITCH; + uint input_idx = input_offset + spatial * MMAD_INPUT_SPATIAL_PITCH + FEATURE_BLOCKS_COUNT * FEATURE_PITCH; #else - uint input_idx = input_offset + FEATURE_BLOCKS_COUNT * INPUT0_FEATURE_PITCH + + uint input_idx = input_offset + FEATURE_BLOCKS_COUNT * FEATURE_PITCH + zi * MMAD_INPUT_Z_PITCH + yi * MMAD_INPUT_Y_PITCH + xi * MMAD_INPUT_X_PITCH; #endif // !SPLIT_SPATIAL uint filter_idx = filter_offset + spatial * MMAD_FILTER_SPATIAL_PITCH + FEATURE_BLOCKS_COUNT * MMAD_FILTER_FBLOCK_PITCH; MAKE_VECTOR_TYPE(INPUT0_TYPE, 4) input_data_u = (0, 0, 0, 0); for (uint i = 0; i < 4; i++) { - if (FEATURE_BLOCKS_COUNT * SUB_GROUP_SIZE * 4 + sglid * 4 + i < INPUT0_FEATURE_NUM) { - input_data_u[i] = input[input_idx + (sglid * 4 + i) * INPUT0_FEATURE_PITCH]; + if (FEATURE_BLOCKS_COUNT * SUB_GROUP_SIZE * 4 + sglid * 4 + i < IN_FEATURE_NUM) { + input_data_u[i] = input[input_idx + (sglid * 4 + i) * FEATURE_PITCH]; } } INPUT_PACKED_TYPE input_data = AS_TYPE(INPUT_PACKED_TYPE, input_data_u); @@ -269,7 +270,7 @@ KERNEL(fully_connected_gpu_MMAD)( } #endif // HAS_FEATURE_LEFTOVERS - if (OUTPUT_FEATURE_NUM % SUB_GROUP_SIZE != 0 && feature >= OUTPUT_FEATURE_NUM) + if (OUT_FEATURE_NUM % SUB_GROUP_SIZE != 0 && feature >= OUT_FEATURE_NUM) return; #if BIAS_TERM @@ -284,7 +285,11 @@ KERNEL(fully_connected_gpu_MMAD)( float dequantized = (float)dotProd; #endif // BIAS_TERM +#if IS_3D + const uint out_idx = OUTPUT_GET_INDEX(batch, skip_f, feature, 0); +#else const uint out_idx = OUTPUT_GET_INDEX(batch, feature, 0, 0); +#endif #if HAS_FUSED_OPS FUSED_OPS; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_bf_tiled.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_bf_tiled.cl index 4879fcf6e11..e0705cdfcd3 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_bf_tiled.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_bf_tiled.cl @@ -67,13 +67,27 @@ #define MAX(a, b) ((a) > (b) ? (a) : (b)) // Check alignment restrictions for using block writes on output. -#define USE_BLOCK_WRITE ((OUTPUT_TYPE_SIZE * OUTPUT_BATCH_PITCH) % 16 == 0 && (OUTPUT_TYPE_SIZE * OUTPUT_OFFSET) % 16 == 0) +#define USE_BLOCK_WRITE ((OUTPUT_TYPE_SIZE * TILE_OUT_B_PITCH) % 16 == 0 && (OUTPUT_TYPE_SIZE * OUTPUT_OFFSET) % 16 == 0) #if !REALIGN_FP16_OFFSET -# define MAIN_LOOP_ELEMENTS_COUNT INPUT0_ELEMENTS_COUNT +# if OUTPUT_3D +# define MAIN_LOOP_ELEMENTS_COUNT INPUT0_SIZE_Y +# else +# define MAIN_LOOP_ELEMENTS_COUNT INPUT0_ELEMENTS_COUNT +# endif #else // For REALIGN_FP16_OFFSET one feature is processed separately before entering main loop to correct alignment. -# define MAIN_LOOP_ELEMENTS_COUNT (INPUT0_ELEMENTS_COUNT - 1) +# if OUTPUT_3D +# define MAIN_LOOP_ELEMENTS_COUNT (INPUT0_SIZE_Y - 1) +# else +# define MAIN_LOOP_ELEMENTS_COUNT (INPUT0_ELEMENTS_COUNT - 1) +# endif +#endif + +#if OUTPUT_3D +# define INPUT_ELEMENTS_COUNT INPUT0_SIZE_Y +#else +# define INPUT_ELEMENTS_COUNT INPUT0_ELEMENTS_COUNT #endif __attribute__((intel_reqd_sub_group_size(SIMD))) @@ -97,25 +111,24 @@ KERNEL(fc)( // full dispatch pipeline. uint feature_mini_block = gid % DISPATCH_FSV; uint batch_mini_block = gid / DISPATCH_FSV % DISPATCH_BSV; - uint feature_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV) % (CEIL_DIV(OUTPUT_FEATURE_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); - uint batch_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(OUTPUT_FEATURE_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); + uint feature_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV) % (CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); + uint batch_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); uint out_f = (feature_mega_block * DISPATCH_FSV + feature_mini_block) * (TILE_OFM * SIMD); - uint out_b = (batch_mega_block * DISPATCH_BSV + batch_mini_block) * TILE_B; + uint out_b = ((batch_mega_block * DISPATCH_BSV + batch_mini_block) * TILE_B); ACCUMULATOR_VEC_TYPE acc[TILE_B] = { }; INPUT_VEC_TYPE in_0[TILE_B] = { }; FILTER_VEC_TYPE wei = 0; - - uint weights_offset = out_f * INPUT0_ELEMENTS_COUNT; - uint input_offset = out_b * INPUT0_BATCH_PITCH + INPUT0_OFFSET; + uint input_offset = out_b * TILE_IN_B_PITCH + INPUT0_OFFSET; + uint weights_offset = out_f * INPUT_ELEMENTS_COUNT; #if REALIGN_FP16_OFFSET // For fp16 we need to ensure that all block reads are aligned to 4 byte (2 words) boundary. // To do this solve first input feature separately. { - INPUT0_TYPE tmp_input = input[input_offset + get_sub_group_local_id() % TILE_B * INPUT0_BATCH_PITCH]; + INPUT0_TYPE tmp_input = input[input_offset + get_sub_group_local_id() % TILE_B * TILE_IN_B_PITCH]; MAKE_VECTOR_TYPE(FILTER_TYPE, TILE_OFM) tmp_wei = BLOCK_READN(FILTER_TYPE, TILE_OFM, weights, weights_offset); __attribute__((opencl_unroll_hint)) @@ -135,13 +148,12 @@ KERNEL(fc)( // Load input. #define LOAD_IN_0(bi) do { \ in_0[bi] = INPUT_BLOCK_READ(input, input_offset); \ - input_offset += INPUT0_BATCH_PITCH; \ + input_offset += TILE_IN_B_PITCH; \ } while (false) CONST_LOOP(TILE_B, LOAD_IN_0); #undef LOAD_IN_0 - input_offset += TILE_IFM * SIMD - INPUT0_BATCH_PITCH * TILE_B; - + input_offset += TILE_IFM * SIMD - TILE_IN_B_PITCH * TILE_B; // NOTE: Manually unrolling multiplication loop leads to lower register pressure and allows for bigger block sizes, // but significantly degrades readability and generality of code. // It doesn't also show noticable performance improvement on tested configurations. @@ -172,13 +184,12 @@ KERNEL(fc)( { #define LOAD_IN_0(bi) do { \ in_0[bi] = INPUT_BLOCK_READ(input, input_offset); \ - input_offset += INPUT0_BATCH_PITCH; \ + input_offset += TILE_IN_B_PITCH; \ } while (false) CONST_LOOP(TILE_B, LOAD_IN_0); #undef LOAD_IN_0 - input_offset += TILE_IFM * SIMD - INPUT0_BATCH_PITCH * TILE_B; - + input_offset += TILE_IFM * SIMD - TILE_IN_B_PITCH * TILE_B; __attribute__((opencl_unroll_hint)) for (uint ki = 0; ki < CEIL_DIV(LEFTOVER_IFM, TILE_K); ++ki) { wei = FILTER_BLOCK_READ(weights, weights_offset); @@ -210,7 +221,7 @@ KERNEL(fc)( } #if BIAS_TERM - #if OUTPUT_FEATURE_NUM % (TILE_OFM * SIMD) == 0 + #if TILE_OUT_F_NUM % (TILE_OFM * SIMD) == 0 BIAS_VEC_TYPE bias = BIAS_BLOCK_READ(biases, out_f); #else BIAS_VEC_TYPE bias = 0; @@ -242,12 +253,12 @@ KERNEL(fc)( #endif // ===================================================================================================================================== // Write results - uint output_offset = out_f * OUTPUT_FEATURE_PITCH + out_b * OUTPUT_BATCH_PITCH + OUTPUT_OFFSET; + uint output_offset = out_f * TILE_OUT_F_PITCH + out_b * TILE_OUT_B_PITCH + OUTPUT_OFFSET; - if (USE_BLOCK_WRITE && (OUTPUT_FEATURE_NUM % (TILE_OFM * SIMD) == 0 || out_f + (TILE_OFM * SIMD) <= OUTPUT_FEATURE_NUM)) { + if (USE_BLOCK_WRITE && (TILE_OUT_F_NUM % (TILE_OFM * SIMD) == 0 || out_f + (TILE_OFM * SIMD) <= TILE_OUT_F_NUM)) { #define WRITE_OUTPUT(bi) do { \ OUTPUT_BLOCK_WRITE(output, output_offset, result[bi]); \ - output_offset += OUTPUT_BATCH_PITCH; \ + output_offset += TILE_OUT_B_PITCH; \ } while (false) CONST_LOOP(TILE_B, WRITE_OUTPUT); @@ -258,8 +269,8 @@ KERNEL(fc)( // TODO: Investigate why below code doesn't compile and check how it affects performance. //#define WRITE_OUTPUT_FEATURE(fi) do { \ // const bool should_write = \ - // OUTPUT_FEATURE_NUM % (TILE_OFM * SIMD) == 0 || \ - // out_f + (fi) * SIMD + get_sub_group_local_id() < OUTPUT_FEATURE_NUM; \ + // TILE_OUT_F_NUM % (TILE_OFM * SIMD) == 0 || \ + // out_f + (fi) * SIMD + get_sub_group_local_id() < TILE_OUT_F_NUM; \ // if (should_write) { \ // output[output_offset] = result[out_bi][fi]; \ // } \ @@ -269,7 +280,7 @@ KERNEL(fc)( //#define WRITE_OUTPUT(bi) do { \ // const uint out_bi = bi; \ // CONST_LOOP(TILE_OFM, WRITE_OUTPUT_FEATURE); \ - // output_offset += OUTPUT_BATCH_PITCH - TILE_OFM * SIMD; \ + // output_offset += TILE_OUT_B_PITCH - TILE_OFM * SIMD; \ // } while (false) // //CONST_LOOP(TILE_B, WRITE_OUTPUT); @@ -279,14 +290,14 @@ KERNEL(fc)( for (uint bi = 0; bi < TILE_B; ++bi) { for (uint fi = 0; fi < TILE_OFM; ++fi) { const bool should_write = - OUTPUT_FEATURE_NUM % (TILE_OFM * SIMD) == 0 || - out_f + fi * SIMD + get_sub_group_local_id() < OUTPUT_FEATURE_NUM; + TILE_OUT_F_NUM % (TILE_OFM * SIMD) == 0 || + out_f + fi * SIMD + get_sub_group_local_id() < TILE_OUT_F_NUM; if (should_write) { output[output_offset] = ((OUTPUT_TYPE*)(&result[bi]))[fi]; } output_offset += SIMD; } - output_offset += OUTPUT_BATCH_PITCH - TILE_OFM * SIMD; + output_offset += TILE_OUT_B_PITCH - TILE_OFM * SIMD; } } // ===================================================================================================================================== diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_bfyx_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_bfyx_ref.cl index 00d2722f116..c8a6076b15f 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_bfyx_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/fully_connected_gpu_bfyx_ref.cl @@ -27,10 +27,31 @@ KERNEL(fc)( #endif ) { +#if OUTPUT_3D + const uint oxfm = get_global_id(0); + const uint b = get_global_id(1); + const uint oym = oxfm % OUTPUT_SIZE_Y; + const uint ofm = oxfm / OUTPUT_SIZE_Y; + + ACCUMULATOR_TYPE dotProd = ACCUMULATOR_VAL_ZERO; + + for (uint y = 0; y < INPUT0_SIZE_Y; ++y) + { + for(uint x = 0; x < INPUT0_SIZE_X; ++x ) + { + const uint input0_idx = GET_DATA_INDEX(INPUT0, b, ofm, y, x); + const uint filter_idx = GET_FILTER_INDEX(FILTER, 0, oym, y, 0, 0); + dotProd += (ACCUMULATOR_TYPE)(input[input0_idx] * weights[filter_idx]); + } + } + + const uint dst_index = GET_DATA_INDEX(OUTPUT, b, ofm, oym, 0); + const uint bias_index = oym; +#else const uint ofm = get_global_id(0); const uint b = get_global_id(1); - ACCUMULATOR_TYPE dotProd = (ACCUMULATOR_TYPE)0; + ACCUMULATOR_TYPE dotProd = ACCUMULATOR_VAL_ZERO; for (uint ifm = 0; ifm < INPUT0_FEATURE_NUM; ++ifm) { @@ -46,9 +67,10 @@ KERNEL(fc)( } const uint dst_index = GET_DATA_INDEX(OUTPUT, b, ofm, 0, 0); + const uint bias_index = ofm; +#endif #if BIAS_TERM - const uint bias_index = ofm; ACTIVATION_TYPE dequantized = dotProd + biases[bias_index]; #else ACTIVATION_TYPE dequantized = dotProd; diff --git a/inference-engine/thirdparty/clDNN/src/fully_connected.cpp b/inference-engine/thirdparty/clDNN/src/fully_connected.cpp index 040d631d252..13c4cbdb6c7 100644 --- a/inference-engine/thirdparty/clDNN/src/fully_connected.cpp +++ b/inference-engine/thirdparty/clDNN/src/fully_connected.cpp @@ -54,6 +54,10 @@ bool is_batch_after_spatial(const std::string order) { format::type get_preferred_format(const fully_connected_node& node) { auto input_layout = node.input().get_output_layout(); + // for 3d output we have to chose bfyx format + if (node.get_primitive()->input_size == 3) + return format::bfyx; + if (data_type_traits::is_floating_point(input_layout.data_type) && (is_batch_after_spatial(input_layout.format.order()) || input_layout.format == format::bs_x_bsv16 || @@ -107,6 +111,9 @@ layout fully_connected_inst::calc_output_layout(fully_connected_node const& node } auto output_size = tensor(input_layout.size.batch[0], weights_layout.size.batch[0], 1, 1); + if (desc->input_size == 3) { + output_size = tensor(input_layout.size.batch[0], input_layout.size.feature[0], 1, weights_layout.size.batch[0]); + } format output_format = get_preferred_format(node); return layout(output_type, output_format, output_size); diff --git a/inference-engine/thirdparty/clDNN/src/gpu/fully_connected_gpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/fully_connected_gpu.cpp index 069501d11f8..677e93224fe 100644 --- a/inference-engine/thirdparty/clDNN/src/gpu/fully_connected_gpu.cpp +++ b/inference-engine/thirdparty/clDNN/src/gpu/fully_connected_gpu.cpp @@ -57,10 +57,11 @@ public: arg.get_program()); fc_optional_params.allowInputReordering = true; - fc_params.output = fc_params.output.FlattenFeatureAndSpatials(); - const auto primitive = arg.get_primitive(); + if (primitive->input_size != 3) + fc_params.output = fc_params.output.FlattenFeatureAndSpatials(); + if (arg.get_output_layout().data_type == data_types::i8 || arg.get_output_layout().data_type == data_types::u8) { fc_params.quantization = kernel_selector::QuantizationType::SYMMETRIC; diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fully_connected_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fully_connected_gpu_test.cpp index f0bf5fc59d8..e129a31d908 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fully_connected_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fully_connected_gpu_test.cpp @@ -1150,6 +1150,109 @@ INSTANTIATE_TEST_CASE_P(smoke_bfyx_batched, ::testing::Values(format::bfyx), ::testing::Values("")), ); + +template +struct fully_connected_random_test_3d : ::testing::TestWithParam { + void run_test() { + size_t batch, input_f, input_x, input_y, output_y; + format::type input_format, output_format; + std::string kernel; + + std::tie(batch, input_f, input_x, input_y, output_y, input_format, output_format, kernel) = GetParam(); + + auto input_data = generate_smart_random_4d(batch, input_f, input_y, input_x); + auto weights_data = generate_smart_random_4d(output_y, input_y, 1, 1); + auto bias_data = generate_smart_random_2d(1, output_y); + + auto eng = get_test_engine(); + auto net = network_test(eng); + auto input = net.add_input_layout("input", input_format, std::move(input_data)); + auto weights = net.add_data("weights", format::oiyx, std::move(weights_data)); + auto bias = net.add_data("bias", format::bfyx, std::move(bias_data)); + auto fc = net.add_fully_connected_3d("fc", input, weights, bias, implementation_desc{ output_format, kernel }, 3); + + net.run(build_options(build_option::optimize_data(true))); + } +}; + + +using fully_connected_random_test_f32_3d = fully_connected_random_test_3d; +using fully_connected_random_test_f16_3d = fully_connected_random_test_3d; +using fully_connected_random_test_i8_3d = fully_connected_random_test_3d; + +TEST_P(fully_connected_random_test_f32_3d, basic) { + run_test(); +} + +INSTANTIATE_TEST_CASE_P(smoke, + fully_connected_random_test_f32_3d, + ::testing::Combine( + ::testing::Values(1,3), + ::testing::Values(1,3), + ::testing::Values(1), + ::testing::Values(1,3,16), + ::testing::Values(1,3,16), + ::testing::Values(format::bfyx), + ::testing::Values(format::any), + ::testing::Values("")), ); + +INSTANTIATE_TEST_CASE_P(smoke_big, + fully_connected_random_test_f32_3d, + ::testing::Combine( + ::testing::Values(3), + ::testing::Values(16, 17, 32), + ::testing::Values(1), + ::testing::Values(17, 32), + ::testing::Values(17, 32), + ::testing::Values(format::bfyx), + ::testing::Values(format::any), + ::testing::Values("")), ); + +TEST_P(fully_connected_random_test_f16_3d, basic) { + run_test(); +} + +INSTANTIATE_TEST_CASE_P(smoke, + fully_connected_random_test_f16_3d, + ::testing::Combine( + ::testing::Values(1,3), + ::testing::Values(1,3), + ::testing::Values(1), + ::testing::Values(1,3,16), + ::testing::Values(1,3,16), + ::testing::Values(format::bfyx), + ::testing::Values(format::any), + ::testing::Values("")), ); + +TEST_P(fully_connected_random_test_i8_3d, basic) { + run_test(); +} + +INSTANTIATE_TEST_CASE_P(smoke, + fully_connected_random_test_i8_3d, + ::testing::Combine( + ::testing::Values(1,3), + ::testing::Values(1,3), + ::testing::Values(1), + ::testing::Values(1,3,16), + ::testing::Values(1,3,16), + ::testing::Values(format::bfyx), + ::testing::Values(format::any), + ::testing::Values("")), ); + +INSTANTIATE_TEST_CASE_P(smoke_big, + fully_connected_random_test_i8_3d, + ::testing::Combine( + ::testing::Values(1,3), + ::testing::Values(16,17), + ::testing::Values(1), + ::testing::Values(17, 32), + ::testing::Values(17, 32), + ::testing::Values(format::bfyx), + ::testing::Values(format::any), + ::testing::Values("")), ); + + struct quantization_t { VF input_low; VF input_high; diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 6d043cf5d5a..aee13eed135 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -368,6 +368,38 @@ public: layout get_per_channel_layout(T& p) { return layout{ p.default_type, p.default_format, tensor{1, p.out_shape.feature[0], 1, 1} }; } + + size_t get_fc_output_dim_size(bc_test_params& p) { + size_t size = 2; + for (auto i : p.out_shape.spatial) { + if (i > 1) + size++; + } + return size; + } + + layout get_fc_weights_layout(T& p) { + cldnn::tensor weights_tensor; + if (p.out_shape.spatial[1] > 1) { + // 3d case + weights_tensor = cldnn::tensor(p.kernel.batch[0], p.kernel.feature[0], 1, 1); + } + else { + weights_tensor = cldnn::tensor(batch(p.out_shape.feature[0]), feature(p.in_shape.feature[0]), + spatial(p.kernel.spatial[0], p.kernel.spatial[1], p.kernel.spatial[2])); + } + return layout{p.weights_type, p.weights_format, weights_tensor}; + } + + layout get_fc_bias_layout(T& p) { + if (p.out_shape.spatial[1] > 1) { + // 3d case + return layout{ p.default_type, format::bfyx, tensor{1, p.out_shape.spatial[1], 1, 1} }; + } + else { + return layout{ p.default_type, format::bfyx, tensor{1, p.out_shape.feature[0], 1, 1} }; + } + } }; class ResamplePrimitiveFusingTest : public ::BaseFusingTest { @@ -557,10 +589,16 @@ public: #define CASE_FC_FP32_1 {1, 1, 3, 1}, {1, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx #define CASE_FC_FP32_2 {2, 1, 3, 1}, {2, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::yxfb, data_types::f32, format::oiyx, data_types::f32, format::bfyx #define CASE_FC_FP32_3 {2, 32, 1, 1}, {2, 16, 1, 1}, {16, 32, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP32_3D_1 {5, 3, 1, 3}, {5, 3, 1, 5}, {5, 3, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx +#define CASE_FC_FP32_3D_2 {2, 1, 1, 1}, {2, 1, 1, 32}, {32, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx +#define CASE_FC_FP32_3D_3 {2, 32, 1, 32}, {2, 32, 1, 16}, {16, 32, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx #define CASE_FC_U8S8_1 {1, 1, 3, 1}, {1, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx #define CASE_FC_U8S8_2 {2, 1, 3, 1}, {2, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx #define CASE_FC_U8S8_3 {2, 32, 1, 1}, {2, 16, 1, 1}, {16, 32, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_U8S8_3D_1 {2, 32, 1, 3}, {2, 32, 1, 16}, {16, 3, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_U8S8_3D_2 {1, 1, 1, 3}, {1, 1, 1, 32}, {32, 3, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_U8S8_3D_3 {2, 3, 1, 1}, {2, 3, 1, 15}, {15, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx #define CASE_NORMALIZE_I8_1 {1, 2, 3, 3}, data_types::u8, format::bfyx, data_types::f32, format::bfyx @@ -2475,9 +2513,9 @@ class fc_fp32_activation : public FCFusingTest {}; TEST_P(fc_fp32_activation, basic) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), - data("weights", get_mem(get_weights_layout(p))), - data("bias", get_mem(get_bias_layout(p))), - fully_connected("fc_prim", "input", "weights", "bias"), + data("weights", get_mem(get_fc_weights_layout(p))), + data("bias", get_mem(get_fc_bias_layout(p))), + fully_connected("fc_prim", "input", "weights", "bias", padding(), get_fc_output_dim_size(p)), activation("activation", "fc_prim", activation_func::abs), reorder("reorder_bfyx", "activation", p.default_format, data_types::f32) ); @@ -2490,14 +2528,17 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, fc_fp32_activation, ::testing::ValuesIn(std bc_test_params{ CASE_FC_FP32_1, 2, 3 }, bc_test_params{ CASE_FC_FP32_2, 2, 3 }, bc_test_params{ CASE_FC_FP32_3, 2, 3 }, + bc_test_params{ CASE_FC_FP32_3D_1, 2, 3 }, + bc_test_params{ CASE_FC_FP32_3D_2, 2, 3 }, + bc_test_params{ CASE_FC_FP32_3D_3, 2, 3 }, }), ); class fc_fp32_bias : public FCFusingTest {}; TEST_P(fc_fp32_bias, basic) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), - data("weights", get_mem(get_weights_layout(p))), - data("bias", get_mem(get_bias_layout(p))), + data("weights", get_mem(get_fc_weights_layout(p))), + data("bias", get_mem(get_fc_bias_layout(p))), fully_connected("fc_prim", "input", "weights", ""), eltwise("bias_add", {"fc_prim", "bias"}, eltwise_mode::sum), reorder("reorder_bfyx", "bias_add", p.default_format, data_types::f32) @@ -2517,10 +2558,10 @@ class fc_int8_scale : public FCFusingTest {}; TEST_P(fc_int8_scale, basic) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), - data("weights", get_mem(get_weights_layout(p))), - data("bias", get_mem(get_bias_layout(p))), + data("weights", get_mem(get_fc_weights_layout(p))), + data("bias", get_mem(get_fc_bias_layout(p))), data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count())), - fully_connected("fc_prim", "input", "weights", "bias", data_types::f32), + fully_connected("fc_prim", "input", "weights", "bias", data_types::f32, padding(), get_fc_output_dim_size(p)), scale("scale", "fc_prim", "scale_data"), reorder("reorder_bfyx", "scale", p.default_format, data_types::f32) ); @@ -2532,10 +2573,10 @@ TEST_P(fc_int8_scale, basic) { TEST_P(fc_int8_scale, fp16_scale_out) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), - data("weights", get_mem(get_weights_layout(p))), - data("bias", get_mem(get_bias_layout(p))), + data("weights", get_mem(get_fc_weights_layout(p))), + data("bias", get_mem(get_fc_bias_layout(p))), data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count())), - fully_connected("fc_prim", "input", "weights", "bias", data_types::f32), + fully_connected("fc_prim", "input", "weights", "bias", data_types::f32, padding(), get_fc_output_dim_size(p)), scale("scale", "fc_prim", "scale_data", optional_data_type{data_types::f16}), reorder("reorder_bfyx", "scale", p.default_format, data_types::f32) ); @@ -2549,19 +2590,22 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, fc_int8_scale, bc_test_params{ CASE_FC_U8S8_1, 2, 3 }, bc_test_params{ CASE_FC_U8S8_2, 2, 3 }, bc_test_params{ CASE_FC_U8S8_3, 2, 3 }, + bc_test_params{ CASE_FC_U8S8_3D_1, 2, 3 }, + bc_test_params{ CASE_FC_U8S8_3D_2, 2, 3 }, + bc_test_params{ CASE_FC_U8S8_3D_3, 2, 3 }, }), ); class fc_int8_quantize_u8 : public FCFusingTest {}; TEST_P(fc_int8_quantize_u8, basic) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), - data("weights", get_mem(get_weights_layout(p))), - data("bias", get_mem(get_bias_layout(p))), + data("weights", get_mem(get_fc_weights_layout(p))), + data("bias", get_mem(get_fc_bias_layout(p))), data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), data("out_lo", get_mem(get_single_element_layout(p), 0)), data("out_hi", get_mem(get_single_element_layout(p), 255)), - fully_connected("fc_prim", "input", "weights", "bias", data_types::f32), + fully_connected("fc_prim", "input", "weights", "bias", data_types::f32, padding(), get_fc_output_dim_size(p)), quantize("quantize", "fc_prim", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8), reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) ); @@ -2575,20 +2619,23 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu_fc, fc_int8_quantize_u8, bc_test_params{CASE_FC_U8S8_1, 2, 3}, bc_test_params{CASE_FC_U8S8_2, 2, 3}, bc_test_params{CASE_FC_U8S8_3, 2, 3}, + bc_test_params{ CASE_FC_U8S8_3D_1, 2, 3 }, + bc_test_params{ CASE_FC_U8S8_3D_2, 2, 3 }, + bc_test_params{ CASE_FC_U8S8_3D_3, 2, 3 }, }), ); class fc_int8_scale_quantize_i8 : public FCFusingTest {}; TEST_P(fc_int8_scale_quantize_i8, basic) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), - data("weights", get_mem(get_weights_layout(p))), - data("bias", get_mem(get_bias_layout(p))), + data("weights", get_mem(get_fc_weights_layout(p))), + data("bias", get_mem(get_fc_bias_layout(p))), data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), data("out_lo", get_mem(get_single_element_layout(p), -127)), data("out_hi", get_mem(get_single_element_layout(p), 127)), data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)), - fully_connected("fc_prim", "input", "weights", "bias", data_types::f32), + fully_connected("fc_prim", "input", "weights", "bias", data_types::f32, padding(), get_fc_output_dim_size(p)), scale("scale", "fc_prim", "scale_data"), quantize("quantize", "scale", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8), reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) @@ -2602,6 +2649,9 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, fc_int8_scale_quantize_i8, bc_test_params{CASE_FC_U8S8_1, 2, 4}, bc_test_params{CASE_FC_U8S8_2, 2, 4}, bc_test_params{CASE_FC_U8S8_3, 2, 4}, + bc_test_params{ CASE_FC_U8S8_3D_1, 2, 4 }, + bc_test_params{ CASE_FC_U8S8_3D_2, 2, 4 }, + bc_test_params{ CASE_FC_U8S8_3D_3, 2, 4 }, }), ); @@ -2610,14 +2660,14 @@ class fc_int8_scale_activation_quantize_i8 : public FCFusingTest {}; TEST_P(fc_int8_scale_activation_quantize_i8, basic) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), - data("weights", get_mem(get_weights_layout(p))), - data("bias", get_mem(get_bias_layout(p))), + data("weights", get_mem(get_fc_weights_layout(p))), + data("bias", get_mem(get_fc_bias_layout(p))), data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), data("out_lo", get_mem(get_single_element_layout(p), -127)), data("out_hi", get_mem(get_single_element_layout(p), 127)), data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)), - fully_connected("fc_prim", "input", "weights", "bias", data_types::f32), + fully_connected("fc_prim", "input", "weights", "bias", data_types::f32, padding(), get_fc_output_dim_size(p)), scale("scale", "fc_prim", "scale_data"), activation("activation_scale", "scale", activation_func::exp), quantize("quantize", "activation_scale", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8), @@ -2633,6 +2683,14 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, fc_int8_scale_activation_quantize_i8, bc_test_params{CASE_FC_U8S8_1, 2, 5}, bc_test_params{CASE_FC_U8S8_2, 2, 5}, bc_test_params{CASE_FC_U8S8_3, 2, 5}, + + bc_test_params{ CASE_FC_U8S8_3D_1, 2, 5 }, + bc_test_params{ CASE_FC_U8S8_3D_2, 2, 5 }, + bc_test_params{ CASE_FC_U8S8_3D_3, 2, 5 }, + + bc_test_params{ CASE_FC_FP32_3D_1, 3, 5 }, + bc_test_params{ CASE_FC_FP32_3D_2, 3, 5 }, + bc_test_params{ CASE_FC_FP32_3D_3, 3, 5 }, }), ); diff --git a/inference-engine/thirdparty/clDNN/tests/test_utils/network_test.h b/inference-engine/thirdparty/clDNN/tests/test_utils/network_test.h index 1848a32ad09..74cd673fde8 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_utils/network_test.h +++ b/inference-engine/thirdparty/clDNN/tests/test_utils/network_test.h @@ -144,7 +144,6 @@ template struct reference_tensor_typed : reference_tensor { using vector_type = VVVVF; reference_tensor_typed(vector_type data) : reference(std::move(data)) {} - void compare(cldnn::memory actual) override { auto ptr = actual.pointer(); for (size_t bi = 0; bi < reference.size(); ++bi) { @@ -231,6 +230,36 @@ VVF fully_connected_reference_typed(VVVVF& input, VVVVF::type> +VVVVF fully_connected_reference_typed_3d(VVVVF& input, VVVVF& weights, VF& bias) { + size_t input_f = input[0].size(); + size_t input_y = input[0][0].size(); + size_t input_x = input[0][0][0].size(); + size_t output_b = input.size(); // input is assumed to be bfyx + size_t output_f = weights.size(); // weights is assumed to be bfyx + size_t weights_f = weights[0].size(); // weights is assumed to be bfyx + VVVVF output(output_b, VVVF(input_f, VVF(output_f, VF(1)))); + OutputT res; + for (size_t b = 0; b < output_b; ++b) { + for (size_t n = 0; n < input_f; ++n) { + for (size_t f = 0; f < output_f; ++f) { + res = bias[f]; + for (size_t y = 0; y < input_y; ++y) { + for (size_t x = 0; x < input_x; ++x) { + res += (OutputT)input[b][n][y][x] * (OutputT)weights[f][y][0][0]; + } + } + output[b][n][f][0] = (OutputT)res; + } + } + } + return output; +} + // ===================================================================================================================== // Network test struct reference_node_interface { @@ -300,6 +329,22 @@ public: return add_node(id, reference_tensor_typed(output_data), { input, weights, bias }); } + template + typename reference_node::ptr add_fully_connected_3d(cldnn::primitive_id id, + std::shared_ptr> input, + std::shared_ptr> weights, + std::shared_ptr> bias, + cldnn::implementation_desc force = cldnn::implementation_desc{cldnn::format::any, ""}, + size_t input_dim_size = 3) { + topo.add(cldnn::fully_connected(id, input->id, weights->id, bias->id, cldnn::type_to_data_type::value, cldnn::padding(), input_dim_size)); + if (force.output_format != cldnn::format::any || force.kernel_name != "") + forced_impls[id] = force; + VVVVF output_data = fully_connected_reference_typed_3d(input->reference.reference, + weights->reference.reference, + bias->reference.reference[0]); + return add_node(id, reference_tensor_typed(output_data), {input, weights, bias}); + } + cldnn::network build_network(cldnn::build_options opts) { opts.set_option(cldnn::build_option::force_implementations(forced_impls)); auto net = cldnn::network(eng, topo, opts);