[IE CLDNN] Implement NormalizeL2 int8 kernels (#720)
This commit is contained in:
parent
a705f0c358
commit
2100521a14
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2016 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -20,14 +20,19 @@ ParamsKey NormalizeKernelAcrossSpatialRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT8);
|
||||
k.EnableInputDataType(Datatype::UINT8);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableInputLayout(DataLayout::yxfb);
|
||||
k.EnableInputLayout(DataLayout::byxf);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::yxfb);
|
||||
k.EnableOutputLayout(DataLayout::byxf);
|
||||
k.EnableDifferentTypes();
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
@ -39,4 +44,4 @@ KernelsData NormalizeKernelAcrossSpatialRef::GetKernelsData(const Params& params
|
||||
const optional_params& optParams) const {
|
||||
return GetCommonKernelsData(params, optParams, FORCE_PRIORITY_9);
|
||||
}
|
||||
} // namespace kernel_selector
|
||||
} // namespace kernel_selector
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2016 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -28,6 +28,14 @@ JitConstants NormalizeKernelBase::GetJitConstants(const normalize_params& np) co
|
||||
MakeJitConstant("THRESHOLD", 0.0001f),
|
||||
});
|
||||
|
||||
auto activation_dt = GetActivationType(np);
|
||||
jit.Merge(MakeTypeJitConstants(activation_dt, "ACTIVATION"));
|
||||
if (!np.fused_ops.empty()) {
|
||||
std::vector<std::string> idx_order = { "b", "f", "y", "x" };
|
||||
auto conf = FusedOpsConfiguration("", idx_order, "result", activation_dt);
|
||||
jit.Merge(MakeFusedOpsJitConstants(np, { conf }));
|
||||
}
|
||||
|
||||
return jit;
|
||||
}
|
||||
|
||||
@ -63,6 +71,8 @@ KernelsData NormalizeKernelBase::GetCommonKernelsData(const Params& params,
|
||||
const optional_params& options,
|
||||
float estimated_time) const {
|
||||
assert(params.GetType() == KernelType::NORMALIZE);
|
||||
if (!Validate(params, options))
|
||||
return {};
|
||||
|
||||
const normalize_params& orgParams = static_cast<const normalize_params&>(params);
|
||||
|
||||
@ -77,11 +87,39 @@ KernelsData NormalizeKernelBase::GetCommonKernelsData(const Params& params,
|
||||
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||
|
||||
auto& kernel = kd.kernels[0];
|
||||
FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point);
|
||||
FillCLKernelData(kernel,
|
||||
runInfo,
|
||||
params.engineInfo,
|
||||
kernelName,
|
||||
jit,
|
||||
entry_point,
|
||||
"",
|
||||
false,
|
||||
false,
|
||||
1,
|
||||
GetFusedPrimitiveInputsCount(params));
|
||||
|
||||
kernel.arguments.push_back({ArgumentDescriptor::Types::SCALE_TABLE, 0});
|
||||
|
||||
kd.estimatedTime = estimated_time;
|
||||
|
||||
return {kd};
|
||||
}
|
||||
|
||||
bool NormalizeKernelBase::Validate(const Params& params, const optional_params&) const {
|
||||
const normalize_params& orgParams = static_cast<const normalize_params&>(params);
|
||||
|
||||
for (auto& fused_op : orgParams.fused_ops) {
|
||||
if (!IsFusedPrimitiveSupported(fused_op))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Datatype NormalizeKernelBase::GetActivationType(const normalize_params& params) const {
|
||||
if (params.output.GetDType() == Datatype::F16)
|
||||
return Datatype::F16;
|
||||
return Datatype::F32;
|
||||
}
|
||||
} // namespace kernel_selector
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2016 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common_kernel_base.h"
|
||||
@ -59,5 +58,12 @@ protected:
|
||||
JitConstants GetJitConstants(const normalize_params& params) const;
|
||||
DispatchData SetDefault(const normalize_params& params) const;
|
||||
KernelsData GetCommonKernelsData(const Params& params, const optional_params&, float estimated_time) const;
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
return { FusedOpType::QUANTIZE,
|
||||
FusedOpType::ACTIVATION,
|
||||
FusedOpType::SCALE };
|
||||
}
|
||||
bool Validate(const Params& params, const optional_params&) const override;
|
||||
Datatype GetActivationType(const normalize_params& params) const;
|
||||
};
|
||||
} // namespace kernel_selector
|
||||
} // namespace kernel_selector
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2016 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -20,14 +20,19 @@ ParamsKey NormalizeKernelWithinSpatialRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT8);
|
||||
k.EnableInputDataType(Datatype::UINT8);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableInputLayout(DataLayout::yxfb);
|
||||
k.EnableInputLayout(DataLayout::byxf);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::yxfb);
|
||||
k.EnableOutputLayout(DataLayout::byxf);
|
||||
k.EnableDifferentTypes();
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
@ -39,4 +44,4 @@ KernelsData NormalizeKernelWithinSpatialRef::GetKernelsData(const Params& params
|
||||
const optional_params& optParams) const {
|
||||
return GetCommonKernelsData(params, optParams, FORCE_PRIORITY_9);
|
||||
}
|
||||
} // namespace kernel_selector
|
||||
} // namespace kernel_selector
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2016-2017 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -15,15 +15,14 @@
|
||||
#include "include/common.cl"
|
||||
#include "include/data_types.cl"
|
||||
|
||||
|
||||
#if FP16_UNIT_USED
|
||||
#define UNIT_CVT_FUNC(val) convert_half(val)
|
||||
#else
|
||||
#define UNIT_CVT_FUNC(val) (val)
|
||||
KERNEL (normalize_gpu_across_spatial_bfyx)(
|
||||
const __global INPUT0_TYPE* input,
|
||||
__global OUTPUT_TYPE* output,
|
||||
#if HAS_FUSED_OPS_DECLS
|
||||
FUSED_OPS_DECLS,
|
||||
#endif
|
||||
|
||||
|
||||
KERNEL (normalize_gpu_across_spatial_bfyx)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output, const __global UNIT_TYPE* scale_input)
|
||||
const __global SCALE_TABLE_TYPE* scale_input
|
||||
)
|
||||
{
|
||||
const uint b = get_global_id(0);
|
||||
|
||||
@ -68,13 +67,19 @@ KERNEL (normalize_gpu_across_spatial_bfyx)(const __global UNIT_TYPE* input, __gl
|
||||
const uint scale_index = f;
|
||||
#else
|
||||
const uint scale_index = f % SCALE_TABLE_FEATURE_NUM;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
for (uint y = 0; y < INPUT0_SIZE_Y; y++)
|
||||
{
|
||||
for (uint x = 0; x < INPUT0_SIZE_X; x++)
|
||||
{
|
||||
output[output_idx] = ACTIVATION(UNIT_CVT_FUNC(norm) * input[input_idx] * scale_input[scale_index], ACTIVATION_PARAMS);
|
||||
ACTIVATION_TYPE result = TO_ACTIVATION_TYPE(norm) * TO_ACTIVATION_TYPE(input[input_idx]) * TO_ACTIVATION_TYPE(scale_input[scale_index]);
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS;
|
||||
output[output_idx] = FUSED_OPS_RESULT;
|
||||
#else
|
||||
output[output_idx] = TO_OUTPUT_TYPE(ACTIVATION(result, ACTIVATION_PARAMS));
|
||||
#endif
|
||||
input_idx += INPUT0_X_PITCH;
|
||||
output_idx += OUTPUT_X_PITCH;
|
||||
}
|
||||
@ -85,6 +90,3 @@ KERNEL (normalize_gpu_across_spatial_bfyx)(const __global UNIT_TYPE* input, __gl
|
||||
output_idx += OUTPUT_FEATURE_PITCH - INPUT0_SIZE_Y*OUTPUT_Y_PITCH;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#undef UNIT_CVT_FUNC
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2016-2017 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -15,15 +15,14 @@
|
||||
#include "include/common.cl"
|
||||
#include "include/data_types.cl"
|
||||
|
||||
|
||||
#if FP16_UNIT_USED
|
||||
#define UNIT_CVT_FUNC(val) convert_half(val)
|
||||
#else
|
||||
#define UNIT_CVT_FUNC(val) (val)
|
||||
KERNEL (normalize_gpu_within_spatial_bfyx)(
|
||||
const __global INPUT0_TYPE* input,
|
||||
__global OUTPUT_TYPE* output,
|
||||
#if HAS_FUSED_OPS_DECLS
|
||||
FUSED_OPS_DECLS,
|
||||
#endif
|
||||
|
||||
|
||||
KERNEL (normalize_gpu_within_spatial_bfyx)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output, const __global UNIT_TYPE* scale_input)
|
||||
const __global SCALE_TABLE_TYPE* scale_input
|
||||
)
|
||||
{
|
||||
const uint x = get_global_id(0);
|
||||
const uint y = get_global_id(1);
|
||||
@ -40,7 +39,7 @@ KERNEL (normalize_gpu_within_spatial_bfyx)(const __global UNIT_TYPE* input, __gl
|
||||
norm = mad(value, value, norm);
|
||||
input_idx += INPUT0_FEATURE_PITCH;
|
||||
}
|
||||
|
||||
|
||||
uint output_idx = OUTPUT_OFFSET + b*OUTPUT_BATCH_PITCH + y*OUTPUT_Y_PITCH + x*OUTPUT_X_PITCH;
|
||||
|
||||
if(norm <= THRESHOLD)
|
||||
@ -62,13 +61,16 @@ KERNEL (normalize_gpu_within_spatial_bfyx)(const __global UNIT_TYPE* input, __gl
|
||||
const uint scale_index = f;
|
||||
#else
|
||||
const uint scale_index = f % SCALE_TABLE_FEATURE_NUM;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
output[output_idx] = ACTIVATION(UNIT_CVT_FUNC(norm) * input[input_idx] * scale_input[scale_index], ACTIVATION_PARAMS);
|
||||
ACTIVATION_TYPE result = TO_ACTIVATION_TYPE(norm) * TO_ACTIVATION_TYPE(input[input_idx]) * TO_ACTIVATION_TYPE(scale_input[scale_index]);
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS;
|
||||
output[output_idx] = FUSED_OPS_RESULT;
|
||||
#else
|
||||
output[output_idx] = TO_OUTPUT_TYPE(ACTIVATION(result, ACTIVATION_PARAMS));
|
||||
#endif
|
||||
output_idx += OUTPUT_FEATURE_PITCH;
|
||||
input_idx += INPUT0_FEATURE_PITCH;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#undef UNIT_CVT_FUNC
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2016 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -75,14 +75,26 @@ attach_normalize_gpu::attach_normalize_gpu() {
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::yxfb),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::yxfb),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::byxf),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::byxf),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::byxf),
|
||||
normalize_gpu::create);
|
||||
implementation_map<normalize>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::byxf),
|
||||
normalize_gpu::create);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
@ -359,6 +359,10 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<mvn>();
|
||||
|
||||
should_fuse |= input_data.is_type<normalize>() &&
|
||||
(input_data.get_dependency(0).get_output_layout().data_type == data_types::u8 ||
|
||||
input_data.get_dependency(0).get_output_layout().data_type == data_types::i8);
|
||||
|
||||
should_fuse |= input_data.is_type<deconvolution>();
|
||||
|
||||
should_fuse |= input_data.is_type<permute>();
|
||||
@ -396,6 +400,10 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<mvn>() && mvn_supports_fusings(input_data.as<mvn>());
|
||||
|
||||
should_fuse |= input_data.is_type<normalize>() &&
|
||||
(input_data.get_dependency(0).get_output_layout().data_type == data_types::u8 ||
|
||||
input_data.get_dependency(0).get_output_layout().data_type == data_types::i8);
|
||||
|
||||
should_fuse |= input_data.is_type<deconvolution>();
|
||||
|
||||
should_fuse |= input_data.is_type<permute>();
|
||||
@ -449,8 +457,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
quantize_node.get_scale_shift_opt() &&
|
||||
(out_layout.data_type == data_types::u8 || out_layout.data_type == data_types::i8);
|
||||
|
||||
should_fuse |= input_data.is_type<lrn>() &&
|
||||
quantize_node.get_scale_shift_opt();
|
||||
should_fuse |= input_data.is_type<lrn>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<gemm>() && gemm_supports_fusings(input_data.as<gemm>()) &&
|
||||
quantize_node.get_scale_shift_opt() &&
|
||||
@ -465,6 +472,10 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<activation>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<normalize>() && quantize_node.get_scale_shift_opt() &&
|
||||
(input_data.get_dependency(0).get_output_layout().data_type == data_types::u8 ||
|
||||
input_data.get_dependency(0).get_output_layout().data_type == data_types::i8);
|
||||
|
||||
should_fuse |= input_data.is_type<deconvolution>() && quantize_node.get_scale_shift_opt() &&
|
||||
// fp16/fp32 optimized kernels don't support chaning data type
|
||||
(input_data.get_dependency(0).get_output_layout().data_type == data_types::u8 ||
|
||||
|
@ -29,7 +29,16 @@ primitive_type_id normalize::type_id() {
|
||||
layout normalize_inst::calc_output_layout(normalize_node const& node) {
|
||||
assert(static_cast<bool>(node.get_primitive()->output_data_type) == false &&
|
||||
"Output data type forcing is not supported for normalize_node!");
|
||||
return node.input().get_non_padded_output_layout();
|
||||
auto input_node_layout = node.input().get_non_padded_output_layout();
|
||||
auto output_type = input_node_layout.data_type;
|
||||
|
||||
if (node.has_fused_primitives()) {
|
||||
output_type = node.get_fused_output_layout().data_type;
|
||||
} else if (input_node_layout.data_type == data_types::u8 || input_node_layout.data_type == data_types::i8) {
|
||||
output_type = data_types::f32;
|
||||
}
|
||||
|
||||
return layout(output_type, input_node_layout.format, input_node_layout.size);
|
||||
}
|
||||
|
||||
std::string normalize_inst::to_string(normalize_node const& node) {
|
||||
|
@ -85,6 +85,17 @@ struct gemm_test_params {
|
||||
size_t expected_not_fused_primitives;
|
||||
};
|
||||
|
||||
struct normalize_test_params {
|
||||
tensor in_shape;
|
||||
data_types data_type;
|
||||
format input_format;
|
||||
data_types default_type;
|
||||
format default_format;
|
||||
bool across_spatial;
|
||||
size_t expected_fused_primitives;
|
||||
size_t expected_not_fused_primitives;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class BaseFusingTest : public ::testing::TestWithParam<T> {
|
||||
public:
|
||||
@ -430,6 +441,8 @@ public:
|
||||
#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_NORMALIZE_I8_1 {1, 2, 3, 3}, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ---------------------------------------- FP32 convolution cases ------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
@ -4161,7 +4174,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_activation_scale_eltwise,
|
||||
permute_params{CASE_PERMUTE_F32_5, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_6, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_7, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F16_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_2, 2, 5},
|
||||
@ -4169,34 +4182,34 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_activation_scale_eltwise,
|
||||
permute_params{CASE_PERMUTE_F16_4, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_5, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_6, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_S8_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F32_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_3, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_4, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F16_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_3, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_4, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_S8_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3D_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3D_3, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3D_2, 2, 5},
|
||||
@ -4235,7 +4248,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_quant_u8,
|
||||
permute_params{CASE_PERMUTE_F32_5, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_6, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_7, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F16_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_2, 2, 5},
|
||||
@ -4248,18 +4261,18 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_quant_u8,
|
||||
permute_params{CASE_PERMUTE_S8_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F32_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_3, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_4, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F16_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_2, 2, 5},
|
||||
@ -4270,7 +4283,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_quant_u8,
|
||||
permute_params{CASE_PERMUTE_S8_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3D_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3D_3, 2, 5},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3D_2, 2, 5},
|
||||
@ -4321,19 +4334,19 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_actv_eltw_scale_actv_quant_i8
|
||||
permute_params{CASE_PERMUTE_F16_4, 2, 8},
|
||||
permute_params{CASE_PERMUTE_F16_5, 2, 8},
|
||||
permute_params{CASE_PERMUTE_F16_6, 2, 8},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_S8_0, 2, 8},
|
||||
permute_params{CASE_PERMUTE_S8_1, 2, 8},
|
||||
permute_params{CASE_PERMUTE_S8_2, 2, 8},
|
||||
permute_params{CASE_PERMUTE_S8_3, 2, 8},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_0, 2, 8},
|
||||
permute_params{CASE_PERMUTE_U8_1, 2, 8},
|
||||
permute_params{CASE_PERMUTE_U8_2, 2, 8},
|
||||
permute_params{CASE_PERMUTE_U8_3, 2, 8},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F32_3D_0, 2, 8},
|
||||
permute_params{CASE_PERMUTE_F32_3D_1, 2, 8},
|
||||
permute_params{CASE_PERMUTE_F32_3D_1, 2, 8},
|
||||
permute_params{CASE_PERMUTE_F32_3D_2, 2, 8},
|
||||
permute_params{CASE_PERMUTE_F32_3D_3, 2, 8},
|
||||
permute_params{CASE_PERMUTE_F32_3D_4, 2, 8},
|
||||
@ -4348,7 +4361,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_actv_eltw_scale_actv_quant_i8
|
||||
permute_params{CASE_PERMUTE_S8_3D_1, 2, 8},
|
||||
permute_params{CASE_PERMUTE_S8_3D_2, 2, 8},
|
||||
permute_params{CASE_PERMUTE_S8_3D_3, 2, 8},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_3D_0, 2, 8},
|
||||
permute_params{CASE_PERMUTE_U8_3D_1, 2, 8},
|
||||
permute_params{CASE_PERMUTE_U8_3D_2, 2, 8},
|
||||
@ -4387,7 +4400,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_actv_scale_actv,
|
||||
permute_params{CASE_PERMUTE_F32_5, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F32_6, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F32_7, 2, 7},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F16_0, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F16_1, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F16_2, 2, 7},
|
||||
@ -4395,36 +4408,102 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_actv_scale_actv,
|
||||
permute_params{CASE_PERMUTE_F16_4, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F16_5, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F16_6, 2, 7},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_S8_0, 2, 7},
|
||||
permute_params{CASE_PERMUTE_S8_1, 2, 7},
|
||||
permute_params{CASE_PERMUTE_S8_2, 2, 7},
|
||||
permute_params{CASE_PERMUTE_S8_3, 2, 7},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_0, 2, 7},
|
||||
permute_params{CASE_PERMUTE_U8_1, 2, 7},
|
||||
permute_params{CASE_PERMUTE_U8_2, 2, 7},
|
||||
permute_params{CASE_PERMUTE_U8_3, 2, 7},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F32_3D_0, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F32_3D_1, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F32_3D_2, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F32_3D_3, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F32_3D_4, 2, 7},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_F16_3D_0, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F16_3D_1, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F16_3D_2, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F16_3D_3, 2, 7},
|
||||
permute_params{CASE_PERMUTE_F16_3D_4, 2, 7},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_S8_3D_0, 2, 7},
|
||||
permute_params{CASE_PERMUTE_S8_3D_1, 2, 7},
|
||||
permute_params{CASE_PERMUTE_S8_3D_2, 2, 7},
|
||||
permute_params{CASE_PERMUTE_S8_3D_3, 2, 7},
|
||||
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_3D_0, 2, 7},
|
||||
permute_params{CASE_PERMUTE_U8_3D_1, 2, 7},
|
||||
permute_params{CASE_PERMUTE_U8_3D_2, 2, 7},
|
||||
permute_params{CASE_PERMUTE_U8_3D_3, 2, 7},
|
||||
}), );
|
||||
|
||||
class NormalizeFusingTest : public ::BaseFusingTest<normalize_test_params> {
|
||||
public:
|
||||
void execute(normalize_test_params& p) {
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||
network network_fused(this->engine, this->topology_fused, bo_fused);
|
||||
network_fused.set_input_data("input", input_prim);
|
||||
network_not_fused.set_input_data("input", input_prim);
|
||||
|
||||
compare(network_not_fused, network_fused, p);
|
||||
}
|
||||
layout get_input_layout(normalize_test_params& p) { return layout{p.data_type, p.input_format, p.in_shape}; }
|
||||
layout get_per_channel_layout(normalize_test_params& p) {
|
||||
return layout{p.default_type, p.default_format, tensor{1, p.in_shape.feature[0], 1, 1}};
|
||||
}
|
||||
layout get_weights_layout(normalize_test_params& p) { return layout {p.default_type, p.default_format, tensor{1, p.in_shape.feature[0], 1, 1}}; }
|
||||
};
|
||||
|
||||
class normalize_i8_quantize : public NormalizeFusingTest {};
|
||||
TEST_P(normalize_i8_quantize, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(
|
||||
input_layout("input", get_input_layout(p)),
|
||||
data("weights", get_mem(get_weights_layout(p))),
|
||||
data("in_lo", get_mem(get_single_element_layout(p), min_random, 0)),
|
||||
data("in_hi", get_mem(get_single_element_layout(p), 1, max_random)),
|
||||
data("out_lo", get_mem(get_single_element_layout(p), 0)),
|
||||
data("out_hi", get_mem(get_single_element_layout(p), 255)),
|
||||
normalize("normalizel2", "input", "weights", p.across_spatial),
|
||||
quantize("quantize", "normalizel2", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::u8),
|
||||
reorder("output_reorder", "quantize", p.default_format, data_types::f32));
|
||||
|
||||
tolerance = 1;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu,
|
||||
normalize_i8_quantize,
|
||||
::testing::ValuesIn(std::vector<normalize_test_params>{
|
||||
normalize_test_params{CASE_NORMALIZE_I8_1, false, 2, 3},
|
||||
normalize_test_params{CASE_NORMALIZE_I8_1, true, 2, 3},
|
||||
}), );
|
||||
|
||||
class normalize_i8_float : public NormalizeFusingTest {};
|
||||
TEST_P(normalize_i8_float, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(
|
||||
input_layout("input", get_input_layout(p)),
|
||||
data("weights", get_mem(get_weights_layout(p))),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/255)),
|
||||
normalize("normalizel2", "input", "weights", p.across_spatial),
|
||||
scale("scale", "normalizel2", "scale_data"),
|
||||
activation("activation", "scale", activation_func::abs),
|
||||
reorder("output_reorder", "activation", p.default_format, data_types::f32));
|
||||
|
||||
tolerance = 1e-05f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu,
|
||||
normalize_i8_float,
|
||||
::testing::ValuesIn(std::vector<normalize_test_params>{
|
||||
normalize_test_params{CASE_NORMALIZE_I8_1, false, 2, 4},
|
||||
normalize_test_params{CASE_NORMALIZE_I8_1, true, 2, 4},
|
||||
}), );
|
||||
|
290
inference-engine/thirdparty/clDNN/tests/test_cases/normalizel2_gpu_test.cpp
vendored
Normal file
290
inference-engine/thirdparty/clDNN/tests/test_cases/normalizel2_gpu_test.cpp
vendored
Normal file
@ -0,0 +1,290 @@
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <api/input_layout.hpp>
|
||||
#include <api/normalize.hpp>
|
||||
#include <api/topology.hpp>
|
||||
#include <api/network.hpp>
|
||||
#include <api/engine.hpp>
|
||||
#include "test_utils/test_utils.h"
|
||||
#include <api/data.hpp>
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
TEST(normalizel2_f32_gpu, basic) {
|
||||
// Input : 1x2x3x3
|
||||
// Output : 1x2x3x3
|
||||
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const unsigned b = 1;
|
||||
const unsigned f = 2;
|
||||
const unsigned y = 3;
|
||||
const unsigned x = 3;
|
||||
|
||||
auto input = memory::allocate(engine, {data_types::f32, format::bfyx, {b, f, y, x}});
|
||||
auto weights = memory::allocate(engine, {data_types::f32, format::bfyx, {1, f, 1, 1}});
|
||||
|
||||
std::vector<float> inputVals(b * f * y * x);
|
||||
std::generate(inputVals.begin(), inputVals.end(), []() {
|
||||
static float n = 0;
|
||||
return n++;
|
||||
});
|
||||
std::vector<float> weightVals(f);
|
||||
for (auto& it : weightVals) {
|
||||
it = 1.f;
|
||||
}
|
||||
|
||||
set_values(input, inputVals);
|
||||
set_values(weights, weightVals);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input.get_layout()));
|
||||
topology.add(data("Input1", weights));
|
||||
topology.add(normalize("normalizel2", "Input0", "Input1", false));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("Input0", input);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("normalizel2").get_memory();
|
||||
auto output_ptr = output.pointer<float>();
|
||||
|
||||
std::vector<float> expected_results = {0.f,
|
||||
0.0995037f,
|
||||
0.178885f,
|
||||
0.242536f,
|
||||
0.294086f,
|
||||
0.336336f,
|
||||
0.371391f,
|
||||
0.400819f,
|
||||
0.425797f,
|
||||
1.f,
|
||||
0.995037f,
|
||||
0.98387f,
|
||||
0.970143f,
|
||||
0.955779f,
|
||||
0.941742f,
|
||||
0.928477f,
|
||||
0.916157f,
|
||||
0.904819f};
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_TRUE(are_equal(expected_results[i], output_ptr[i]));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(normalizel2_f32_gpu, basic2) {
|
||||
// Input : 1x2x3x3
|
||||
// Output : 1x2x3x3
|
||||
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const unsigned b = 1;
|
||||
const unsigned f = 2;
|
||||
const unsigned y = 3;
|
||||
const unsigned x = 3;
|
||||
|
||||
auto input = memory::allocate(engine, {data_types::f32, format::bfyx, {b, f, y, x}});
|
||||
auto weights = memory::allocate(engine, {data_types::f32, format::bfyx, {1, f, 1, 1}});
|
||||
|
||||
std::vector<float> inputVals(b * f * y * x);
|
||||
std::generate(inputVals.begin(), inputVals.end(), []() {
|
||||
static float n = 0;
|
||||
return n++;
|
||||
});
|
||||
std::vector<float> weightVals(f);
|
||||
for (auto& it : weightVals) {
|
||||
it = 1.f;
|
||||
}
|
||||
|
||||
set_values(input, inputVals);
|
||||
set_values(weights, weightVals);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input.get_layout()));
|
||||
topology.add(data("Input1", weights));
|
||||
topology.add(normalize("normalizel2", "Input0", "Input1", true));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("Input0", input);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("normalizel2").get_memory();
|
||||
auto output_ptr = output.pointer<float>();
|
||||
|
||||
std::vector<float> expected_results = {0.f,
|
||||
0.0236691f,
|
||||
0.0473381f,
|
||||
0.0710072f,
|
||||
0.0946762f,
|
||||
0.118345f,
|
||||
0.142014f,
|
||||
0.165683f,
|
||||
0.189352f,
|
||||
0.213021f,
|
||||
0.236691f,
|
||||
0.26036f,
|
||||
0.284029f,
|
||||
0.307698f,
|
||||
0.331367f,
|
||||
0.355036f,
|
||||
0.378705f,
|
||||
0.402374f};
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_TRUE(are_equal(expected_results[i], output_ptr[i]));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(normalizel2_int8_gpu, basic) {
|
||||
// Input : 1x2x3x3
|
||||
// Output : 1x2x3x3
|
||||
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const unsigned b = 1;
|
||||
const unsigned f = 2;
|
||||
const unsigned y = 3;
|
||||
const unsigned x = 3;
|
||||
|
||||
auto input = memory::allocate(engine, {data_types::i8, format::bfyx, {b, f, y, x}});
|
||||
auto weights = memory::allocate(engine, {data_types::f32, format::bfyx, {1, f, 1, 1}});
|
||||
|
||||
std::vector<int8_t> inputVals(b * f * y * x);
|
||||
std::generate(inputVals.begin(), inputVals.end(), []() {
|
||||
static int8_t n = 0;
|
||||
return n++;
|
||||
});
|
||||
std::vector<float> weightVals(f);
|
||||
for (auto& it : weightVals) {
|
||||
it = 1;
|
||||
}
|
||||
|
||||
set_values(input, inputVals);
|
||||
set_values(weights, weightVals);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input.get_layout()));
|
||||
topology.add(data("Input1", weights));
|
||||
topology.add(normalize("normalizel2", "Input0", "Input1", false));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("Input0", input);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("normalizel2").get_memory();
|
||||
auto output_ptr = output.pointer<float>();
|
||||
|
||||
std::vector<float> expected_results = {0.f,
|
||||
0.0995037f,
|
||||
0.178885f,
|
||||
0.242536f,
|
||||
0.294086f,
|
||||
0.336336f,
|
||||
0.371391f,
|
||||
0.400819f,
|
||||
0.425797f,
|
||||
1.f,
|
||||
0.995037f,
|
||||
0.98387f,
|
||||
0.970143f,
|
||||
0.955779f,
|
||||
0.941742f,
|
||||
0.928477f,
|
||||
0.916157f,
|
||||
0.904819f};
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_TRUE(are_equal(expected_results[i], output_ptr[i]));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(normalizel2_int8_gpu, basic2) {
|
||||
// Input : 1x2x3x3
|
||||
// Output : 1x2x3x3
|
||||
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const unsigned b = 1;
|
||||
const unsigned f = 2;
|
||||
const unsigned y = 3;
|
||||
const unsigned x = 3;
|
||||
|
||||
auto input = memory::allocate(engine, {data_types::i8, format::bfyx, {b, f, y, x}});
|
||||
auto weights = memory::allocate(engine, {data_types::f32, format::bfyx, {1, f, 1, 1}});
|
||||
|
||||
std::vector<int8_t> inputVals(b * f * y * x);
|
||||
std::generate(inputVals.begin(), inputVals.end(), []() {
|
||||
static int8_t n = 0;
|
||||
return n++;
|
||||
});
|
||||
std::vector<float> weightVals(f);
|
||||
for (auto& it : weightVals) {
|
||||
it = 1.f;
|
||||
}
|
||||
|
||||
set_values(input, inputVals);
|
||||
set_values(weights, weightVals);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input.get_layout()));
|
||||
topology.add(data("Input1", weights));
|
||||
topology.add(normalize("normalizel2", "Input0", "Input1", true));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("Input0", input);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("normalizel2").get_memory();
|
||||
auto output_ptr = output.pointer<float>();
|
||||
|
||||
std::vector<float> expected_results = {0.f,
|
||||
0.0236691f,
|
||||
0.0473381f,
|
||||
0.0710072f,
|
||||
0.0946762f,
|
||||
0.118345f,
|
||||
0.142014f,
|
||||
0.165683f,
|
||||
0.189352f,
|
||||
0.213021f,
|
||||
0.236691f,
|
||||
0.26036f,
|
||||
0.284029f,
|
||||
0.307698f,
|
||||
0.331367f,
|
||||
0.355036f,
|
||||
0.378705f,
|
||||
0.402374f};
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_TRUE(are_equal(expected_results[i], output_ptr[i]));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user