[GPU] Softmax blocked layouts support (#12467)

This commit is contained in:
Oleksii Khovan 2022-09-12 12:48:45 +03:00 committed by GitHub
parent 621bf375c1
commit 0145865301
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 436 additions and 19 deletions

View File

@ -69,16 +69,25 @@ struct softmax_impl : typed_primitive_impl_ocl<softmax> {
namespace detail {
attach_softmax_impl::attach_softmax_impl() {
implementation_map<softmax>::add(impl_types::ocl, softmax_impl::create, {
std::make_tuple(data_types::f32, format::yxfb),
std::make_tuple(data_types::f16, format::yxfb),
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::f32, format::byxf),
std::make_tuple(data_types::f16, format::byxf),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
});
auto types = {data_types::f16, data_types::f32};
auto formats = {
format::bfyx,
format::yxfb,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32,
format::bfzyx,
format::b_fs_zyx_fsv16,
format::b_fs_zyx_fsv32,
format::bs_fs_zyx_bsv16_fsv32,
format::bs_fs_zyx_bsv16_fsv16,
format::bs_fs_zyx_bsv32_fsv32,
format::bs_fs_zyx_bsv32_fsv16
};
implementation_map<softmax>::add(impl_types::ocl, softmax_impl::create, types, formats);
}
} // namespace detail

View File

@ -22,16 +22,12 @@ ParamsKey SoftmaxItemsClassKernelBase::GetDefaultSupportedKey() {
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::f);
k.EnableOutputLayout(DataLayout::f);
k.EnableInputLayout(DataLayout::b_fs_zyx_fsv16);
k.EnableInputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::byxf);
k.EnableOutputLayout(DataLayout::yxfb);
k.EnableOutputLayout(DataLayout::bf);
k.EnableOutputLayout(DataLayout::fb);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::b_fs_zyx_fsv16);
k.EnableOutputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
k.EnableSoftmaxDim(SoftmaxDim::X);
k.EnableSoftmaxDim(SoftmaxDim::Y);
k.EnableSoftmaxDim(SoftmaxDim::Z);

View File

@ -6,7 +6,26 @@
#include "kernel_selector_utils.h"
namespace kernel_selector {
ParamsKey SoftmaxKernelRef::GetSupportedKey() const { return GetDefaultSupportedKey(); }
ParamsKey SoftmaxKernelRef::GetSupportedKey() const {
auto k = GetDefaultSupportedKey();
k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
k.EnableInputLayout(DataLayout::b_fs_yx_fsv32);
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv32);
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv16_fsv16);
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv16_fsv16);
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv32_fsv16);
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv32_fsv16);
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv32_fsv32);
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv32_fsv32);
k.EnableInputLayout(DataLayout::b_fs_zyx_fsv16);
k.EnableInputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
k.EnableOutputLayout(DataLayout::b_fs_zyx_fsv16);
k.EnableOutputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
return k;
}
SoftmaxKernelRef::Parent::DispatchData SoftmaxKernelRef::SetDefault(const softmax_params& params,
const optional_params& optParams) const {
@ -28,4 +47,12 @@ KernelsPriority SoftmaxKernelRef::GetKernelsPriority(const Params& /*params*/, c
KernelsData SoftmaxKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
return GetCommonKernelsData(params, options);
}
JitConstants SoftmaxKernelRef::GetJitConstants(const softmax_params& params, DispatchData dispatchData) const {
auto jit = Parent::GetJitConstants(params, dispatchData);
if (!SimpleLayout(params.inputs[0].GetLayout())) {
jit.AddConstant(MakeJitConstant("SOFTMAX_DIM_" + toString(params.dim), "1"));
}
return jit;
}
} // namespace kernel_selector

View File

@ -19,5 +19,6 @@ public:
protected:
DispatchData SetDefault(const softmax_params& params, const optional_params& optParams) const override;
JitConstants GetJitConstants(const softmax_params& params, DispatchData dispatchData) const override;
};
} // namespace kernel_selector

View File

@ -20,4 +20,4 @@ softmax_kernel_selector::softmax_kernel_selector() {
KernelsData softmax_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
return GetNaiveBestKernel(params, options, KernelType::SOFT_MAX);
}
} // namespace kernel_selector
} // namespace kernel_selector

View File

