[IE CLDNN] Resample opt caffe (#4770)
This commit is contained in:
parent
a533680a60
commit
94eeb99ad0
@ -124,14 +124,7 @@ void CreateInterpolateOp(Program& p, const std::shared_ptr<ngraph::op::v4::Inter
|
||||
int antialias = attrs.antialias;
|
||||
float cube_coeff = attrs.cube_coeff;
|
||||
|
||||
// [WA] Replace linear mode with linear_onnx to emulate the old behavior from v1->v4 Interpolate converison
|
||||
// This WA must be removed as soon as optimized kernel supports linear mode
|
||||
auto input_shape_rank = op->get_input_shape(0).size();
|
||||
auto mode = attrs.mode;
|
||||
if (mode == ngraph::op::v4::Interpolate::InterpolateMode::linear && input_shape_rank < 5) {
|
||||
mode = ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx;
|
||||
}
|
||||
|
||||
auto cldnnSampleType = GetResampleType(mode);
|
||||
auto shapeCalcMode = GetShapeCalculationMode(attrs.shape_calculation_mode);
|
||||
auto coordTransMode = GetCoordinateTransformationMode(attrs.coordinate_transformation_mode);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// Copyright (C) 2020-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -49,9 +49,9 @@ ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolat
|
||||
// If we write only
|
||||
// attrsV4.mode = ngraph::op::v4::Interpolate::InterpolateMode::linear;
|
||||
// instead of a conditional statements below when attrsV0.mode == "linear",
|
||||
// then we have a performance drop, because CPU and GPU have no optimized
|
||||
// then we have a performance drop, because CPU have no optimized
|
||||
// version of the 'linear' mode.
|
||||
// TODO: delete this conditional statement, when CPU and GPU will have
|
||||
// TODO: delete this conditional statement, when CPU will have
|
||||
// optimized version of the 'linear' mode.
|
||||
if (input_shape_rank < 5) {
|
||||
attrsV4.mode = ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx;
|
||||
|
@ -175,6 +175,22 @@ JitConstants ResampleKernelBase::GetJitConstants(const resample_params& params)
|
||||
MakeJitConstant("CUBE_COEFF", params.cube_coeff),
|
||||
});
|
||||
|
||||
if (params.resampleType == ResampleType::CAFFE_BILINEAR_INTERP) {
|
||||
if (axesUsed[0] == 1) jit.AddConstant(MakeJitConstant("AXES_USED_B", 1));
|
||||
if (axesUsed[1] == 1) jit.AddConstant(MakeJitConstant("AXES_USED_F", 1));
|
||||
if (axesUsed[2] == 1) jit.AddConstant(MakeJitConstant("AXES_USED_Z", 1));
|
||||
if (axesUsed[3] == 1) jit.AddConstant(MakeJitConstant("AXES_USED_Y", 1));
|
||||
if (axesUsed[4] == 1) jit.AddConstant(MakeJitConstant("AXES_USED_X", 1));
|
||||
|
||||
jit.AddConstants({
|
||||
MakeJitConstant("PADDED_B", b_size_padded),
|
||||
MakeJitConstant("PADDED_F", f_size_padded),
|
||||
MakeJitConstant("PADDED_X", x_size_padded),
|
||||
MakeJitConstant("PADDED_Y", y_size_padded),
|
||||
MakeJitConstant("PADDED_Z", z_size_padded),
|
||||
});
|
||||
}
|
||||
|
||||
size_t feature_block_size = GetFeatureBlockSize(params);
|
||||
|
||||
if (params.resampleType == ResampleType::CAFFE_BILINEAR_INTERP) {
|
||||
|
@ -37,6 +37,7 @@ ParamsKey ResampleKernelOpt::GetSupportedKey() const {
|
||||
k.EnableReampleType(ResampleType::BILINEAR_INTERP);
|
||||
k.EnableReampleType(ResampleType::NEAREST_NEIGHBOR);
|
||||
k.EnableReampleType(ResampleType::LINEAR_ONNX);
|
||||
k.EnableReampleType(ResampleType::CAFFE_BILINEAR_INTERP);
|
||||
k.EnableSubGroup();
|
||||
k.EnableSubGroupShort();
|
||||
return k;
|
||||
@ -46,13 +47,21 @@ ResampleKernelBase::DispatchData ResampleKernelOpt::SetDefault(const kernel_sele
|
||||
DispatchData dispatchData;
|
||||
const auto& out = arg.output;
|
||||
|
||||
dispatchData.gws[0] = CeilDiv(out.X().v, GetOptimalBlockSize(arg)) * out.Y().v;
|
||||
dispatchData.gws[1] = Align(out.Feature().v, sub_group_size);
|
||||
dispatchData.gws[2] = arg.output.Batch().v;
|
||||
if (arg.resampleType == ResampleType::CAFFE_BILINEAR_INTERP) {
|
||||
dispatchData.gws[0] = out.X().v * out.Y().v;
|
||||
dispatchData.gws[1] = CeilDiv(out.Feature().v, GetFeatureBlockSize(arg));
|
||||
dispatchData.gws[2] = arg.output.Batch().v;
|
||||
|
||||
dispatchData.lws[0] = 1;
|
||||
dispatchData.lws[1] = sub_group_size;
|
||||
dispatchData.lws[2] = 1;
|
||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, arg.engineInfo);
|
||||
} else {
|
||||
dispatchData.gws[0] = CeilDiv(out.X().v, GetOptimalBlockSize(arg)) * out.Y().v;
|
||||
dispatchData.gws[1] = Align(out.Feature().v, sub_group_size);
|
||||
dispatchData.gws[2] = arg.output.Batch().v;
|
||||
|
||||
dispatchData.lws[0] = 1;
|
||||
dispatchData.lws[1] = sub_group_size;
|
||||
dispatchData.lws[2] = 1;
|
||||
}
|
||||
|
||||
return dispatchData;
|
||||
}
|
||||
@ -98,10 +107,26 @@ JitConstants ResampleKernelOpt::GetJitConstants(const resample_params ¶ms) c
|
||||
jit.AddConstant(MakeJitConstant("VEC_SIZE", vec_size));
|
||||
|
||||
if (!params.fused_ops.empty()) {
|
||||
std::vector<std::string> idx_order = {"b", "feature_block", "y", "(x + out_x)"};
|
||||
FusedOpsConfiguration conf = {"", idx_order, "res", GetAccumulatorType(params), vec_size, LoadType::LT_ALIGNED_READ};
|
||||
conf.SetVectorAxis(Tensor::DataChannelName::FEATURE);
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
|
||||
if (params.resampleType != ResampleType::CAFFE_BILINEAR_INTERP) {
|
||||
std::vector<std::string> idx_order = {"b", "feature_block", "y", "(x + out_x)"};
|
||||
FusedOpsConfiguration conf = {"", idx_order, "res", GetAccumulatorType(params), vec_size, LoadType::LT_ALIGNED_READ};
|
||||
conf.SetVectorAxis(Tensor::DataChannelName::FEATURE);
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
|
||||
} else {
|
||||
std::vector<std::string> idx_order;
|
||||
idx_order = {"batch", "OF_ID", "oy", "ox"};
|
||||
|
||||
FusedOpsConfiguration conf = {"", idx_order, "res", GetAccumulatorType(params), 1};
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
|
||||
}
|
||||
}
|
||||
|
||||
if (params.resampleType == ResampleType::CAFFE_BILINEAR_INTERP) {
|
||||
if (GetFeatureBlockSize(params) == 8) {
|
||||
jit.AddConstant(MakeJitConstant("VEC_BLOCK_SIZE", 8));
|
||||
} else {
|
||||
jit.AddConstant(MakeJitConstant("VEC_BLOCK_SIZE", 16));
|
||||
}
|
||||
}
|
||||
|
||||
return jit;
|
||||
|
@ -2,12 +2,17 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "include/common.cl"
|
||||
#include "include/fetch.cl"
|
||||
#include "include/data_types.cl"
|
||||
#include "include/include_all.cl"
|
||||
|
||||
#define unroll_for __attribute__((opencl_unroll_hint)) for
|
||||
|
||||
#if ANTIALIAS == 1
|
||||
#define TRIANGLE_COEFF(a, x) ( (a) * ACCUMULATOR_MAX_FUNC(ACCUMULATOR_VAL_ZERO, ACCUMULATOR_VAL_ONE - ACCUMULATOR_ABS_FUNC((a) * (x))))
|
||||
#else
|
||||
#define TRIANGLE_COEFF(a, x) (ACCUMULATOR_MAX_FUNC(ACCUMULATOR_VAL_ZERO, ACCUMULATOR_VAL_ONE - ACCUMULATOR_ABS_FUNC(x)))
|
||||
#endif
|
||||
|
||||
#define READ_FUNC(ptr, offset) BLOCK_READN(INPUT0_TYPE, VEC_SIZE, ptr, offset)
|
||||
#define WRITE_FUNC(ptr, offset, val) BLOCK_WRITEN(OUTPUT_TYPE, VEC_SIZE, ptr, offset, val)
|
||||
|
||||
@ -18,6 +23,25 @@
|
||||
#define OUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_SIZE)
|
||||
#define TO_OUT_VEC_TYPE(x) CAT(convert_, OUT_VEC_TYPE)(x)
|
||||
|
||||
|
||||
inline uint FUNC(get_input_index)(uint b, uint f, uint y, uint x)
|
||||
{
|
||||
#if INPUT0_DIMS < 5
|
||||
return INPUT0_GET_INDEX(b, f, y, x);
|
||||
#else
|
||||
#error [clDNN resample_ref.cl]: input format - not supported
|
||||
#endif
|
||||
}
|
||||
|
||||
inline uint FUNC(get_output_index)(uint b, uint f, uint y, uint x)
|
||||
{
|
||||
#if OUTPUT_DIMS < 5
|
||||
return OUTPUT_GET_INDEX(b, f, y, x);
|
||||
#else
|
||||
#error [clDNN resample_ref.cl]: output format - not supported
|
||||
#endif
|
||||
}
|
||||
|
||||
inline float FUNC(get_original_coordinate)(float num, float scale, int length_resized, int length_original)
|
||||
{
|
||||
#if defined(COORD_TRANS_MODE_HALF_PIXEL)
|
||||
@ -35,6 +59,171 @@ inline float FUNC(get_original_coordinate)(float num, float scale, int length_re
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef SAMPLE_TYPE_CAFFE_INTERP
|
||||
KERNEL (resample_opt)(__global INPUT0_TYPE* input,
|
||||
__global OUTPUT_TYPE* output
|
||||
#if HAS_FUSED_OPS_DECLS
|
||||
, FUSED_OPS_DECLS
|
||||
#endif
|
||||
)
|
||||
{
|
||||
const int in_size[4] = { INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X };
|
||||
const int out_size[4] = { OUTPUT_BATCH_NUM, OUTPUT_FEATURE_NUM, OUTPUT_SIZE_Y, OUTPUT_SIZE_X };
|
||||
|
||||
const int ox = (int)get_global_id(0) % OUTPUT_SIZE_X;
|
||||
const int oy = (int)get_global_id(0) / OUTPUT_SIZE_X;
|
||||
const int feature_block_num = get_global_id(1);
|
||||
const int feature = feature_block_num * FEATURE_BLOCK_SIZE;
|
||||
|
||||
#if OUTPUT_DIMS <= 4
|
||||
const int batch = get_global_id(2);
|
||||
#else
|
||||
#error [clDNN resample_ref.cl]: Unsupported data dimension
|
||||
#endif
|
||||
|
||||
ACCUMULATOR_TYPE i_b = AXES_USED[0] ? FUNC_CALL(get_original_coordinate)(batch, SCALES[0], out_size[0], PADDED_B) : batch;
|
||||
ACCUMULATOR_TYPE i_f = AXES_USED[1] ? FUNC_CALL(get_original_coordinate)(feature, SCALES[1], out_size[1], PADDED_F) : feature;
|
||||
ACCUMULATOR_TYPE i_y = AXES_USED[3] ? FUNC_CALL(get_original_coordinate)(oy, SCALES[3], out_size[2], PADDED_Y) : oy;
|
||||
ACCUMULATOR_TYPE i_x = AXES_USED[4] ? FUNC_CALL(get_original_coordinate)(ox, SCALES[4], out_size[3], PADDED_X) : ox;
|
||||
|
||||
#if PADDING_USED == 1
|
||||
i_b -= PADS_BEGIN[0];
|
||||
i_f -= PADS_BEGIN[1];
|
||||
i_y -= PADS_BEGIN[3];
|
||||
i_x -= PADS_BEGIN[4];
|
||||
#endif
|
||||
|
||||
const int ib_r = (int)i_b;
|
||||
const int if_r = (int)i_f;
|
||||
const int iy_r = (int)i_y;
|
||||
const int ix_r = (int)i_x;
|
||||
|
||||
#if ANTIALIAS == 1
|
||||
const ACCUMULATOR_TYPE ab = 1.0f / SCALES[0];
|
||||
const ACCUMULATOR_TYPE af = 1.0f / SCALES[1];
|
||||
const ACCUMULATOR_TYPE ay = 1.0f / SCALES[3];
|
||||
const ACCUMULATOR_TYPE ax = 1.0f / SCALES[4];
|
||||
|
||||
const int rb = (SCALES[0] < 1.0f) ? 2 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / ab);
|
||||
const int rf = (SCALES[1] < 1.0f) ? 2 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / af);
|
||||
const int ry = (SCALES[3] < 1.0f) ? 2 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / ay);
|
||||
const int rx = (SCALES[4] < 1.0f) ? 2 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / ax);
|
||||
#else
|
||||
const ACCUMULATOR_TYPE ab = 1.0f;
|
||||
const ACCUMULATOR_TYPE af = 1.0f;
|
||||
const ACCUMULATOR_TYPE ay = 1.0f;
|
||||
const ACCUMULATOR_TYPE ax = 1.0f;
|
||||
|
||||
const int rb = (SCALES[0] < 1.0f) ? 1 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / ab);
|
||||
const int rf = (SCALES[1] < 1.0f) ? 1 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / af);
|
||||
const int ry = (SCALES[3] < 1.0f) ? 1 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / ay);
|
||||
const int rx = (SCALES[4] < 1.0f) ? 1 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / ax);
|
||||
#endif
|
||||
|
||||
int const b_init = max(-PADS_BEGIN[0], ib_r - rb);
|
||||
int const f_init = max(-PADS_BEGIN[1], if_r - rf);
|
||||
int const y_init = max(-PADS_BEGIN[3], iy_r - ry);
|
||||
int const x_init = max(-PADS_BEGIN[4], ix_r - rx);
|
||||
|
||||
int const b_max = min(PADS_END[0] + INPUT0_BATCH_NUM, ib_r + rb + 1);
|
||||
int const f_max = min(PADS_END[1] + INPUT0_FEATURE_NUM, if_r + rf + 1);
|
||||
int const y_max = min(PADS_END[3] + INPUT0_SIZE_Y, iy_r + ry + 1);
|
||||
int const x_max = min(PADS_END[4] + INPUT0_SIZE_X, ix_r + rx + 1);
|
||||
|
||||
const int fp_max = FEATURE_BLOCK_SIZE;
|
||||
|
||||
ACCUMULATOR_TYPE wb = ACCUMULATOR_VAL_ZERO;
|
||||
ACCUMULATOR_TYPE wf = ACCUMULATOR_VAL_ZERO;
|
||||
ACCUMULATOR_TYPE wy = ACCUMULATOR_VAL_ZERO;
|
||||
ACCUMULATOR_TYPE wx = ACCUMULATOR_VAL_ZERO;
|
||||
ACCUMULATOR_TYPE w = ACCUMULATOR_VAL_ZERO;
|
||||
|
||||
for (int fp = 0; fp < fp_max; fp+=VEC_BLOCK_SIZE) {
|
||||
MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, VEC_BLOCK_SIZE) sum = ACCUMULATOR_VAL_ZERO;
|
||||
ACCUMULATOR_TYPE wsum = ACCUMULATOR_VAL_ZERO;
|
||||
|
||||
for (int b = b_init; b < b_max; b++) {
|
||||
wb = TRIANGLE_COEFF(ab, i_b - b);
|
||||
|
||||
for (int f = f_init; f < f_max; f++) {
|
||||
wf = wb * TRIANGLE_COEFF(af, i_f - f);
|
||||
|
||||
if (wf != 0) {
|
||||
for (int y = y_init; y < y_max; y++) {
|
||||
wy = wf * TRIANGLE_COEFF(ay, i_y - y);
|
||||
|
||||
if (wy != 0) {
|
||||
for (int x = x_init; x < x_max; x++) {
|
||||
wx = TRIANGLE_COEFF(ax, i_x - x);
|
||||
w = wx * wy;
|
||||
|
||||
#if PADDING_USED == 1
|
||||
bool isOutOfBounds = b < 0 || f < 0 || y < 0 || x < 0 ||
|
||||
b >= in_size[0] || f >= in_size[1] ||
|
||||
y >= in_size[2] || x >= in_size[3];
|
||||
#endif
|
||||
if (w != 0) {
|
||||
wsum += w;
|
||||
|
||||
#if PADDING_USED == 1
|
||||
if (!isOutOfBounds)
|
||||
#endif
|
||||
{
|
||||
#if VEC_BLOCK_SIZE == 8
|
||||
MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_BLOCK_SIZE) input_vec = vload8(0, &input[FUNC_CALL(get_input_index)(b, f+fp, y, x)]);
|
||||
sum = fma(convert_float8(input_vec), (float8)w, sum);
|
||||
#else
|
||||
MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_BLOCK_SIZE) input_vec = vload16(0, &input[FUNC_CALL(get_input_index)(b, f+fp, y, x)]);
|
||||
sum = fma(convert_float16(input_vec), (float16)w, sum);
|
||||
#endif
|
||||
}
|
||||
} // w != 0;
|
||||
} // for (int x = x_init; x < x_max; x++)
|
||||
}
|
||||
} // for (int y = y_init; y < y_max; y++)
|
||||
}
|
||||
} // for (int f = f_init; f < f_max; f++)
|
||||
} // for (int b = b_init; b < b_max; b++)
|
||||
|
||||
MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_BLOCK_SIZE) out;
|
||||
ACCUMULATOR_TYPE res;
|
||||
|
||||
if (wsum == 0) {
|
||||
res = ACCUMULATOR_VAL_ZERO;
|
||||
for (int f = 0; f < VEC_BLOCK_SIZE; f++) {
|
||||
#if HAS_FUSED_OPS
|
||||
#define OF_ID (feature+fp+f)
|
||||
FUSED_OPS;
|
||||
out[f] = FUSED_OPS_RESULT;
|
||||
#undef OF_ID
|
||||
#else
|
||||
out[f] = ACTIVATION(TO_OUTPUT_TYPE(res), ACTIVATION_PARAMS);
|
||||
#endif
|
||||
}
|
||||
} else {
|
||||
for (int f = 0; f < VEC_BLOCK_SIZE; f++) {
|
||||
res = sum[f] / wsum;
|
||||
#if HAS_FUSED_OPS
|
||||
#define OF_ID (feature+fp+f)
|
||||
FUSED_OPS;
|
||||
out[f] = FUSED_OPS_RESULT;
|
||||
#undef OF_ID
|
||||
#else
|
||||
out[f] = ACTIVATION(TO_OUTPUT_TYPE(res), ACTIVATION_PARAMS);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#if VEC_BLOCK_SIZE == 8
|
||||
vstore8(out, 0, &output[FUNC_CALL(get_output_index)(batch, feature+fp, oy, ox)]);
|
||||
#else
|
||||
vstore16(out, 0, &output[FUNC_CALL(get_output_index)(batch, feature+fp, oy, ox)]);
|
||||
#endif
|
||||
} // fp
|
||||
}
|
||||
#endif // SAMPLE_TYPE_CAFFE_INTERP
|
||||
|
||||
#ifndef SAMPLE_TYPE_CAFFE_INTERP
|
||||
__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))
|
||||
KERNEL (resample_opt)(__global INPUT0_TYPE* input,
|
||||
__global OUTPUT_TYPE* output
|
||||
@ -140,7 +329,9 @@ KERNEL (resample_opt)(__global INPUT0_TYPE* input,
|
||||
WRITE_FUNC(output, OUTPUT_GET_INDEX(b, feature_block, y, (x + out_x)), out);
|
||||
}
|
||||
}
|
||||
#endif // !SAMPLE_TYPE_CAFFE_INTERP
|
||||
|
||||
#undef unroll_for
|
||||
#undef TRIANGLE_COEFF
|
||||
#undef READ_FUNC
|
||||
#undef WRITE_FUNC
|
||||
|
@ -2,9 +2,8 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "include/common.cl"
|
||||
#include "include/fetch.cl"
|
||||
#include "include/data_types.cl"
|
||||
#include "include/include_all.cl"
|
||||
|
||||
inline uint FUNC(get_input_index)(uint b, uint f, uint z, uint y, uint x)
|
||||
{
|
||||
@ -373,11 +372,6 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
|
||||
const int batch = (int)get_global_id(2) % OUTPUT_BATCH_NUM;
|
||||
const int oz = (int)get_global_id(2) / OUTPUT_BATCH_NUM;
|
||||
#endif
|
||||
const int PADDED_B = in_size[0] + PADS_BEGIN[0] + PADS_END[0];
|
||||
const int PADDED_F = in_size[1] + PADS_BEGIN[1] + PADS_END[1];
|
||||
const int PADDED_Z = in_size[2] + PADS_BEGIN[2] + PADS_END[2];
|
||||
const int PADDED_Y = in_size[3] + PADS_BEGIN[3] + PADS_END[3];
|
||||
const int PADDED_X = in_size[4] + PADS_BEGIN[4] + PADS_END[4];
|
||||
|
||||
ACCUMULATOR_TYPE i_b = AXES_USED[0] ? FUNC_CALL(get_original_coordinate)(batch, SCALES[0], out_size[0], PADDED_B) : batch;
|
||||
ACCUMULATOR_TYPE i_f = AXES_USED[1] ? FUNC_CALL(get_original_coordinate)(feature, SCALES[1], out_size[1], PADDED_F) : feature;
|
||||
|
@ -2956,6 +2956,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_eltwise,
|
||||
#define CASE_RESAMPLE_FP32_7 {1, 16, 4, 5, 4}, {1, 16, 2, 3, 2}, data_types::f32, format::bfzyx, resample_type::nearest, data_types::f32, format::bfzyx
|
||||
#define CASE_RESAMPLE_FP32_8 {1, 16, 4, 5, 4}, {1, 16, 2, 3, 2}, data_types::f32, format::bfzyx, resample_type::caffe_bilinear, data_types::f32, format::bfzyx
|
||||
#define CASE_RESAMPLE_FP32_9 {1, 16, 4, 5}, {1, 16, 7, 8}, data_types::f32, format::b_fs_yx_fsv16, resample_type::bilinear, data_types::f32, format::bfyx
|
||||
#define CASE_RESAMPLE_FP32_10 {1, 16, 4, 5}, {1, 16, 7, 8}, data_types::f32, format::b_fs_yx_fsv16, resample_type::caffe_bilinear, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_RESAMPLE_FP16_1 {1, 15, 4, 5}, {1, 15, 2, 3}, data_types::f16, format::bfyx, resample_type::nearest, data_types::f16, format::bfyx
|
||||
#define CASE_RESAMPLE_FP16_2 {1, 15, 4, 5}, {1, 15, 2, 3}, data_types::f16, format::bfyx, resample_type::bilinear, data_types::f16, format::bfyx
|
||||
@ -2967,6 +2968,10 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_eltwise,
|
||||
#define CASE_RESAMPLE_FP16_8 {1, 16, 4, 5, 4}, {1, 16, 2, 3, 2}, data_types::f16, format::bfzyx, resample_type::caffe_bilinear, data_types::f16, format::bfzyx
|
||||
#define CASE_RESAMPLE_FP16_9 {1, 16, 4, 5}, {1, 16, 7, 8}, data_types::f16, format::b_fs_yx_fsv16, resample_type::bilinear, data_types::f16, format::bfyx
|
||||
#define CASE_RESAMPLE_FP16_10 {2, 32, 4, 5}, {2, 32, 7, 8}, data_types::f16, format::fs_b_yx_fsv32, resample_type::bilinear, data_types::f16, format::bfyx
|
||||
#define CASE_RESAMPLE_FP16_11 {1, 16, 4, 5}, {1, 16, 7, 8}, data_types::f16, format::b_fs_yx_fsv16, resample_type::caffe_bilinear, data_types::f16, format::bfyx
|
||||
#define CASE_RESAMPLE_FP16_12 {2, 32, 4, 5}, {2, 32, 7, 8}, data_types::f16, format::fs_b_yx_fsv32, resample_type::caffe_bilinear, data_types::f16, format::bfyx
|
||||
#define CASE_RESAMPLE_FP16_13 {1, 16, 4, 5}, {1, 16, 7, 8}, data_types::f16, format::b_fs_yx_fsv16, resample_type::caffe_bilinear, data_types::f16, format::bfyx
|
||||
#define CASE_RESAMPLE_FP16_14 {1, 32, 4, 5}, {1, 32, 2, 3}, data_types::f16, format::fs_b_yx_fsv32, resample_type::caffe_bilinear, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_RESAMPLE_I8_1 {1, 16, 4, 5}, {1, 16, 2, 3}, data_types::i8, format::b_fs_yx_fsv16, resample_type::nearest, data_types::f32, format::bfyx
|
||||
#define CASE_RESAMPLE_I8_2 {2, 32, 4, 5}, {2, 32, 2, 3}, data_types::i8, format::b_fs_yx_fsv16, resample_type::nearest, data_types::f32, format::bfyx
|
||||
@ -3006,6 +3011,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_quantize,
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_7, 2, 3 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_8, 2, 3 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_9, 2, 3 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_10, 2, 3 },
|
||||
|
||||
// FQ can't be fused to FP16 primitive for now
|
||||
// resample_test_params{ CASE_RESAMPLE_FP16_1, 2, 3 },
|
||||
@ -3047,6 +3053,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_activation_eltwise,
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_7, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_8, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_9, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_10, 2, 5 },
|
||||
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_1, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_2, 2, 5 },
|
||||
@ -3058,6 +3065,10 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_activation_eltwise,
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_8, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_9, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_10, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_11, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_12, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_13, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_14, 2, 5 },
|
||||
|
||||
resample_test_params{ CASE_RESAMPLE_I8_1, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_I8_2, 2, 5 },
|
||||
@ -3106,6 +3117,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_quantize_concat,
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_7, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_8, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_9, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_10, 3, 6 },
|
||||
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_1, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_2, 3, 6 },
|
||||
@ -3117,6 +3129,10 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_quantize_concat,
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_8, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_9, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_10, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_11, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_12, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_13, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_14, 3, 6 },
|
||||
|
||||
resample_test_params{ CASE_RESAMPLE_I8_3, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_I8_4, 3, 6 },
|
||||
@ -3157,6 +3173,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_concat,
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_7, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_8, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_9, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_10, 3, 6 },
|
||||
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_1, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_2, 3, 6 },
|
||||
@ -3168,6 +3185,10 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_concat,
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_8, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_9, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_10, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_11, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_12, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_13, 3, 6 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_14, 3, 6 },
|
||||
|
||||
resample_test_params{ CASE_RESAMPLE_I8_1, 3, 6},
|
||||
resample_test_params{ CASE_RESAMPLE_I8_2, 3, 6},
|
||||
@ -5855,7 +5876,7 @@ class ScatterElementsUpdatePrimitiveFusingTest : public ::BaseFusingTest<scatter
|
||||
public:
|
||||
void execute(scatter_elements_update_test_params& p) {
|
||||
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
auto input_prim = get_mem(get_input_layout(p), -5, 5);
|
||||
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||
network network_fused(this->engine, this->topology_fused, bo_fused);
|
||||
network_fused.set_input_data("input", input_prim);
|
||||
@ -5942,7 +5963,7 @@ TEST_P(scatter_elements_update_scale_activation_eltwise, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("scatter_elements_update_indices", get_repeatless_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p)) - 1)),
|
||||
data("scatter_elements_update_updates", get_mem(get_updates_layout(p), 0, 100)),
|
||||
data("scatter_elements_update_updates", get_mem(get_updates_layout(p), 0, 5)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), -1, 1)),
|
||||
data("eltwise_data", get_mem(layout{ p.data_type, p.input_format, p.input_shape})),
|
||||
scatter_elements_update("scatter_elements_update_prim", "input", "scatter_elements_update_indices", "scatter_elements_update_updates", p.axis),
|
||||
@ -5951,7 +5972,7 @@ TEST_P(scatter_elements_update_scale_activation_eltwise, basic) {
|
||||
eltwise("eltwise", {"scale", "eltwise_data"}, eltwise_mode::sum, p.data_type),
|
||||
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
|
||||
);
|
||||
tolerance = 1e-5f;
|
||||
tolerance = 1e-2f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -7749,7 +7770,7 @@ TEST_P(scatter_nd_update_scale_activation_eltwise, basic) {
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("scatter_nd_update_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices)),
|
||||
data("scatter_nd_update_updates", get_mem(get_updates_layout(p), 0, 100)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), -1, 1)),
|
||||
data("eltwise_data", get_mem(layout{ p.data_type, p.input_format, p.input_shape })),
|
||||
scatter_nd_update("scatter_nd_update_prim", "input", "scatter_nd_update_indices", "scatter_nd_update_updates", p.indices_rank),
|
||||
activation("activation", "scatter_nd_update_prim", activation_func::abs),
|
||||
@ -7778,8 +7799,8 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise,
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_6, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_7, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_8, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_9, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_8, 2, 5 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_1, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_2, 2, 5 },
|
||||
|
@ -718,7 +718,6 @@ struct resample_random_test : testing::TestWithParam<resample_random_test_params
|
||||
if (info.original_id == "resample")
|
||||
kernel = info.kernel_id;
|
||||
}
|
||||
SCOPED_TRACE("kernel: " + kernel);
|
||||
|
||||
compare(in_mem, output, params.operation_type, params.align_corners);
|
||||
}
|
||||
@ -750,7 +749,7 @@ struct resample_random_test_param_generator : std::vector<resample_random_test_p
|
||||
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke,
|
||||
INSTANTIATE_TEST_CASE_P(smoke_resample,
|
||||
resample_random_test,
|
||||
testing::ValuesIn(
|
||||
resample_random_test_param_generator()
|
||||
@ -765,6 +764,206 @@ INSTANTIATE_TEST_CASE_P(smoke,
|
||||
.smoke_params(data_types::u8, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
|
||||
), );
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct caffe_resample_random_test_params {
|
||||
data_types input_type;
|
||||
tensor input_size;
|
||||
tensor output_size;
|
||||
uint32_t num_filter;
|
||||
resample_type operation_type;
|
||||
uint32_t align_corners;
|
||||
format::type in_format;
|
||||
format::type out_format;
|
||||
std::vector<int32_t> pads_begin;
|
||||
std::vector<int32_t> pads_end;
|
||||
};
|
||||
|
||||
struct caffe_resample_random_test : testing::TestWithParam<caffe_resample_random_test_params>
|
||||
{
|
||||
template <typename T>
|
||||
void fill_random_typed(memory& mem, int min, int max, int k) {
|
||||
auto size = mem.get_layout().size;
|
||||
size_t b = size.batch[0];
|
||||
size_t f = size.feature[0];
|
||||
size_t x = size.spatial[0];
|
||||
size_t y = size.spatial[1];
|
||||
|
||||
auto data = generate_random_4d<T>(b, f, y, x, min, max, k);
|
||||
auto ptr = mem.pointer<T>();
|
||||
for (size_t bi = 0; bi < b; ++bi) {
|
||||
for (size_t fi = 0; fi < f; ++fi) {
|
||||
for (size_t yi = 0; yi < y; ++yi) {
|
||||
for (size_t xi = 0; xi < x; ++xi) {
|
||||
auto coords = tensor(batch(bi), feature(fi), spatial(xi, yi, 0, 0));
|
||||
auto offset = mem.get_layout().get_linear_offset(coords);
|
||||
ptr[offset] = data[bi][fi][yi][xi];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fill_random(memory& mem) {
|
||||
auto dt = mem.get_layout().data_type;
|
||||
switch (dt) {
|
||||
case data_types::f32:
|
||||
fill_random_typed<float>(mem, -127, 127, 2);
|
||||
break;
|
||||
case data_types::f16:
|
||||
fill_random_typed<FLOAT16>(mem, -127, 127, 2);
|
||||
break;
|
||||
case data_types::i8:
|
||||
fill_random_typed<int8_t>(mem, -127, 127, 1);
|
||||
break;
|
||||
case data_types::u8:
|
||||
fill_random_typed<uint8_t>(mem, 0, 255, 1);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool compare_outputs(const memory& out_ref, const memory& out_opt) {
|
||||
auto output_lay = out_ref.get_layout();
|
||||
auto opt_output_lay = out_opt.get_layout();
|
||||
|
||||
size_t b = output_lay.size.batch[0];
|
||||
size_t f = output_lay.size.feature[0];
|
||||
size_t x = output_lay.size.spatial[0];
|
||||
size_t y = output_lay.size.spatial[1];
|
||||
auto ref_ptr = out_ref.pointer<T>();
|
||||
auto opt_ptr = out_opt.pointer<T>();
|
||||
for (size_t bi = 0; bi < b; ++bi) {
|
||||
for (size_t fi = 0; fi < f; ++fi) {
|
||||
for (size_t yi = 0; yi < y; ++yi) {
|
||||
for (size_t xi = 0; xi < x; ++xi) {
|
||||
auto ref_out_coords = tensor(batch(bi), feature(fi), spatial(xi, yi, 0, 0));
|
||||
auto ref_out_offset = output_lay.get_linear_offset(ref_out_coords);
|
||||
auto ref_out_val = ref_ptr[ref_out_offset];
|
||||
|
||||
auto opt_out_offset = opt_output_lay.get_linear_offset(ref_out_coords);
|
||||
auto opt_out_val = opt_ptr[opt_out_offset];
|
||||
|
||||
EXPECT_EQ(ref_out_offset, opt_out_offset);
|
||||
EXPECT_EQ(opt_out_val, ref_out_val);
|
||||
// EXPECT_NEAR(static_cast<float>(opt_out_val), static_cast<float>(ref_out_val), 1.e-1f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void execute_compare(const caffe_resample_random_test_params& params, bool check_result) {
|
||||
auto eng = cldnn::engine();
|
||||
|
||||
auto in_layout = layout(params.input_type, params.in_format, params.input_size);
|
||||
auto in_mem = memory::allocate(eng, in_layout);
|
||||
fill_random(in_mem);
|
||||
|
||||
cldnn::topology topo;
|
||||
topo.add(input_layout("in", in_layout));
|
||||
auto prim = resample("resample", "in", params.output_size, params.num_filter, params.operation_type);
|
||||
prim.align_corners = params.align_corners;
|
||||
prim.pads_begin = params.pads_begin;
|
||||
prim.pads_end = params.pads_end;
|
||||
topo.add(prim);
|
||||
|
||||
auto build_opts = build_options();
|
||||
build_opts.set_option(build_option::outputs({"resample"}));
|
||||
build_opts.set_option(build_option::force_implementations({ {"resample", {params.in_format, "resample_ref"}} }));
|
||||
|
||||
auto net = network(eng, topo, build_opts);
|
||||
net.set_input_data("in", in_mem);
|
||||
|
||||
auto result = net.execute();
|
||||
auto output = result.at("resample").get_memory();
|
||||
|
||||
// Execute resample_opt
|
||||
auto eng_opt = cldnn::engine();
|
||||
|
||||
cldnn::topology topo_opt;
|
||||
topo_opt.add(input_layout("in", in_layout));
|
||||
auto prim_opt = resample("resample_opt", "in", params.output_size, params.num_filter, params.operation_type);
|
||||
prim_opt.align_corners = params.align_corners;
|
||||
prim_opt.pads_begin = params.pads_begin;
|
||||
prim_opt.pads_end = params.pads_end;
|
||||
topo_opt.add(prim_opt);
|
||||
|
||||
auto build_opts_opt = build_options();
|
||||
build_opts_opt.set_option(build_option::outputs({"resample_opt"}));
|
||||
build_opts.set_option(build_option::force_implementations({ {"resample_opt", {params.in_format, "resample_opt"}} }));
|
||||
|
||||
auto net_opt = network(eng_opt, topo_opt, build_opts_opt);
|
||||
|
||||
// Use in_mem from ref network
|
||||
net_opt.set_input_data("in", in_mem);
|
||||
|
||||
auto result_opt = net_opt.execute();
|
||||
auto output_opt = result_opt.at("resample_opt").get_memory();
|
||||
|
||||
if (check_result == true) {
|
||||
// Check data_types
|
||||
if (params.input_type == data_types::f32) {
|
||||
compare_outputs<float>(output, output_opt);
|
||||
} else if (params.input_type == data_types::f16) {
|
||||
compare_outputs<FLOAT16>(output, output_opt);
|
||||
} else if (params.input_type == data_types::i8) {
|
||||
compare_outputs<int8_t>(output, output_opt);
|
||||
} else if (params.input_type == data_types::u8) {
|
||||
compare_outputs<uint8_t>(output, output_opt);
|
||||
} else {
|
||||
FAIL() << "Not supported data type: " << static_cast<size_t>(params.input_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct caffe_resample_random_test_param_generator : std::vector<caffe_resample_random_test_params> {
|
||||
caffe_resample_random_test_param_generator& add(caffe_resample_random_test_params params) {
|
||||
push_back(params);
|
||||
return *this;
|
||||
}
|
||||
|
||||
caffe_resample_random_test_param_generator& smoke_params(data_types type, format::type input_format, format::type output_format) {
|
||||
push_back(caffe_resample_random_test_params{ type, {1, 512, 16, 16}, {1, 512, 32, 32}, 1, resample_type::caffe_bilinear, 1, input_format, output_format, {}, {}});
|
||||
push_back(caffe_resample_random_test_params{ type, {1, 512, 32, 32}, {1, 512, 16, 16}, 1, resample_type::caffe_bilinear, 1, input_format, output_format, {}, {}});
|
||||
push_back(caffe_resample_random_test_params{ type, {1, 24, 32, 32}, {1, 24, 64, 64}, 1, resample_type::caffe_bilinear, 1, input_format, output_format, {}, {}});
|
||||
push_back(caffe_resample_random_test_params{ type, {1, 24, 96, 96}, {1, 24, 32, 32}, 1, resample_type::caffe_bilinear, 1, input_format, output_format, {}, {}});
|
||||
push_back(caffe_resample_random_test_params{ type, {1, 8, 64, 64}, {1, 8, 32, 32}, 1, resample_type::caffe_bilinear, 1, input_format, output_format, {}, {}});
|
||||
push_back(caffe_resample_random_test_params{ type, {1, 20, 10, 10}, {1, 20, 20, 20}, 1, resample_type::caffe_bilinear, 1, input_format, output_format, {}, {}});
|
||||
push_back(caffe_resample_random_test_params{ type, {1, 20, 20, 20}, {1, 20, 10, 10}, 1, resample_type::caffe_bilinear, 1, input_format, output_format, {}, {}});
|
||||
// Padding applied
|
||||
push_back(caffe_resample_random_test_params{ type, {1, 96, 16, 16}, {1, 96, 32, 32}, 1, resample_type::caffe_bilinear, 1, input_format, output_format, {0, 0, 1, 1}, {0, 0, 1, 1}});
|
||||
push_back(caffe_resample_random_test_params{ type, {1, 96, 32, 32}, {1, 96, 16, 16}, 1, resample_type::caffe_bilinear, 1, input_format, output_format, {0, 0, 1, 1}, {0, 0, 1, 1}});
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(caffe_resample_random_test, random) {
|
||||
auto param = GetParam();
|
||||
execute_compare(param, true);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(caffe_smoke_caffe_fsv16,
|
||||
caffe_resample_random_test,
|
||||
testing::ValuesIn(
|
||||
caffe_resample_random_test_param_generator()
|
||||
.smoke_params(data_types::f32, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
|
||||
.smoke_params(data_types::f16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
|
||||
), );
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(caffe_smoke_caffe_fsv32,
|
||||
caffe_resample_random_test,
|
||||
testing::ValuesIn(
|
||||
caffe_resample_random_test_param_generator()
|
||||
.smoke_params(data_types::f16, format::fs_b_yx_fsv32, format::fs_b_yx_fsv32)
|
||||
), );
|
||||
|
||||
TEST(resample_gpu, interpolate_in2x2x3x2_nearest1) {
|
||||
// Input : 2x2x3x2
|
||||
// Output : 2x2x6x4
|
||||
|
Loading…
Reference in New Issue
Block a user