[GPU] Softmax blocked layouts support (#12467)
This commit is contained in:
parent
621bf375c1
commit
0145865301
@ -69,16 +69,25 @@ struct softmax_impl : typed_primitive_impl_ocl<softmax> {
|
||||
namespace detail {
|
||||
|
||||
attach_softmax_impl::attach_softmax_impl() {
|
||||
implementation_map<softmax>::add(impl_types::ocl, softmax_impl::create, {
|
||||
std::make_tuple(data_types::f32, format::yxfb),
|
||||
std::make_tuple(data_types::f16, format::yxfb),
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::f32, format::byxf),
|
||||
std::make_tuple(data_types::f16, format::byxf),
|
||||
std::make_tuple(data_types::f32, format::bfzyx),
|
||||
std::make_tuple(data_types::f16, format::bfzyx),
|
||||
});
|
||||
auto types = {data_types::f16, data_types::f32};
|
||||
auto formats = {
|
||||
format::bfyx,
|
||||
format::yxfb,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::bs_fs_yx_bsv16_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv32,
|
||||
format::bfzyx,
|
||||
format::b_fs_zyx_fsv16,
|
||||
format::b_fs_zyx_fsv32,
|
||||
format::bs_fs_zyx_bsv16_fsv32,
|
||||
format::bs_fs_zyx_bsv16_fsv16,
|
||||
format::bs_fs_zyx_bsv32_fsv32,
|
||||
format::bs_fs_zyx_bsv32_fsv16
|
||||
};
|
||||
|
||||
implementation_map<softmax>::add(impl_types::ocl, softmax_impl::create, types, formats);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
@ -22,16 +22,12 @@ ParamsKey SoftmaxItemsClassKernelBase::GetDefaultSupportedKey() {
|
||||
k.EnableInputLayout(DataLayout::bfzyx);
|
||||
k.EnableInputLayout(DataLayout::f);
|
||||
k.EnableOutputLayout(DataLayout::f);
|
||||
k.EnableInputLayout(DataLayout::b_fs_zyx_fsv16);
|
||||
k.EnableInputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::byxf);
|
||||
k.EnableOutputLayout(DataLayout::yxfb);
|
||||
k.EnableOutputLayout(DataLayout::bf);
|
||||
k.EnableOutputLayout(DataLayout::fb);
|
||||
k.EnableOutputLayout(DataLayout::bfzyx);
|
||||
k.EnableOutputLayout(DataLayout::b_fs_zyx_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
|
||||
k.EnableSoftmaxDim(SoftmaxDim::X);
|
||||
k.EnableSoftmaxDim(SoftmaxDim::Y);
|
||||
k.EnableSoftmaxDim(SoftmaxDim::Z);
|
||||
|
@ -6,7 +6,26 @@
|
||||
#include "kernel_selector_utils.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
ParamsKey SoftmaxKernelRef::GetSupportedKey() const { return GetDefaultSupportedKey(); }
|
||||
ParamsKey SoftmaxKernelRef::GetSupportedKey() const {
|
||||
auto k = GetDefaultSupportedKey();
|
||||
|
||||
k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
|
||||
k.EnableInputLayout(DataLayout::b_fs_yx_fsv32);
|
||||
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv32);
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv16_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv16_fsv16);
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv32_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv32_fsv16);
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv32_fsv32);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv32_fsv32);
|
||||
k.EnableInputLayout(DataLayout::b_fs_zyx_fsv16);
|
||||
k.EnableInputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::b_fs_zyx_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
SoftmaxKernelRef::Parent::DispatchData SoftmaxKernelRef::SetDefault(const softmax_params& params,
|
||||
const optional_params& optParams) const {
|
||||
@ -28,4 +47,12 @@ KernelsPriority SoftmaxKernelRef::GetKernelsPriority(const Params& /*params*/, c
|
||||
KernelsData SoftmaxKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||
return GetCommonKernelsData(params, options);
|
||||
}
|
||||
|
||||
JitConstants SoftmaxKernelRef::GetJitConstants(const softmax_params& params, DispatchData dispatchData) const {
|
||||
auto jit = Parent::GetJitConstants(params, dispatchData);
|
||||
if (!SimpleLayout(params.inputs[0].GetLayout())) {
|
||||
jit.AddConstant(MakeJitConstant("SOFTMAX_DIM_" + toString(params.dim), "1"));
|
||||
}
|
||||
return jit;
|
||||
}
|
||||
} // namespace kernel_selector
|
||||
|
@ -19,5 +19,6 @@ public:
|
||||
|
||||
protected:
|
||||
DispatchData SetDefault(const softmax_params& params, const optional_params& optParams) const override;
|
||||
JitConstants GetJitConstants(const softmax_params& params, DispatchData dispatchData) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
@ -20,4 +20,4 @@ softmax_kernel_selector::softmax_kernel_selector() {
|
||||
KernelsData softmax_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
|
||||
return GetNaiveBestKernel(params, options, KernelType::SOFT_MAX);
|
||||
}
|
||||
} // namespace kernel_selector
|
||||
} // namespace kernel_selector
|
||||
|
@ -13,6 +13,8 @@ KERNEL(softmax)(
|
||||
, FUSED_OPS_DECLS
|
||||
#endif
|
||||
) {
|
||||
uint cls = 0;
|
||||
#if INPUT0_SIMPLE == 1
|
||||
#if INPUT0_DIMS == 5
|
||||
const uint other0 = (uint)get_global_id(0) % INPUT0_OTHER0_SIZE;
|
||||
const uint other2 = (uint)get_global_id(0) / INPUT0_OTHER0_SIZE;
|
||||
@ -25,13 +27,69 @@ KERNEL(softmax)(
|
||||
|
||||
const uint in_depth_offset = other3*INPUT0_OTHER3_PITCH + other2*INPUT0_OTHER2_PITCH + other1*INPUT0_OTHER1_PITCH + other0*INPUT0_OTHER0_PITCH + INPUT0_OFFSET;
|
||||
const uint out_depth_offset = other3*OUTPUT_OTHER3_PITCH + other2*OUTPUT_OTHER2_PITCH + other1*OUTPUT_OTHER1_PITCH + other0*OUTPUT_OTHER0_PITCH + OUTPUT_OFFSET;
|
||||
#else // blocked format
|
||||
const uint no_offset = 0;
|
||||
uint *b_offset, *f_offset, *z_offset, *y_offset, *x_offset;
|
||||
b_offset = f_offset = z_offset = y_offset = x_offset = &no_offset;
|
||||
|
||||
#if SOFTMAX_DIM_X
|
||||
x_offset = &cls;
|
||||
|
||||
const uint b = get_global_id(2);
|
||||
const uint f = get_global_id(1);
|
||||
const uint z = (uint)get_global_id(0) % INPUT0_SIZE_Z;
|
||||
const uint y = (uint)get_global_id(0) / INPUT0_SIZE_Z;
|
||||
const uint x = 0;
|
||||
#elif SOFTMAX_DIM_Y
|
||||
y_offset = &cls;
|
||||
|
||||
const uint b = get_global_id(2);
|
||||
const uint f = get_global_id(1);
|
||||
const uint z = (uint)get_global_id(0) / INPUT0_SIZE_X;
|
||||
const uint y = 0;
|
||||
const uint x = (uint)get_global_id(0) % INPUT0_SIZE_X;
|
||||
#elif SOFTMAX_DIM_Z
|
||||
z_offset = &cls;
|
||||
|
||||
const uint b = get_global_id(2);
|
||||
const uint f = get_global_id(1);
|
||||
const uint z = 0;
|
||||
const uint y = (uint)get_global_id(0) / INPUT0_SIZE_X;
|
||||
const uint x = (uint)get_global_id(0) % INPUT0_SIZE_X;
|
||||
#elif SOFTMAX_DIM_FEATURE
|
||||
f_offset = &cls;
|
||||
|
||||
const uint b = get_global_id(2);
|
||||
const uint f = 0;
|
||||
const uint z = (uint)get_global_id(0) / INPUT0_SIZE_X;
|
||||
const uint y = get_global_id(1);
|
||||
const uint x = (uint)get_global_id(0) % INPUT0_SIZE_X;
|
||||
#elif SOFTMAX_DIM_BATCH
|
||||
b_offset = &cls;
|
||||
|
||||
const uint b = 0;
|
||||
const uint f = get_global_id(2);
|
||||
const uint z = (uint)get_global_id(0) / INPUT0_SIZE_X;
|
||||
const uint y = get_global_id(1);
|
||||
const uint x = (uint)get_global_id(0) % INPUT0_SIZE_X;
|
||||
#else
|
||||
#error Wrong axis
|
||||
#endif
|
||||
#endif
|
||||
ACCUMULATOR_TYPE max_value = UNIT_VAL_MIN;
|
||||
ACCUMULATOR_TYPE data[INPUT0_CLASS_NUM];
|
||||
|
||||
for (uint cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
|
||||
for (cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
|
||||
{
|
||||
#if INPUT0_SIMPLE == 1
|
||||
const uint index = in_depth_offset + cls*INPUT0_CLASS_PITCH;
|
||||
#else
|
||||
#if INPUT0_DIMS == 5
|
||||
const uint index = INPUT0_GET_INDEX(b + *b_offset, f + *f_offset, z + *z_offset, y + *y_offset, x + *x_offset);
|
||||
#else
|
||||
const uint index = INPUT0_GET_INDEX(b + *b_offset, f + *f_offset, y + *y_offset, x + *x_offset);
|
||||
#endif
|
||||
#endif
|
||||
ACCUMULATOR_TYPE in = input[index];
|
||||
max_value = max(max_value, in);
|
||||
data[cls] = in;
|
||||
@ -39,16 +97,24 @@ KERNEL(softmax)(
|
||||
|
||||
// TODO: currently we calculate on float32 because it's lot of "add" operation and it stuck on the value "8192.0f"
|
||||
ACCUMULATOR_TYPE denominator = 0.0;
|
||||
for (uint cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
|
||||
for (cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
|
||||
{
|
||||
data[cls] = native_exp(data[cls] - max_value);
|
||||
denominator += data[cls];
|
||||
}
|
||||
|
||||
for (uint cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
|
||||
for (cls = 0; cls < INPUT0_CLASS_NUM; ++cls)
|
||||
{
|
||||
const ACCUMULATOR_TYPE res = data[cls] / denominator;
|
||||
#if INPUT0_SIMPLE == 1
|
||||
const uint output_idx = out_depth_offset + cls*OUTPUT_CLASS_PITCH;
|
||||
#else
|
||||
#if INPUT0_DIMS == 5
|
||||
const uint output_idx = OUTPUT_GET_INDEX(b + *b_offset, f + *f_offset, z + *z_offset, y + *y_offset, x + *x_offset);
|
||||
#else
|
||||
const uint output_idx = OUTPUT_GET_INDEX(b + *b_offset, f + *f_offset, y + *y_offset, x + *x_offset);
|
||||
#endif
|
||||
#endif
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS;
|
||||
output[output_idx] = FUSED_OPS_RESULT;
|
||||
|
@ -658,3 +658,321 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_SOFTMAX,
|
||||
softmax_test,
|
||||
::testing::Combine(::testing::ValuesIn(softmax_test::generate_generic_test_params()), ::testing::ValuesIn(softmax_test::generate_specific_test_params())),
|
||||
softmax_test::custom_param_name);
|
||||
|
||||
|
||||
|
||||
namespace {
|
||||
template<typename T>
|
||||
struct SoftmaxParams {
|
||||
int64_t axis;
|
||||
tensor input_tensor;
|
||||
std::vector<T> input;
|
||||
std::vector<T> expected;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
using SoftmaxParamsWithFormat = std::tuple<
|
||||
SoftmaxParams<T>,
|
||||
format::type, // source (plain) layout
|
||||
format::type // target (blocked) layout
|
||||
>;
|
||||
|
||||
const std::vector<format::type> formats2D{
|
||||
format::bfyx,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::bs_fs_yx_bsv16_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv32
|
||||
};
|
||||
|
||||
const std::vector<format::type> formats3D{
|
||||
format::bfzyx,
|
||||
format::b_fs_zyx_fsv16,
|
||||
format::bs_fs_zyx_bsv16_fsv16
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
std::vector<T> getValues(const std::vector<float> &values) {
|
||||
std::vector<T> result(values.begin(), values.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::vector<SoftmaxParams<T>> generateSoftmaxParams2D() {
|
||||
const std::vector<SoftmaxParams<T>> result = {
|
||||
{
|
||||
0,
|
||||
tensor(3, 2, 2, 2),
|
||||
getValues<T>({
|
||||
0.1f, -0.1f, 0.9f, 1.5f, 3.f, 0.5f, 7.f, 12.f, 0.2f, 0.2f, -10.f, 5.2f,
|
||||
4.f, 0.5f, 8.f, 8.2f, 0.2f, 0.2f, -10.f, 5.2f, 0.2f, 0.2f, -10.f, 5.2f}),
|
||||
getValues<T>({
|
||||
0.311493f, 0.270291f, 0.999963f, 0.0122108f,
|
||||
0.264614f, 0.364855f, 0.268941f, 0.977054f,
|
||||
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f,
|
||||
0.719295f, 0.364855f, 0.731059f, 0.0218575f,
|
||||
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f,
|
||||
0.0160912f, 0.270291f, 1.1134e-08f, 0.00108822f})
|
||||
},
|
||||
{
|
||||
1,
|
||||
tensor(2, 3, 2, 2),
|
||||
getValues<T>({
|
||||
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, 0.2f, -10.f, 5.2f, 0.2f, 0.2f, -10.f, 5.2f,
|
||||
3.f, 0.5f, 7.f, 12.f, 4.f, 0.5f, 8.f, 8.2f, 0.2f, 0.2f, -10.f, 5.2f}),
|
||||
getValues<T>({
|
||||
0.311493f, 0.270291f, 0.999963f, 0.0122108f,
|
||||
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f,
|
||||
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f,
|
||||
0.264614f, 0.364855f, 0.268941f, 0.977054f,
|
||||
0.719295f, 0.364855f, 0.731059f, 0.0218575f,
|
||||
0.0160912f, 0.270291f, 1.1134e-08f, 0.00108822f})
|
||||
},
|
||||
{
|
||||
2,
|
||||
tensor(2, 3, 2, 2),
|
||||
getValues<T>({
|
||||
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, 0.2f, -10.f, 5.2f, 0.2f, 0.2f, -10.f, 5.2f,
|
||||
3.f, 0.5f, 7.f, 12.f, 4.f, 0.5f, 8.f, 8.2f, 0.2f, 0.2f, -10.f, 5.2f}),
|
||||
getValues<T>({
|
||||
0.310026f, 0.167982f, 0.689974f, 0.832018f,
|
||||
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f,
|
||||
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f,
|
||||
0.0179862f, 1.013e-05f, 0.982014f, 0.99999f,
|
||||
0.0179862f, 0.000452622f, 0.982014f, 0.999547f,
|
||||
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f})
|
||||
},
|
||||
{
|
||||
3,
|
||||
tensor(2, 3, 2, 2),
|
||||
getValues<T>({
|
||||
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, 0.2f, -10.f, 5.2f, 0.2f, 0.2f, -10.f, 5.2f,
|
||||
3.f, 0.5f, 7.f, 12.f, 4.f, 0.5f, 8.f, 8.2f, 0.2f, 0.2f, -10.f, 5.2f}),
|
||||
getValues<T>({
|
||||
0.549834f, 0.450166f, 0.354344f, 0.645656f,
|
||||
0.5f, 0.5f, 2.50452e-07f, 1.0f,
|
||||
0.5f, 0.5f, 2.50452e-07f, 1.0f,
|
||||
0.924142f, 0.0758582f, 0.00669285f, 0.993307f,
|
||||
0.970688f, 0.0293122f, 0.450166f, 0.549834f,
|
||||
0.5f, 0.5f, 2.50452e-07f, 1.0f})
|
||||
},
|
||||
|
||||
};
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::vector<SoftmaxParams<T>> generateSoftmaxParams3D() {
|
||||
const std::vector<SoftmaxParams<T>> result = {
|
||||
{
|
||||
0,
|
||||
tensor(2, 3, 2, 2, 2),
|
||||
getValues<T>({
|
||||
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
|
||||
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
|
||||
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
|
||||
getValues<T>({
|
||||
0.0521536f, 0.354344f, 0.00223785f, 2.75357e-05f, 0.00816257f, 0.425557f, 0.0060598f, 3.39827e-09f,
|
||||
0.0218813f, 0.425557f, 1.523e-08f, 0.0474259f, 0.130108f, 0.450166f, 4.13994e-08f, 0.731059f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.24974f, 0.5f, 0.952574f, 0.880797f,
|
||||
0.947846f, 0.645656f, 0.997762f, 0.999972f, 0.991837f, 0.574443f, 0.99394f, 1.0f,
|
||||
0.978119f, 0.574443f, 1.0f, 0.952574f, 0.869892f, 0.549834f, 1.0f, 0.268941f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.75026f, 0.5f, 0.0474259f, 0.119203f})
|
||||
},
|
||||
{
|
||||
1,
|
||||
tensor(2, 3, 2, 2, 2),
|
||||
getValues<T>({
|
||||
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
|
||||
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
|
||||
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
|
||||
getValues<T>({
|
||||
0.311493f, 0.270291f, 0.999963f, 0.0122108f, 0.332225f, 0.250089f, 0.999943f, 0.0213123f,
|
||||
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f, 0.367165f, 0.337585f, 6.79002e-06f, 0.862025f,
|
||||
0.344253f, 0.364855f, 1.84576e-05f, 0.493895f, 0.30061f, 0.412327f, 5.01718e-05f, 0.116662f,
|
||||
0.264614f, 0.364855f, 0.268941f, 0.977054f, 0.923207f, 0.290461f, 0.5f, 1.0f,
|
||||
0.719295f, 0.364855f, 0.731059f, 0.0218575f, 0.0561403f, 0.35477f, 0.5f, 5.05653e-08f,
|
||||
0.0160912f, 0.270291f, 1.1134e-08f, 0.00108822f, 0.0206528f, 0.35477f, 7.615e-09f, 2.5175e-09f})
|
||||
},
|
||||
{
|
||||
2,
|
||||
tensor(2, 3, 2, 2, 2),
|
||||
getValues<T>({
|
||||
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
|
||||
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
|
||||
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
|
||||
getValues<T>({
|
||||
0.475021f, 0.524979f, 0.5f, 0.268941f, 0.524979f, 0.475021f, 0.5f, 0.731059f,
|
||||
0.475021f, 0.524979f, 0.731059f, 0.268941f, 0.524979f, 0.475021f, 0.268941f, 0.731059f,
|
||||
0.524979f, 0.475021f, 0.268941f, 0.731059f, 0.475021f, 0.524979f, 0.731059f, 0.268941f,
|
||||
0.119203f, 0.598688f, 0.731059f, 4.53979e-05f, 0.880797f, 0.401312f, 0.268941f, 0.999955f,
|
||||
0.858149f, 0.549834f, 0.880797f, 0.952574f, 0.141851f, 0.450166f, 0.119203f, 0.0474259f,
|
||||
0.268941f, 0.475021f, 0.880797f, 0.952574f, 0.731059f, 0.524979f, 0.119203f, 0.0474259f})
|
||||
},
|
||||
{
|
||||
3,
|
||||
tensor(2, 3, 2, 2, 2),
|
||||
getValues<T>({
|
||||
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
|
||||
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
|
||||
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
|
||||
getValues<T>({
|
||||
0.310026f, 0.167982f, 0.689974f, 0.832018f, 0.331812f, 0.0629734f, 0.668188f, 0.937027f,
|
||||
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f, 0.999988f, 0.00223785f, 1.23728e-05f, 0.997762f,
|
||||
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f, 0.999888f, 0.0198403f, 0.000111653f, 0.98016f,
|
||||
0.0179862f, 1.013e-05f, 0.982014f, 0.99999f, 0.268941f, 3.08284e-10f, 0.731059f, 1.0f,
|
||||
0.0179862f, 0.000452622f, 0.982014f, 0.999547f, 0.0218813f, 0.00739154f, 0.978119f, 0.992609f,
|
||||
0.999963f, 0.00669285f, 3.71689e-05f, 0.993307f, 0.999998f, 0.130108f, 1.8506e-06f, 0.869892f})
|
||||
},
|
||||
{
|
||||
4,
|
||||
tensor(2, 3, 2, 2, 2),
|
||||
getValues<T>({
|
||||
0.1f, -0.1f, 0.9f, 1.5f, 0.2f, -0.2f, 0.9f, 2.5f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.3f, 0.1f, -11.f, 6.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 0.1f, 0.3f, -9.f, 4.2f,
|
||||
3.f, 0.5f, 7.f, 12.f, 5.f, 0.1f, 6.f, 22.f,
|
||||
4.f, 0.5f, 8.f, 8.2f, 2.2f, 0.3f, 6.f, 5.2f,
|
||||
0.2f, 0.2f, -10.f, 5.2f, 1.2f, 0.3f, -12.f, 2.2f}),
|
||||
getValues<T>({
|
||||
0.549834f, 0.450166f, 0.354344f, 0.645656f, 0.598688f, 0.401312f, 0.167982f, 0.832018f,
|
||||
0.5f, 0.5f, 2.50452e-07f, 1.0f, 0.549834f, 0.450166f, 3.38949e-08f, 1.0f,
|
||||
0.5f, 0.5f, 2.50452e-07f, 1.0f, 0.450166f, 0.549834f, 1.8506e-06f, 0.999998f,
|
||||
0.924142f, 0.0758582f, 0.00669285f, 0.993307f, 0.992609f, 0.00739154f, 1.12535e-07f, 1.0f,
|
||||
0.970688f, 0.0293122f, 0.450166f, 0.549834f, 0.869892f, 0.130108f, 0.689974f, 0.310025f,
|
||||
0.5f, 0.5f, 2.50452e-07f, 1.0f, 0.710949f, 0.28905f, 6.80798e-07f, 0.999999f})
|
||||
}
|
||||
};
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
float getError();
|
||||
|
||||
template<>
|
||||
float getError<float>() {
|
||||
return 0.001;
|
||||
}
|
||||
|
||||
template<>
|
||||
float getError<half_t>() {
|
||||
return 0.2;
|
||||
}
|
||||
|
||||
struct PrintToStringParamName {
|
||||
template<class T>
|
||||
std::string operator()(const testing::TestParamInfo<SoftmaxParamsWithFormat<T> > ¶m) {
|
||||
std::stringstream buf;
|
||||
SoftmaxParams<T> p;
|
||||
format::type plain_format;
|
||||
format::type target_format;
|
||||
std::tie(p, plain_format, target_format) = param.param;
|
||||
buf << "_inputTensor=" << p.input_tensor.to_string()
|
||||
<< "_axis=" << p.axis
|
||||
<< "_plainFormat=" << fmt_to_str(plain_format)
|
||||
<< "_targetFormat=" << fmt_to_str(target_format);
|
||||
return buf.str();
|
||||
}
|
||||
};
|
||||
}; // namespace
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct softmax_gpu_formats_test
|
||||
: public ::testing::TestWithParam<SoftmaxParamsWithFormat<T> > {
|
||||
public:
|
||||
void test() {
|
||||
const auto data_type = type_to_data_type<T>::value;
|
||||
SoftmaxParams<T> params;
|
||||
format::type plain_format;
|
||||
format::type target_format;
|
||||
|
||||
std::tie(params, plain_format, target_format) = this->GetParam();
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
const auto input = engine.allocate_memory({data_type, plain_format, params.input_tensor});
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input", input->get_layout()));
|
||||
topology.add(reorder("reordered_input", "input", target_format, data_type));
|
||||
topology.add(softmax("blocked_softmax", "reordered_input", params.axis));
|
||||
topology.add(reorder("softmax", "blocked_softmax", plain_format, data_type));
|
||||
|
||||
set_values(input, params.input);
|
||||
|
||||
build_options bo;
|
||||
bo.set_option(build_option::optimize_data(false));
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("input", input);
|
||||
const auto outputs = network.execute();
|
||||
const auto output = outputs.at("softmax").get_memory();
|
||||
const cldnn::mem_lock<T> output_ptr(output, get_test_stream());
|
||||
|
||||
ASSERT_EQ(params.input_tensor.count(), output_ptr.size());
|
||||
for (uint32_t i = 0; i < output_ptr.size(); i++) {
|
||||
EXPECT_NEAR(output_ptr[i], params.expected[i], getError<T>()) << "target_format=" << target_format << ", i=" << i;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using softmax_gpu_formats_test_f32 = softmax_gpu_formats_test<float>;
|
||||
using softmax_gpu_formats_test_f16 = softmax_gpu_formats_test<half_t>;
|
||||
|
||||
TEST_P(softmax_gpu_formats_test_f32, softmax_gpu_formats_test_f32) {
|
||||
ASSERT_NO_FATAL_FAILURE(test());
|
||||
}
|
||||
|
||||
TEST_P(softmax_gpu_formats_test_f16, softmax_gpu_formats_test_f16) {
|
||||
ASSERT_NO_FATAL_FAILURE(test());
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(softmax_gpu_formats_test_f32_2d,
|
||||
softmax_gpu_formats_test_f32,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateSoftmaxParams2D<float>()),
|
||||
::testing::Values(format::bfyx),
|
||||
::testing::ValuesIn(formats2D)
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(softmax_gpu_formats_test_f16_2d,
|
||||
softmax_gpu_formats_test_f16,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateSoftmaxParams2D<half_t>()),
|
||||
::testing::Values(format::bfyx),
|
||||
::testing::ValuesIn(formats2D)
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(softmax_gpu_formats_test_f32_3d,
|
||||
softmax_gpu_formats_test_f32,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateSoftmaxParams3D<float>()),
|
||||
::testing::Values(format::bfzyx),
|
||||
::testing::ValuesIn(formats3D)
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(softmax_gpu_formats_test_f16_3d,
|
||||
softmax_gpu_formats_test_f16,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateSoftmaxParams3D<half_t>()),
|
||||
::testing::Values(format::bfzyx),
|
||||
::testing::ValuesIn(formats3D)
|
||||
),
|
||||
PrintToStringParamName());
|
||||
|
Loading…
Reference in New Issue
Block a user