@ -13,6 +13,8 @@ KERNEL(softmax)(
, FUSED_OPS_DECLS
#endif
) {
uint cls = 0;
#if INPUT0_SIMPLE == 1
#if INPUT0_DIMS == 5
const uint other0 = (uint)get_global_id(0) % INPUT0_OTHER0_SIZE;
const uint other2 = (uint)get_global_id(0) / INPUT0_OTHER0_SIZE;
@ -25,13 +27,69 @@ KERNEL(softmax)(
const uint in_depth_offset = other3*INPUT0_OTHER3_PITCH + other2*INPUT0_OTHER2_PITCH + other1*INPUT0_OTHER1_PITCH + other0*INPUT0_OTHER0_PITCH + INPUT0_OFFSET;
const uint out_depth_offset = other3*OUTPUT_OTHER3_PITCH + other2*OUTPUT_OTHER2_PITCH + other1*OUTPUT_OTHER1_PITCH + other0*OUTPUT_OTHER0_PITCH + OUTPUT_OFFSET;
#else // blocked format
const uint no_offset = 0;
uint *b_offset, *f_offset, *z_offset, *y_offset, *x_offset;
b_offset = f_offset = z_offset = y_offset = x_offset = &no_offset;
#if SOFTMAX_DIM_X
x_offset = &cls;
const uint b = get_global_id(2);
const uint f = get_global_id(1);
const uint z = (uint)get_global_id(0) % INPUT0_SIZE_Z;
const uint y = (uint)get_global_id(0) / INPUT0_SIZE_Z;
const uint x = 0;
#elif SOFTMAX_DIM_Y
y_offset = &cls;
const uint b = get_global_id(2);
const uint f = get_global_id(1);
const uint z = (uint)get_global_id(0) / INPUT0_SIZE_X;
const uint y = 0;
const uint x = (uint)get_global_id(0) % INPUT0_SIZE_X;
#elif SOFTMAX_DIM_Z
z_offset = &cls;
const uint b = get_global_id(2);
const uint f = get_global_id(1);
const uint z = 0;
const uint y = (uint)get_global_id(0) / INPUT0_SIZE_X;
const uint x = (uint)get_global_id(0) % INPUT0_SIZE_X;
#elif SOFTMAX_DIM_FEATURE
f_offset = &cls;
const uint b = get_global_id(2);
const uint f = 0;
const uint z = (uint)get_global_id(0) / INPUT0_SIZE_X;
const uint y = get_global_id(1);
const uint x = (uint)get_global_id(0) % INPUT0_SIZE_X;
#elif SOFTMAX_DIM_BATCH
b_offset = &cls;
const uint b = 0;
const uint f = get_global_id(2);
const uint z = (uint)get_global_id(0) / INPUT0_SIZE_X;
const uint y = get_global_id(1);
const uint x = (uint)get_global_id(0) % INPUT0_SIZE_X;
#else
#error Wrong axis
#endif
#endif
ACCUMULATOR_TYPE max_value = UNIT_VAL_MIN;
ACCUMULATOR_TYPE data[INPUT0_CLASS_NUM];
for (uint cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
for (cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
{
#if INPUT0_SIMPLE == 1
const uint index = in_depth_offset + cls*INPUT0_CLASS_PITCH;
#else
#if INPUT0_DIMS == 5
const uint index = INPUT0_GET_INDEX(b + *b_offset, f + *f_offset, z + *z_offset, y + *y_offset, x + *x_offset);
#else
const uint index = INPUT0_GET_INDEX(b + *b_offset, f + *f_offset, y + *y_offset, x + *x_offset);
#endif
#endif
ACCUMULATOR_TYPE in = input[index];
max_value = max(max_value, in);
data[cls] = in;
@ -39,16 +97,24 @@ KERNEL(softmax)(
// TODO: currently we calculate on float32 because it's lot of "add" operation and it stuck on the value "8192.0f"
ACCUMULATOR_TYPE denominator = 0.0;
for (uint cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
for (cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
{
data[cls] = native_exp(data[cls] - max_value);
denominator += data[cls];
}
for (uint cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
for (cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
{
const ACCUMULATOR_TYPE res = data[cls] / denominator;
#if INPUT0_SIMPLE == 1
const uint output_idx = out_depth_offset + cls*OUTPUT_CLASS_PITCH;
#else
#if INPUT0_DIMS == 5
const uint output_idx = OUTPUT_GET_INDEX(b + *b_offset, f + *f_offset, z + *z_offset, y + *y_offset, x + *x_offset);
#else
const uint output_idx = OUTPUT_GET_INDEX(b + *b_offset, f + *f_offset, y + *y_offset, x + *x_offset);
#endif
#endif
#if HAS_FUSED_OPS
FUSED_OPS;
output[output_idx] = FUSED_OPS_RESULT;

View File

@ -658,3 +658,321 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_SOFTMAX,
softmax_test,
::testing::Combine(::testing::ValuesIn(softmax_test::generate_generic_test_params()), ::testing::ValuesIn(softmax_test::generate_specific_test_params())),
softmax_test::custom_param_name);
namespace {
template<typename T>
struct SoftmaxParams {
int64_t axis;
tensor input_tensor;
std::vector<T> input;
std::vector<T> expected;
};
template<typename T>
using SoftmaxParamsWithFormat = std::tuple<
SoftmaxParams<T>,
format::type, // source (plain) layout
format::type // target (blocked) layout
>;
const std::vector<format::type> formats2D{
format::bfyx,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32
};
const std::vector<format::type> formats3D{
format::bfzyx,
format::b_fs_zyx_fsv16,
format::bs_fs_zyx_bsv16_fsv16
};
template<typename T>
std::vector<T> getValues(const std::vector<float> &values) {
std::vector<T> result(values.begin(), values.end());
return result;
}
template<typename T>
std::vector<SoftmaxParams<T>> generateSoftmaxParams2D() {
const std::vector<SoftmaxParams<T>> result = {
{
0,
tensor(3, 2, 2, 2),
getValues<T>({
0.1f, -0.1f, 0.9f, 1.5f, 3.f, 0.5f, 7.f, 12.f, 0.2f, 0.2f, -10.f, 5.2f,
4.f, 0.5f, 8.f, 8.2f, 0.2f, 0.2f, -10.f, 5.2f, 0.2f, 0.2f, -10.f, 5.2f}),
getValues<T>({
0.311493f, 0.270291f, 0.999963f, 0.0122108f,
0.264614f, 0.364855f, 0.268941f, 0.977054f,
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f,
0.719295f, 0.364855f, 0.731059f, 0.0218575f,
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f,
0.0160912f, 0.270291f, 1.1134e-08f, 0.00108822f})
},
{
1,
tensor(2, 3, 2, 2),
getValues<T>({
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, 0.2f, -10.f, 5.2f, 0.2f, 0.2f, -10.f, 5.2f,
3.f, 0.5f, 7.f, 12.f, 4.f, 0.5f, 8.f, 8.2f, 0.2f, 0.2f, -10.f, 5.2f}),
getValues<T>({
0.311493f, 0.270291f, 0.999963f, 0.0122108f,
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f,
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f,
0.264614f, 0.364855f, 0.268941f, 0.977054f,
0.719295f, 0.364855f, 0.731059f, 0.0218575f,
0.0160912f, 0.270291f, 1.1134e-08f, 0.00108822f})
},
{
2,
tensor(2, 3, 2, 2),
getValues<T>({
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, 0.2f, -10.f, 5.2f, 0.2f, 0.2f, -10.f, 5.2f,
3.f, 0.5f, 7.f, 12.f, 4.f, 0.5f, 8.f, 8.2f, 0.2f, 0.2f, -10.f, 5.2f}),
getValues<T>({
0.310026f, 0.167982f, 0.689974f, 0.832018f,
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f,
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f,
0.0179862f, 1.013e-05f, 0.982014f, 0.99999f,
0.0179862f, 0.000452622f, 0.982014f, 0.999547f,
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f})
},
{
3,
tensor(2, 3, 2, 2),
getValues<T>({
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, 0.2f, -10.f, 5.2f, 0.2f, 0.2f, -10.f, 5.2f,
3.f, 0.5f, 7.f, 12.f, 4.f, 0.5f, 8.f, 8.2f, 0.2f, 0.2f, -10.f, 5.2f}),
getValues<T>({
0.549834f, 0.450166f, 0.354344f, 0.645656f,
0.5f, 0.5f, 2.50452e-07f, 1.0f,
0.5f, 0.5f, 2.50452e-07f, 1.0f,
0.924142f, 0.0758582f, 0.00669285f, 0.993307f,
0.970688f, 0.0293122f, 0.450166f, 0.549834f,
0.5f, 0.5f, 2.50452e-07f, 1.0f})
},
};
return result;
}
template<typename T>
std::vector<SoftmaxParams<T>> generateSoftmaxParams3D() {
const std::vector<SoftmaxParams<T>> result = {
{
0,
tensor(2, 3, 2, 2, 2),
getValues<T>({
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
getValues<T>({
0.0521536f, 0.354344f, 0.00223785f, 2.75357e-05f, 0.00816257f, 0.425557f, 0.0060598f, 3.39827e-09f,
0.0218813f, 0.425557f, 1.523e-08f, 0.0474259f, 0.130108f, 0.450166f, 4.13994e-08f, 0.731059f,
0.5f, 0.5f, 0.5f, 0.5f, 0.24974f, 0.5f, 0.952574f, 0.880797f,
0.947846f, 0.645656f, 0.997762f, 0.999972f, 0.991837f, 0.574443f, 0.99394f, 1.0f,
0.978119f, 0.574443f, 1.0f, 0.952574f, 0.869892f, 0.549834f, 1.0f, 0.268941f,
0.5f, 0.5f, 0.5f, 0.5f, 0.75026f, 0.5f, 0.0474259f, 0.119203f})
},
{
1,
tensor(2, 3, 2, 2, 2),
getValues<T>({
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
getValues<T>({
0.311493f, 0.270291f, 0.999963f, 0.0122108f, 0.332225f, 0.250089f, 0.999943f, 0.0213123f,
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f, 0.367165f, 0.337585f, 6.79002e-06f, 0.862025f,
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f, 0.30061f, 0.412327f, 5.01718e-05f, 0.116662f,
0.264614f, 0.364855f, 0.268941f, 0.977054f, 0.923207f, 0.290461f, 0.5f, 1.0f,
0.719295f, 0.364855f, 0.731059f, 0.0218575f, 0.0561403f, 0.35477f, 0.5f, 5.05653e-08f,
0.0160912f, 0.270291f, 1.1134e-08f, 0.00108822f, 0.0206528f, 0.35477f, 7.615e-09f, 2.5175e-09f})
},
{
2,
tensor(2, 3, 2, 2, 2),
getValues<T>({
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
getValues<T>({
0.475021f, 0.524979f, 0.5f, 0.268941f, 0.524979f, 0.475021f, 0.5f, 0.731059f,
0.475021f, 0.524979f, 0.731059f, 0.268941f, 0.524979f, 0.475021f, 0.268941f, 0.731059f,
0.524979f, 0.475021f, 0.268941f, 0.731059f, 0.475021f, 0.524979f, 0.731059f, 0.268941f,
0.119203f, 0.598688f, 0.731059f, 4.53979e-05f, 0.880797f, 0.401312f, 0.268941f, 0.999955f,
0.858149f, 0.549834f, 0.880797f, 0.952574f, 0.141851f, 0.450166f, 0.119203f, 0.0474259f,
0.268941f, 0.475021f, 0.880797f, 0.952574f, 0.731059f, 0.524979f, 0.119203f, 0.0474259f})
},
{
3,
tensor(2, 3, 2, 2, 2),
getValues<T>({
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
getValues<T>({
0.310026f, 0.167982f, 0.689974f, 0.832018f, 0.331812f, 0.0629734f, 0.668188f, 0.937027f,
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f, 0.999988f, 0.00223785f, 1.23728e-05f, 0.997762f,
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f, 0.999888f, 0.0198403f, 0.000111653f, 0.98016f,
0.0179862f, 1.013e-05f, 0.982014f, 0.99999f, 0.268941f, 3.08284e-10f, 0.731059f, 1.0f,
0.0179862f, 0.000452622f, 0.982014f, 0.999547f, 0.0218813f, 0.00739154f, 0.978119f, 0.992609f,
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f, 0.999998f, 0.130108f, 1.8506e-06f, 0.869892f})
},
{
4,
tensor(2, 3, 2, 2, 2),
getValues<T>({
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
getValues<T>({
0.549834f, 0.450166f, 0.354344f, 0.645656f, 0.598688f, 0.401312f, 0.167982f, 0.832018f,
0.5f, 0.5f, 2.50452e-07f, 1.0f, 0.549834f, 0.450166f, 3.38949e-08f, 1.0f,
0.5f, 0.5f, 2.50452e-07f, 1.0f, 0.450166f, 0.549834f, 1.8506e-06f, 0.999998f,
0.924142f, 0.0758582f, 0.00669285f, 0.993307f, 0.992609f, 0.00739154f, 1.12535e-07f, 1.0f,
0.970688f, 0.0293122f, 0.450166f, 0.549834f, 0.869892f, 0.130108f, 0.689974f, 0.310025f,
0.5f, 0.5f, 2.50452e-07f, 1.0f, 0.710949f, 0.28905f, 6.80798e-07f, 0.999999f})
}
};
return result;
}
template<typename T>
float getError();
template<>
float getError<float>() {
return 0.001;
}
template<>
float getError<half_t>() {
return 0.2;
}
struct PrintToStringParamName {
template<class T>
std::string operator()(const testing::TestParamInfo<SoftmaxParamsWithFormat<T> > &param) {
std::stringstream buf;
SoftmaxParams<T> p;
format::type plain_format;
format::type target_format;
std::tie(p, plain_format, target_format) = param.param;
buf << "_inputTensor=" << p.input_tensor.to_string()
<< "_axis=" << p.axis
<< "_plainFormat=" << fmt_to_str(plain_format)
<< "_targetFormat=" << fmt_to_str(target_format);
return buf.str();
}
};
}; // namespace
template<typename T>
struct softmax_gpu_formats_test
: public ::testing::TestWithParam<SoftmaxParamsWithFormat<T> > {
public:
void test() {
const auto data_type = type_to_data_type<T>::value;
SoftmaxParams<T> params;
format::type plain_format;
format::type target_format;
std::tie(params, plain_format, target_format) = this->GetParam();
auto& engine = get_test_engine();
const auto input = engine.allocate_memory({data_type, plain_format, params.input_tensor});
topology topology;
topology.add(input_layout("input", input->get_layout()));
topology.add(reorder("reordered_input", "input", target_format, data_type));
topology.add(softmax("blocked_softmax", "reordered_input", params.axis));
topology.add(reorder("softmax", "blocked_softmax", plain_format, data_type));
set_values(input, params.input);
build_options bo;
bo.set_option(build_option::optimize_data(false));
network network(engine, topology);
network.set_input_data("input", input);
const auto outputs = network.execute();
const auto output = outputs.at("softmax").get_memory();
const cldnn::mem_lock<T> output_ptr(output, get_test_stream());
ASSERT_EQ(params.input_tensor.count(), output_ptr.size());
for (uint32_t i = 0; i < output_ptr.size(); i++) {
EXPECT_NEAR(output_ptr[i], params.expected[i], getError<T>()) << "target_format=" << target_format << ", i=" << i;
}
}
};
using softmax_gpu_formats_test_f32 = softmax_gpu_formats_test<float>;
using softmax_gpu_formats_test_f16 = softmax_gpu_formats_test<half_t>;
TEST_P(softmax_gpu_formats_test_f32, softmax_gpu_formats_test_f32) {
ASSERT_NO_FATAL_FAILURE(test());
}
TEST_P(softmax_gpu_formats_test_f16, softmax_gpu_formats_test_f16) {
ASSERT_NO_FATAL_FAILURE(test());
}
INSTANTIATE_TEST_SUITE_P(softmax_gpu_formats_test_f32_2d,
softmax_gpu_formats_test_f32,
::testing::Combine(
::testing::ValuesIn(generateSoftmaxParams2D<float>()),
::testing::Values(format::bfyx),
::testing::ValuesIn(formats2D)
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(softmax_gpu_formats_test_f16_2d,
softmax_gpu_formats_test_f16,
::testing::Combine(
::testing::ValuesIn(generateSoftmaxParams2D<half_t>()),
::testing::Values(format::bfyx),
::testing::ValuesIn(formats2D)
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(softmax_gpu_formats_test_f32_3d,
softmax_gpu_formats_test_f32,
::testing::Combine(
::testing::ValuesIn(generateSoftmaxParams3D<float>()),
::testing::Values(format::bfzyx),
::testing::ValuesIn(formats3D)
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(softmax_gpu_formats_test_f16_3d,
softmax_gpu_formats_test_f16,
::testing::Combine(
::testing::ValuesIn(generateSoftmaxParams3D<half_t>()),
::testing::Values(format::bfzyx),
::testing::ValuesIn(formats3D)
),
PrintToStringParamName());