From 2100521a145e240d8c61e408c398e06ac8554042 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Fri, 5 Jun 2020 10:16:27 +0300 Subject: [PATCH] [IE CLDNN] Implement NormalizeL2 int8 kernels (#720) --- .../normalize_kernel_across_spatial_ref.cpp | 9 +- .../normalize/normalize_kernel_base.cpp | 42 ++- .../normalize/normalize_kernel_base.h | 12 +- .../normalize_kernel_within_spatial_ref.cpp | 9 +- .../normalize_gpu_across_spatial_ref.cl | 30 +- .../normalize_gpu_within_spatial_ref.cl | 32 +- .../clDNN/src/gpu/normalize_gpu.cpp | 14 +- .../prepare_primitive_fusing.cpp | 15 +- .../thirdparty/clDNN/src/normalize.cpp | 11 +- .../tests/test_cases/fusings_gpu_test.cpp | 127 ++++++-- .../tests/test_cases/normalizel2_gpu_test.cpp | 290 ++++++++++++++++++ 11 files changed, 525 insertions(+), 66 deletions(-) create mode 100644 inference-engine/thirdparty/clDNN/tests/test_cases/normalizel2_gpu_test.cpp diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_across_spatial_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_across_spatial_ref.cpp index 2cb1d03643f..fdc116648f6 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_across_spatial_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_across_spatial_ref.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2016 Intel Corporation +// Copyright (c) 2016-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,14 +20,19 @@ ParamsKey NormalizeKernelAcrossSpatialRef::GetSupportedKey() const { ParamsKey k; k.EnableInputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::UINT8); k.EnableInputLayout(DataLayout::bfyx); k.EnableInputLayout(DataLayout::yxfb); k.EnableInputLayout(DataLayout::byxf); k.EnableOutputLayout(DataLayout::bfyx); k.EnableOutputLayout(DataLayout::yxfb); k.EnableOutputLayout(DataLayout::byxf); + k.EnableDifferentTypes(); k.EnableTensorOffset(); k.EnableTensorPitches(); k.EnableBatching(); @@ -39,4 +44,4 @@ KernelsData NormalizeKernelAcrossSpatialRef::GetKernelsData(const Params& params const optional_params& optParams) const { return GetCommonKernelsData(params, optParams, FORCE_PRIORITY_9); } -} // namespace kernel_selector \ No newline at end of file +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_base.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_base.cpp index 43e725eef2a..2f1d5eab875 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_base.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_base.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2016 Intel Corporation +// Copyright (c) 2016-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,6 +28,14 @@ JitConstants NormalizeKernelBase::GetJitConstants(const normalize_params& np) co MakeJitConstant("THRESHOLD", 0.0001f), }); + auto activation_dt = GetActivationType(np); + jit.Merge(MakeTypeJitConstants(activation_dt, "ACTIVATION")); + if (!np.fused_ops.empty()) { + std::vector idx_order = { "b", "f", "y", "x" }; + auto conf = FusedOpsConfiguration("", idx_order, "result", activation_dt); + jit.Merge(MakeFusedOpsJitConstants(np, { conf })); + } + return jit; } @@ -63,6 +71,8 @@ KernelsData NormalizeKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, float estimated_time) const { assert(params.GetType() == KernelType::NORMALIZE); + if (!Validate(params, options)) + return {}; const normalize_params& orgParams = static_cast(params); @@ -77,11 +87,39 @@ KernelsData NormalizeKernelBase::GetCommonKernelsData(const Params& params, auto jit = CreateJit(kernelName, cldnn_jit, entry_point); auto& kernel = kd.kernels[0]; - FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point); + FillCLKernelData(kernel, + runInfo, + params.engineInfo, + kernelName, + jit, + entry_point, + "", + false, + false, + 1, + GetFusedPrimitiveInputsCount(params)); + kernel.arguments.push_back({ArgumentDescriptor::Types::SCALE_TABLE, 0}); kd.estimatedTime = estimated_time; return {kd}; } + +bool NormalizeKernelBase::Validate(const Params& params, const optional_params&) const { + const normalize_params& orgParams = static_cast(params); + + for (auto& fused_op : orgParams.fused_ops) { + if (!IsFusedPrimitiveSupported(fused_op)) + return false; + } + + return true; +} + +Datatype NormalizeKernelBase::GetActivationType(const normalize_params& params) const { + if (params.output.GetDType() == Datatype::F16) + return Datatype::F16; + return Datatype::F32; +} } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_base.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_base.h index 81e843afc04..043fa1b04af 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_base.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_base.h @@ -1,4 +1,4 @@ -// Copyright (c) 2016 Intel Corporation +// Copyright (c) 2016-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include "common_kernel_base.h" @@ -59,5 +58,12 @@ protected: JitConstants GetJitConstants(const normalize_params& params) const; DispatchData SetDefault(const normalize_params& params) const; KernelsData GetCommonKernelsData(const Params& params, const optional_params&, float estimated_time) const; + std::vector GetSupportedFusedOps() const override { + return { FusedOpType::QUANTIZE, + FusedOpType::ACTIVATION, + FusedOpType::SCALE }; + } + bool Validate(const Params& params, const optional_params&) const override; + Datatype GetActivationType(const normalize_params& params) const; }; -} // namespace kernel_selector \ No newline at end of file +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_within_spatial_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_within_spatial_ref.cpp index af8236f5b38..1000b3e130e 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_within_spatial_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/normalize/normalize_kernel_within_spatial_ref.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2016 Intel Corporation +// Copyright (c) 2016-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,14 +20,19 @@ ParamsKey NormalizeKernelWithinSpatialRef::GetSupportedKey() const { ParamsKey k; k.EnableInputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::UINT8); k.EnableInputLayout(DataLayout::bfyx); k.EnableInputLayout(DataLayout::yxfb); k.EnableInputLayout(DataLayout::byxf); k.EnableOutputLayout(DataLayout::bfyx); k.EnableOutputLayout(DataLayout::yxfb); k.EnableOutputLayout(DataLayout::byxf); + k.EnableDifferentTypes(); k.EnableTensorOffset(); k.EnableTensorPitches(); k.EnableBatching(); @@ -39,4 +44,4 @@ KernelsData NormalizeKernelWithinSpatialRef::GetKernelsData(const Params& params const optional_params& optParams) const { return GetCommonKernelsData(params, optParams, FORCE_PRIORITY_9); } -} // namespace kernel_selector \ No newline at end of file +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/normalize_gpu_across_spatial_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/normalize_gpu_across_spatial_ref.cl index 1fa78124b25..c90862f3f1c 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/normalize_gpu_across_spatial_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/normalize_gpu_across_spatial_ref.cl @@ -1,4 +1,4 @@ -// Copyright (c) 2016-2017 Intel Corporation +// Copyright (c) 2016-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,15 +15,14 @@ #include "include/common.cl" #include "include/data_types.cl" - -#if FP16_UNIT_USED - #define UNIT_CVT_FUNC(val) convert_half(val) -#else - #define UNIT_CVT_FUNC(val) (val) +KERNEL (normalize_gpu_across_spatial_bfyx)( + const __global INPUT0_TYPE* input, + __global OUTPUT_TYPE* output, +#if HAS_FUSED_OPS_DECLS + FUSED_OPS_DECLS, #endif - - -KERNEL (normalize_gpu_across_spatial_bfyx)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output, const __global UNIT_TYPE* scale_input) + const __global SCALE_TABLE_TYPE* scale_input + ) { const uint b = get_global_id(0); @@ -68,13 +67,19 @@ KERNEL (normalize_gpu_across_spatial_bfyx)(const __global UNIT_TYPE* input, __gl const uint scale_index = f; #else const uint scale_index = f % SCALE_TABLE_FEATURE_NUM; -#endif +#endif for (uint y = 0; y < INPUT0_SIZE_Y; y++) { for (uint x = 0; x < INPUT0_SIZE_X; x++) { - output[output_idx] = ACTIVATION(UNIT_CVT_FUNC(norm) * input[input_idx] * scale_input[scale_index], ACTIVATION_PARAMS); + ACTIVATION_TYPE result = TO_ACTIVATION_TYPE(norm) * TO_ACTIVATION_TYPE(input[input_idx]) * TO_ACTIVATION_TYPE(scale_input[scale_index]); +#if HAS_FUSED_OPS + FUSED_OPS; + output[output_idx] = FUSED_OPS_RESULT; +#else + output[output_idx] = TO_OUTPUT_TYPE(ACTIVATION(result, ACTIVATION_PARAMS)); +#endif input_idx += INPUT0_X_PITCH; output_idx += OUTPUT_X_PITCH; } @@ -85,6 +90,3 @@ KERNEL (normalize_gpu_across_spatial_bfyx)(const __global UNIT_TYPE* input, __gl output_idx += OUTPUT_FEATURE_PITCH - INPUT0_SIZE_Y*OUTPUT_Y_PITCH; } } - - -#undef UNIT_CVT_FUNC diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/normalize_gpu_within_spatial_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/normalize_gpu_within_spatial_ref.cl index 7f31af6238d..942e22962b1 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/normalize_gpu_within_spatial_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/normalize_gpu_within_spatial_ref.cl @@ -1,4 +1,4 @@ -// Copyright (c) 2016-2017 Intel Corporation +// Copyright (c) 2016-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,15 +15,14 @@ #include "include/common.cl" #include "include/data_types.cl" - -#if FP16_UNIT_USED - #define UNIT_CVT_FUNC(val) convert_half(val) -#else - #define UNIT_CVT_FUNC(val) (val) +KERNEL (normalize_gpu_within_spatial_bfyx)( + const __global INPUT0_TYPE* input, + __global OUTPUT_TYPE* output, +#if HAS_FUSED_OPS_DECLS + FUSED_OPS_DECLS, #endif - - -KERNEL (normalize_gpu_within_spatial_bfyx)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output, const __global UNIT_TYPE* scale_input) + const __global SCALE_TABLE_TYPE* scale_input + ) { const uint x = get_global_id(0); const uint y = get_global_id(1); @@ -40,7 +39,7 @@ KERNEL (normalize_gpu_within_spatial_bfyx)(const __global UNIT_TYPE* input, __gl norm = mad(value, value, norm); input_idx += INPUT0_FEATURE_PITCH; } - + uint output_idx = OUTPUT_OFFSET + b*OUTPUT_BATCH_PITCH + y*OUTPUT_Y_PITCH + x*OUTPUT_X_PITCH; if(norm <= THRESHOLD) @@ -62,13 +61,16 @@ KERNEL (normalize_gpu_within_spatial_bfyx)(const __global UNIT_TYPE* input, __gl const uint scale_index = f; #else const uint scale_index = f % SCALE_TABLE_FEATURE_NUM; -#endif +#endif - output[output_idx] = ACTIVATION(UNIT_CVT_FUNC(norm) * input[input_idx] * scale_input[scale_index], ACTIVATION_PARAMS); + ACTIVATION_TYPE result = TO_ACTIVATION_TYPE(norm) * TO_ACTIVATION_TYPE(input[input_idx]) * TO_ACTIVATION_TYPE(scale_input[scale_index]); +#if HAS_FUSED_OPS + FUSED_OPS; + output[output_idx] = FUSED_OPS_RESULT; +#else + output[output_idx] = TO_OUTPUT_TYPE(ACTIVATION(result, ACTIVATION_PARAMS)); +#endif output_idx += OUTPUT_FEATURE_PITCH; input_idx += INPUT0_FEATURE_PITCH; } } - - -#undef UNIT_CVT_FUNC diff --git a/inference-engine/thirdparty/clDNN/src/gpu/normalize_gpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/normalize_gpu.cpp index c221f45faa8..52d3b545341 100644 --- a/inference-engine/thirdparty/clDNN/src/gpu/normalize_gpu.cpp +++ b/inference-engine/thirdparty/clDNN/src/gpu/normalize_gpu.cpp @@ -1,5 +1,5 @@ /* -// Copyright (c) 2016 Intel Corporation +// Copyright (c) 2016-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -75,14 +75,26 @@ attach_normalize_gpu::attach_normalize_gpu() { normalize_gpu::create); implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), normalize_gpu::create); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), + normalize_gpu::create); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx), + normalize_gpu::create); implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), normalize_gpu::create); implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), normalize_gpu::create); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i8, format::yxfb), + normalize_gpu::create); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::u8, format::yxfb), + normalize_gpu::create); implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), normalize_gpu::create); implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), normalize_gpu::create); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i8, format::byxf), + normalize_gpu::create); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::u8, format::byxf), + normalize_gpu::create); } } // namespace detail diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp index 0fc778df7f6..d090b135d7f 100644 --- a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp @@ -359,6 +359,10 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type() && + (input_data.get_dependency(0).get_output_layout().data_type == data_types::u8 || + input_data.get_dependency(0).get_output_layout().data_type == data_types::i8); + should_fuse |= input_data.is_type(); should_fuse |= input_data.is_type(); @@ -396,6 +400,10 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type() && mvn_supports_fusings(input_data.as()); + should_fuse |= input_data.is_type() && + (input_data.get_dependency(0).get_output_layout().data_type == data_types::u8 || + input_data.get_dependency(0).get_output_layout().data_type == data_types::i8); + should_fuse |= input_data.is_type(); should_fuse |= input_data.is_type(); @@ -449,8 +457,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { quantize_node.get_scale_shift_opt() && (out_layout.data_type == data_types::u8 || out_layout.data_type == data_types::i8); - should_fuse |= input_data.is_type() && - quantize_node.get_scale_shift_opt(); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); should_fuse |= input_data.is_type() && gemm_supports_fusings(input_data.as()) && quantize_node.get_scale_shift_opt() && @@ -465,6 +472,10 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt() && + (input_data.get_dependency(0).get_output_layout().data_type == data_types::u8 || + input_data.get_dependency(0).get_output_layout().data_type == data_types::i8); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt() && // fp16/fp32 optimized kernels don't support chaning data type (input_data.get_dependency(0).get_output_layout().data_type == data_types::u8 || diff --git a/inference-engine/thirdparty/clDNN/src/normalize.cpp b/inference-engine/thirdparty/clDNN/src/normalize.cpp index d0f66cb7f11..1d6e82bd283 100644 --- a/inference-engine/thirdparty/clDNN/src/normalize.cpp +++ b/inference-engine/thirdparty/clDNN/src/normalize.cpp @@ -29,7 +29,16 @@ primitive_type_id normalize::type_id() { layout normalize_inst::calc_output_layout(normalize_node const& node) { assert(static_cast(node.get_primitive()->output_data_type) == false && "Output data type forcing is not supported for normalize_node!"); - return node.input().get_non_padded_output_layout(); + auto input_node_layout = node.input().get_non_padded_output_layout(); + auto output_type = input_node_layout.data_type; + + if (node.has_fused_primitives()) { + output_type = node.get_fused_output_layout().data_type; + } else if (input_node_layout.data_type == data_types::u8 || input_node_layout.data_type == data_types::i8) { + output_type = data_types::f32; + } + + return layout(output_type, input_node_layout.format, input_node_layout.size); } std::string normalize_inst::to_string(normalize_node const& node) { diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 86dd962d061..96335583376 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -85,6 +85,17 @@ struct gemm_test_params { size_t expected_not_fused_primitives; }; +struct normalize_test_params { + tensor in_shape; + data_types data_type; + format input_format; + data_types default_type; + format default_format; + bool across_spatial; + size_t expected_fused_primitives; + size_t expected_not_fused_primitives; +}; + template class BaseFusingTest : public ::testing::TestWithParam { public: @@ -430,6 +441,8 @@ public: #define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx #define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_NORMALIZE_I8_1 {1, 2, 3, 3}, data_types::u8, format::bfyx, data_types::f32, format::bfyx + /* ----------------------------------------------------------------------------------------------------- */ /* ---------------------------------------- FP32 convolution cases ------------------------------------- */ /* ----------------------------------------------------------------------------------------------------- */ @@ -4161,7 +4174,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_activation_scale_eltwise, permute_params{CASE_PERMUTE_F32_5, 2, 5}, permute_params{CASE_PERMUTE_F32_6, 2, 5}, permute_params{CASE_PERMUTE_F32_7, 2, 5}, - + permute_params{CASE_PERMUTE_F16_0, 2, 5}, permute_params{CASE_PERMUTE_F16_1, 2, 5}, permute_params{CASE_PERMUTE_F16_2, 2, 5}, @@ -4169,34 +4182,34 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_activation_scale_eltwise, permute_params{CASE_PERMUTE_F16_4, 2, 5}, permute_params{CASE_PERMUTE_F16_5, 2, 5}, permute_params{CASE_PERMUTE_F16_6, 2, 5}, - + permute_params{CASE_PERMUTE_S8_0, 2, 5}, permute_params{CASE_PERMUTE_S8_1, 2, 5}, permute_params{CASE_PERMUTE_S8_2, 2, 5}, permute_params{CASE_PERMUTE_S8_3, 2, 5}, - + permute_params{CASE_PERMUTE_U8_0, 2, 5}, permute_params{CASE_PERMUTE_U8_1, 2, 5}, permute_params{CASE_PERMUTE_U8_2, 2, 5}, permute_params{CASE_PERMUTE_U8_3, 2, 5}, - + permute_params{CASE_PERMUTE_F32_3D_0, 2, 5}, permute_params{CASE_PERMUTE_F32_3D_1, 2, 5}, permute_params{CASE_PERMUTE_F32_3D_2, 2, 5}, permute_params{CASE_PERMUTE_F32_3D_3, 2, 5}, permute_params{CASE_PERMUTE_F32_3D_4, 2, 5}, - + permute_params{CASE_PERMUTE_F16_3D_0, 2, 5}, permute_params{CASE_PERMUTE_F16_3D_1, 2, 5}, permute_params{CASE_PERMUTE_F16_3D_2, 2, 5}, permute_params{CASE_PERMUTE_F16_3D_3, 2, 5}, permute_params{CASE_PERMUTE_F16_3D_4, 2, 5}, - + permute_params{CASE_PERMUTE_S8_3D_0, 2, 5}, permute_params{CASE_PERMUTE_S8_3D_1, 2, 5}, permute_params{CASE_PERMUTE_S8_3D_2, 2, 5}, permute_params{CASE_PERMUTE_S8_3D_3, 2, 5}, - + permute_params{CASE_PERMUTE_U8_3D_0, 2, 5}, permute_params{CASE_PERMUTE_U8_3D_1, 2, 5}, permute_params{CASE_PERMUTE_U8_3D_2, 2, 5}, @@ -4235,7 +4248,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_quant_u8, permute_params{CASE_PERMUTE_F32_5, 2, 5}, permute_params{CASE_PERMUTE_F32_6, 2, 5}, permute_params{CASE_PERMUTE_F32_7, 2, 5}, - + permute_params{CASE_PERMUTE_F16_0, 2, 5}, permute_params{CASE_PERMUTE_F16_1, 2, 5}, permute_params{CASE_PERMUTE_F16_2, 2, 5}, @@ -4248,18 +4261,18 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_quant_u8, permute_params{CASE_PERMUTE_S8_1, 2, 5}, permute_params{CASE_PERMUTE_S8_2, 2, 5}, permute_params{CASE_PERMUTE_S8_3, 2, 5}, - + permute_params{CASE_PERMUTE_U8_0, 2, 5}, permute_params{CASE_PERMUTE_U8_1, 2, 5}, permute_params{CASE_PERMUTE_U8_2, 2, 5}, permute_params{CASE_PERMUTE_U8_3, 2, 5}, - + permute_params{CASE_PERMUTE_F32_3D_0, 2, 5}, permute_params{CASE_PERMUTE_F32_3D_1, 2, 5}, permute_params{CASE_PERMUTE_F32_3D_2, 2, 5}, permute_params{CASE_PERMUTE_F32_3D_3, 2, 5}, permute_params{CASE_PERMUTE_F32_3D_4, 2, 5}, - + permute_params{CASE_PERMUTE_F16_3D_0, 2, 5}, permute_params{CASE_PERMUTE_F16_3D_1, 2, 5}, permute_params{CASE_PERMUTE_F16_3D_2, 2, 5}, @@ -4270,7 +4283,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_quant_u8, permute_params{CASE_PERMUTE_S8_3D_1, 2, 5}, permute_params{CASE_PERMUTE_S8_3D_2, 2, 5}, permute_params{CASE_PERMUTE_S8_3D_3, 2, 5}, - + permute_params{CASE_PERMUTE_U8_3D_0, 2, 5}, permute_params{CASE_PERMUTE_U8_3D_1, 2, 5}, permute_params{CASE_PERMUTE_U8_3D_2, 2, 5}, @@ -4321,19 +4334,19 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_actv_eltw_scale_actv_quant_i8 permute_params{CASE_PERMUTE_F16_4, 2, 8}, permute_params{CASE_PERMUTE_F16_5, 2, 8}, permute_params{CASE_PERMUTE_F16_6, 2, 8}, - + permute_params{CASE_PERMUTE_S8_0, 2, 8}, permute_params{CASE_PERMUTE_S8_1, 2, 8}, permute_params{CASE_PERMUTE_S8_2, 2, 8}, permute_params{CASE_PERMUTE_S8_3, 2, 8}, - + permute_params{CASE_PERMUTE_U8_0, 2, 8}, permute_params{CASE_PERMUTE_U8_1, 2, 8}, permute_params{CASE_PERMUTE_U8_2, 2, 8}, permute_params{CASE_PERMUTE_U8_3, 2, 8}, - + permute_params{CASE_PERMUTE_F32_3D_0, 2, 8}, - permute_params{CASE_PERMUTE_F32_3D_1, 2, 8}, + permute_params{CASE_PERMUTE_F32_3D_1, 2, 8}, permute_params{CASE_PERMUTE_F32_3D_2, 2, 8}, permute_params{CASE_PERMUTE_F32_3D_3, 2, 8}, permute_params{CASE_PERMUTE_F32_3D_4, 2, 8}, @@ -4348,7 +4361,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_actv_eltw_scale_actv_quant_i8 permute_params{CASE_PERMUTE_S8_3D_1, 2, 8}, permute_params{CASE_PERMUTE_S8_3D_2, 2, 8}, permute_params{CASE_PERMUTE_S8_3D_3, 2, 8}, - + permute_params{CASE_PERMUTE_U8_3D_0, 2, 8}, permute_params{CASE_PERMUTE_U8_3D_1, 2, 8}, permute_params{CASE_PERMUTE_U8_3D_2, 2, 8}, @@ -4387,7 +4400,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_actv_scale_actv, permute_params{CASE_PERMUTE_F32_5, 2, 7}, permute_params{CASE_PERMUTE_F32_6, 2, 7}, permute_params{CASE_PERMUTE_F32_7, 2, 7}, - + permute_params{CASE_PERMUTE_F16_0, 2, 7}, permute_params{CASE_PERMUTE_F16_1, 2, 7}, permute_params{CASE_PERMUTE_F16_2, 2, 7}, @@ -4395,36 +4408,102 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_actv_scale_actv, permute_params{CASE_PERMUTE_F16_4, 2, 7}, permute_params{CASE_PERMUTE_F16_5, 2, 7}, permute_params{CASE_PERMUTE_F16_6, 2, 7}, - + permute_params{CASE_PERMUTE_S8_0, 2, 7}, permute_params{CASE_PERMUTE_S8_1, 2, 7}, permute_params{CASE_PERMUTE_S8_2, 2, 7}, permute_params{CASE_PERMUTE_S8_3, 2, 7}, - + permute_params{CASE_PERMUTE_U8_0, 2, 7}, permute_params{CASE_PERMUTE_U8_1, 2, 7}, permute_params{CASE_PERMUTE_U8_2, 2, 7}, permute_params{CASE_PERMUTE_U8_3, 2, 7}, - + permute_params{CASE_PERMUTE_F32_3D_0, 2, 7}, permute_params{CASE_PERMUTE_F32_3D_1, 2, 7}, permute_params{CASE_PERMUTE_F32_3D_2, 2, 7}, permute_params{CASE_PERMUTE_F32_3D_3, 2, 7}, permute_params{CASE_PERMUTE_F32_3D_4, 2, 7}, - + permute_params{CASE_PERMUTE_F16_3D_0, 2, 7}, permute_params{CASE_PERMUTE_F16_3D_1, 2, 7}, permute_params{CASE_PERMUTE_F16_3D_2, 2, 7}, permute_params{CASE_PERMUTE_F16_3D_3, 2, 7}, permute_params{CASE_PERMUTE_F16_3D_4, 2, 7}, - + permute_params{CASE_PERMUTE_S8_3D_0, 2, 7}, permute_params{CASE_PERMUTE_S8_3D_1, 2, 7}, permute_params{CASE_PERMUTE_S8_3D_2, 2, 7}, permute_params{CASE_PERMUTE_S8_3D_3, 2, 7}, - + permute_params{CASE_PERMUTE_U8_3D_0, 2, 7}, permute_params{CASE_PERMUTE_U8_3D_1, 2, 7}, permute_params{CASE_PERMUTE_U8_3D_2, 2, 7}, permute_params{CASE_PERMUTE_U8_3D_3, 2, 7}, }), ); + +class NormalizeFusingTest : public ::BaseFusingTest { +public: + void execute(normalize_test_params& p) { + auto input_prim = get_mem(get_input_layout(p)); + network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused); + network network_fused(this->engine, this->topology_fused, bo_fused); + network_fused.set_input_data("input", input_prim); + network_not_fused.set_input_data("input", input_prim); + + compare(network_not_fused, network_fused, p); + } + layout get_input_layout(normalize_test_params& p) { return layout{p.data_type, p.input_format, p.in_shape}; } + layout get_per_channel_layout(normalize_test_params& p) { + return layout{p.default_type, p.default_format, tensor{1, p.in_shape.feature[0], 1, 1}}; + } + layout get_weights_layout(normalize_test_params& p) { return layout {p.default_type, p.default_format, tensor{1, p.in_shape.feature[0], 1, 1}}; } +}; + +class normalize_i8_quantize : public NormalizeFusingTest {}; +TEST_P(normalize_i8_quantize, basic) { + auto p = GetParam(); + create_topologies( + input_layout("input", get_input_layout(p)), + data("weights", get_mem(get_weights_layout(p))), + data("in_lo", get_mem(get_single_element_layout(p), min_random, 0)), + data("in_hi", get_mem(get_single_element_layout(p), 1, max_random)), + data("out_lo", get_mem(get_single_element_layout(p), 0)), + data("out_hi", get_mem(get_single_element_layout(p), 255)), + normalize("normalizel2", "input", "weights", p.across_spatial), + quantize("quantize", "normalizel2", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::u8), + reorder("output_reorder", "quantize", p.default_format, data_types::f32)); + + tolerance = 1; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, + normalize_i8_quantize, + ::testing::ValuesIn(std::vector{ + normalize_test_params{CASE_NORMALIZE_I8_1, false, 2, 3}, + normalize_test_params{CASE_NORMALIZE_I8_1, true, 2, 3}, + }), ); + +class normalize_i8_float : public NormalizeFusingTest {}; +TEST_P(normalize_i8_float, basic) { + auto p = GetParam(); + create_topologies( + input_layout("input", get_input_layout(p)), + data("weights", get_mem(get_weights_layout(p))), + data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/255)), + normalize("normalizel2", "input", "weights", p.across_spatial), + scale("scale", "normalizel2", "scale_data"), + activation("activation", "scale", activation_func::abs), + reorder("output_reorder", "activation", p.default_format, data_types::f32)); + + tolerance = 1e-05f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, + normalize_i8_float, + ::testing::ValuesIn(std::vector{ + normalize_test_params{CASE_NORMALIZE_I8_1, false, 2, 4}, + normalize_test_params{CASE_NORMALIZE_I8_1, true, 2, 4}, + }), ); diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/normalizel2_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/normalizel2_gpu_test.cpp new file mode 100644 index 00000000000..8de4a1d07d4 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/normalizel2_gpu_test.cpp @@ -0,0 +1,290 @@ +// Copyright (c) 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include + +#include +#include +#include +#include +#include +#include "test_utils/test_utils.h" +#include + +#include +#include + +using namespace cldnn; +using namespace ::tests; + +TEST(normalizel2_f32_gpu, basic) { + // Input : 1x2x3x3 + // Output : 1x2x3x3 + + const auto& engine = get_test_engine(); + + const unsigned b = 1; + const unsigned f = 2; + const unsigned y = 3; + const unsigned x = 3; + + auto input = memory::allocate(engine, {data_types::f32, format::bfyx, {b, f, y, x}}); + auto weights = memory::allocate(engine, {data_types::f32, format::bfyx, {1, f, 1, 1}}); + + std::vector inputVals(b * f * y * x); + std::generate(inputVals.begin(), inputVals.end(), []() { + static float n = 0; + return n++; + }); + std::vector weightVals(f); + for (auto& it : weightVals) { + it = 1.f; + } + + set_values(input, inputVals); + set_values(weights, weightVals); + + topology topology; + topology.add(input_layout("Input0", input.get_layout())); + topology.add(data("Input1", weights)); + topology.add(normalize("normalizel2", "Input0", "Input1", false)); + + network network(engine, topology); + + network.set_input_data("Input0", input); + + auto outputs = network.execute(); + + auto output = outputs.at("normalizel2").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = {0.f, + 0.0995037f, + 0.178885f, + 0.242536f, + 0.294086f, + 0.336336f, + 0.371391f, + 0.400819f, + 0.425797f, + 1.f, + 0.995037f, + 0.98387f, + 0.970143f, + 0.955779f, + 0.941742f, + 0.928477f, + 0.916157f, + 0.904819f}; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_TRUE(are_equal(expected_results[i], output_ptr[i])); + } +} + +TEST(normalizel2_f32_gpu, basic2) { + // Input : 1x2x3x3 + // Output : 1x2x3x3 + + const auto& engine = get_test_engine(); + + const unsigned b = 1; + const unsigned f = 2; + const unsigned y = 3; + const unsigned x = 3; + + auto input = memory::allocate(engine, {data_types::f32, format::bfyx, {b, f, y, x}}); + auto weights = memory::allocate(engine, {data_types::f32, format::bfyx, {1, f, 1, 1}}); + + std::vector inputVals(b * f * y * x); + std::generate(inputVals.begin(), inputVals.end(), []() { + static float n = 0; + return n++; + }); + std::vector weightVals(f); + for (auto& it : weightVals) { + it = 1.f; + } + + set_values(input, inputVals); + set_values(weights, weightVals); + + topology topology; + topology.add(input_layout("Input0", input.get_layout())); + topology.add(data("Input1", weights)); + topology.add(normalize("normalizel2", "Input0", "Input1", true)); + + network network(engine, topology); + + network.set_input_data("Input0", input); + + auto outputs = network.execute(); + + auto output = outputs.at("normalizel2").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = {0.f, + 0.0236691f, + 0.0473381f, + 0.0710072f, + 0.0946762f, + 0.118345f, + 0.142014f, + 0.165683f, + 0.189352f, + 0.213021f, + 0.236691f, + 0.26036f, + 0.284029f, + 0.307698f, + 0.331367f, + 0.355036f, + 0.378705f, + 0.402374f}; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_TRUE(are_equal(expected_results[i], output_ptr[i])); + } +} + +TEST(normalizel2_int8_gpu, basic) { + // Input : 1x2x3x3 + // Output : 1x2x3x3 + + const auto& engine = get_test_engine(); + + const unsigned b = 1; + const unsigned f = 2; + const unsigned y = 3; + const unsigned x = 3; + + auto input = memory::allocate(engine, {data_types::i8, format::bfyx, {b, f, y, x}}); + auto weights = memory::allocate(engine, {data_types::f32, format::bfyx, {1, f, 1, 1}}); + + std::vector inputVals(b * f * y * x); + std::generate(inputVals.begin(), inputVals.end(), []() { + static int8_t n = 0; + return n++; + }); + std::vector weightVals(f); + for (auto& it : weightVals) { + it = 1; + } + + set_values(input, inputVals); + set_values(weights, weightVals); + + topology topology; + topology.add(input_layout("Input0", input.get_layout())); + topology.add(data("Input1", weights)); + topology.add(normalize("normalizel2", "Input0", "Input1", false)); + + network network(engine, topology); + + network.set_input_data("Input0", input); + + auto outputs = network.execute(); + + auto output = outputs.at("normalizel2").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = {0.f, + 0.0995037f, + 0.178885f, + 0.242536f, + 0.294086f, + 0.336336f, + 0.371391f, + 0.400819f, + 0.425797f, + 1.f, + 0.995037f, + 0.98387f, + 0.970143f, + 0.955779f, + 0.941742f, + 0.928477f, + 0.916157f, + 0.904819f}; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_TRUE(are_equal(expected_results[i], output_ptr[i])); + } +} + +TEST(normalizel2_int8_gpu, basic2) { + // Input : 1x2x3x3 + // Output : 1x2x3x3 + + const auto& engine = get_test_engine(); + + const unsigned b = 1; + const unsigned f = 2; + const unsigned y = 3; + const unsigned x = 3; + + auto input = memory::allocate(engine, {data_types::i8, format::bfyx, {b, f, y, x}}); + auto weights = memory::allocate(engine, {data_types::f32, format::bfyx, {1, f, 1, 1}}); + + std::vector inputVals(b * f * y * x); + std::generate(inputVals.begin(), inputVals.end(), []() { + static int8_t n = 0; + return n++; + }); + std::vector weightVals(f); + for (auto& it : weightVals) { + it = 1.f; + } + + set_values(input, inputVals); + set_values(weights, weightVals); + + topology topology; + topology.add(input_layout("Input0", input.get_layout())); + topology.add(data("Input1", weights)); + topology.add(normalize("normalizel2", "Input0", "Input1", true)); + + network network(engine, topology); + + network.set_input_data("Input0", input); + + auto outputs = network.execute(); + + auto output = outputs.at("normalizel2").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = {0.f, + 0.0236691f, + 0.0473381f, + 0.0710072f, + 0.0946762f, + 0.118345f, + 0.142014f, + 0.165683f, + 0.189352f, + 0.213021f, + 0.236691f, + 0.26036f, + 0.284029f, + 0.307698f, + 0.331367f, + 0.355036f, + 0.378705f, + 0.402374f}; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_TRUE(are_equal(expected_results[i], output_ptr[i])); + } +}