[GPU] Softmax for stable diffusion (#15863)

This commit is contained in:
Yaroslav Torzuk 2023-04-03 08:21:02 +02:00 committed by GitHub
parent b64cbff10b
commit 8491f15ba7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 140 additions and 19 deletions

View File

@ -3,11 +3,27 @@
//
#include "include/batch_headers/common.cl"
#include "include/batch_headers/fetch_data.cl"
#include "include/batch_headers/sub_group_block_read.cl"
#include "include/batch_headers/sub_group_block_write.cl"
#if SUBGROUP_BLOCK_SIZE == 1
#define BLOCK_READ(ptr, offset) DT_INPUT_BLOCK_READ(ptr, offset)
#define BLOCK_WRITE(ptr, offset, val) DT_OUTPUT_BLOCK_WRITE(ptr, offset, val)
#define BLOCK_TYPE INPUT0_TYPE
#else
#define BLOCK_READ(ptr, offset) CAT(DT_INPUT_BLOCK_READ, SUBGROUP_BLOCK_SIZE)(ptr, offset)
#define BLOCK_WRITE(ptr, offset, val) CAT(DT_OUTPUT_BLOCK_WRITE, SUBGROUP_BLOCK_SIZE)(ptr, offset, val)
#define BLOCK_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, SUBGROUP_BLOCK_SIZE)
#endif
#if IS_DYNAMIC
#define CALC_POWER(n) ({uint pos = 0; uint i = n; do { i >>= 1; ++pos; } while (i); --pos;})
#endif
#define SUB_GROUP_SIZE 16
REQD_SUB_GROUP_SIZE(SUB_GROUP_SIZE)
#if !IS_DYNAMIC
__attribute__((reqd_work_group_size(LWS, 1, 1)))
#endif
@ -36,36 +52,54 @@ KERNEL (softmax_gpu_continuous_bfyx)(
#endif
const uint data_set_offset = data_set_idx * data_set_size;
const uint my_data_offset = data_set_offset + in_data_set_idx;
const uint subgroup_offset = get_sub_group_id() * get_sub_group_size() * items_num;
INPUT0_TYPE my_chunk[STACK_SIZE];
INPUT0_TYPE my_maximum = -UNIT_VAL_MAX;
INPUT0_TYPE my_sum = UNIT_VAL_ZERO;
INPUT0_TYPE tmp;
__local INPUT0_TYPE lg_storage[SLM_SIZE];
//each WI reads items_num consecutive items from batch
for (uint i=0; i<items_num; ++i)
uint i=0;
if (workers_per_data_set > SUB_GROUP_SIZE)
{
tmp = input[my_data_offset + i * workers_per_data_set];
for (; i<items_num - (items_num % SUBGROUP_BLOCK_SIZE); i+=SUBGROUP_BLOCK_SIZE)
{
BLOCK_TYPE vec_tmp = BLOCK_READ(input, data_set_offset + subgroup_offset + i * get_sub_group_size());
#if SUBGROUP_BLOCK_SIZE == 1
my_maximum = max(my_maximum, vec_tmp);
my_chunk[i] = vec_tmp;
#else
for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
{
INPUT0_TYPE tmp = vec_tmp[j];
my_maximum = max(my_maximum, tmp);
my_chunk[i+j] = tmp;
}
#endif
}
}
for (; i<items_num; i++)
{
INPUT0_TYPE tmp = input[data_set_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()];
my_maximum = max(my_maximum, tmp);
my_chunk[i] = tmp;
}
if (in_data_set_idx < leftovers)
{
tmp = input[data_set_offset + workers_per_data_set * items_num + in_data_set_idx];
INPUT0_TYPE tmp = input[data_set_offset + workers_per_data_set * items_num + in_data_set_idx];
my_maximum = max(my_maximum, tmp);
my_chunk[items_num] = tmp;
}
my_maximum = sub_group_reduce_max(my_maximum);
lg_storage[in_data_set_idx] = my_maximum;
if (get_sub_group_local_id() == 0)
lg_storage[get_sub_group_id()] = my_maximum;
barrier(CLK_LOCAL_MEM_FENCE);
if (in_data_set_idx == 0)
{
for (uint i=1; i<LWS; ++i)
for (uint i=1; i<get_num_sub_groups(); ++i)
my_maximum = max(my_maximum, lg_storage[i]);
lg_storage[0] = my_maximum;
@ -79,24 +113,27 @@ KERNEL (softmax_gpu_continuous_bfyx)(
for (uint i=0; i<items_num; ++i)
{
tmp = native_exp(my_chunk[i] - my_maximum);
INPUT0_TYPE tmp = native_exp(my_chunk[i] - my_maximum);
my_sum += tmp;
my_chunk[i] = tmp;
}
if (in_data_set_idx < leftovers)
{
tmp = native_exp(my_chunk[items_num] - my_maximum);
INPUT0_TYPE tmp = native_exp(my_chunk[items_num] - my_maximum);
my_sum += tmp;
my_chunk[items_num] = tmp;
}
lg_storage[in_data_set_idx] = my_sum;
my_sum = sub_group_reduce_add(my_sum);
if (get_sub_group_local_id() == 0)
lg_storage[get_sub_group_id()] = my_sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (in_data_set_idx == 0)
{
for (uint i=1; i<LWS; ++i)
for (uint i=1; i<get_num_sub_groups(); ++i)
my_sum += lg_storage[i];
lg_storage[0] = my_sum;
@ -105,12 +142,34 @@ KERNEL (softmax_gpu_continuous_bfyx)(
my_sum = lg_storage[0];
i=0;
#if HAS_FUSED_OPS
for (uint i=0; i<items_num; ++i)
if (workers_per_data_set > SUB_GROUP_SIZE)
{
for (; i < items_num - (items_num % SUBGROUP_BLOCK_SIZE); i+=SUBGROUP_BLOCK_SIZE)
{
BLOCK_TYPE vec_tmp;
#if SUBGROUP_BLOCK_SIZE == 1
ACTIVATION_TYPE dequantized = my_chunk[i] / my_sum;
FUSED_OPS_MAIN;
vec_tmp = FUSED_OPS_RESULT_MAIN;
#else
for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
{
ACTIVATION_TYPE dequantized = my_chunk[i + j] / my_sum;
FUSED_OPS_MAIN;
vec_tmp[j] = FUSED_OPS_RESULT_MAIN;
}
#endif
BLOCK_WRITE(output, data_set_offset + subgroup_offset + i * get_sub_group_size(), vec_tmp);
}
}
for (; i<items_num; i++)
{
ACTIVATION_TYPE dequantized = my_chunk[i] / my_sum;
FUSED_OPS_MAIN;
output[my_data_offset + i * workers_per_data_set] = FUSED_OPS_RESULT_MAIN;
output[data_set_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = FUSED_OPS_RESULT_MAIN;
}
if (in_data_set_idx < leftovers)
{
@ -119,8 +178,24 @@ KERNEL (softmax_gpu_continuous_bfyx)(
output[data_set_offset + workers_per_data_set * items_num + in_data_set_idx] = FUSED_OPS_RESULT_LEFTOVERS;
}
#else
for (uint i=0; i<items_num; ++i)
output[my_data_offset + i * workers_per_data_set] = ACTIVATION(my_chunk[i] / my_sum, ACTIVATION_PARAMS);
if (workers_per_data_set > SUB_GROUP_SIZE)
{
for (; i<items_num - (items_num % SUBGROUP_BLOCK_SIZE); i+=SUBGROUP_BLOCK_SIZE)
{
BLOCK_TYPE vec_tmp;
#if SUBGROUP_BLOCK_SIZE == 1
vec_tmp = ACTIVATION(my_chunk[i] / my_sum, ACTIVATION_PARAMS);
#else
for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
vec_tmp[j] = ACTIVATION(my_chunk[i + j] / my_sum, ACTIVATION_PARAMS);
#endif
BLOCK_WRITE(output, data_set_offset + subgroup_offset + i * get_sub_group_size(), vec_tmp);
}
}
for (; i < items_num; i++)
{
output[data_set_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = ACTIVATION(my_chunk[i] / my_sum, ACTIVATION_PARAMS);
}
if (in_data_set_idx < leftovers)
output[data_set_offset + workers_per_data_set * items_num + in_data_set_idx] = ACTIVATION(my_chunk[items_num] / my_sum, ACTIVATION_PARAMS);
#endif
@ -128,3 +203,7 @@ KERNEL (softmax_gpu_continuous_bfyx)(
#ifdef CALC_POWER
#undef CALC_POWER
#endif
#undef BLOCK_READ
#undef BLOCK_WRITE
#undef BLOCK_TYPE

View File

@ -46,6 +46,7 @@ public:
size_t maxSlmSize;
size_t normIndex; // which dimension (from in-memory representation) is normalized, e.g. for bfyx and
// softmax::normalize_f, it will be f's index == 2 (used only by naive kernel)
size_t subgroupBlockSize;
};
protected:

View File

@ -31,6 +31,7 @@ SoftmaxKernel_bf::Parent::DispatchData SoftmaxKernel_bf::SetDefault(const softma
auto dispatchData = Parent::SetDefault(params);
dispatchData.normIndex = 0;
// We have two units of data per work item in current implementation.
auto local_mem_per_wi = 2 * BytesPerElement(params.inputs[0].GetDType());
// Combining device execution and local memory restrictions to compute maximum possible LWS.
@ -50,12 +51,23 @@ SoftmaxKernel_bf::Parent::DispatchData SoftmaxKernel_bf::SetDefault(const softma
dispatchData.itemsNum /= 2;
}
if (dispatchData.itemsNum >> 3)
dispatchData.subgroupBlockSize = 8;
else if (dispatchData.itemsNum >> 2)
dispatchData.subgroupBlockSize = 4;
else if (dispatchData.itemsNum >> 1)
dispatchData.subgroupBlockSize = 2;
else
dispatchData.subgroupBlockSize = 1;
assert((dispatchData.itemsNum + 1) * dispatchData.lws[0] >= dispatchData.dataSetSize && "More than 'lws[0]' items per batch remains! Lws too small?");
dispatchData.gws[0] = dispatchData.lws[0];
dispatchData.leftovers = dispatchData.dataSetSize % dispatchData.lws[0];
assert(dispatchData.itemsNum > 0 && dispatchData.lws[0] && dispatchData.gws[0] > 0);
} else {
dispatchData.subgroupBlockSize = 1;
}
return dispatchData;
}
@ -106,6 +118,7 @@ JitConstants SoftmaxKernel_bf::GetJitConstants(const softmax_params& params, Dis
MakeJitConstant("DATA_SETS_COUNT", data_set_count),
MakeJitConstant("DATA_SET_SIZE", data_set_size),
MakeJitConstant("STACK_SIZE", stack_size),
MakeJitConstant("SUBGROUP_BLOCK_SIZE", dispatchData.subgroupBlockSize),
});
} else {
jit.AddConstants({
@ -116,6 +129,7 @@ JitConstants SoftmaxKernel_bf::GetJitConstants(const softmax_params& params, Dis
MakeJitConstant("DATA_SET_SIZE", dispatchData.dataSetSize),
MakeJitConstant("LEFTOVERS", dispatchData.leftovers),
MakeJitConstant("STACK_SIZE", dispatchData.itemsNum + 1),
MakeJitConstant("SUBGROUP_BLOCK_SIZE", dispatchData.subgroupBlockSize),
});
}
auto activation_dt = GetActivationType(params);

View File

@ -84,6 +84,18 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(ov::AnyMap())),
SoftMax8LayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(
smoke_SoftMaxStableDiffusion,
SoftMax8LayerTest,
testing::Combine(testing::ValuesIn(netPrecisions),
::testing::Values(ov::element::undefined),
::testing::Values(ov::element::undefined),
::testing::ValuesIn(ov::test::static_shapes_to_test_representation({{16, 4096, 4096}})),
testing::Values(-1),
testing::Values(CommonTestUtils::DEVICE_GPU),
testing::Values(ov::AnyMap())),
SoftMax8LayerTest::getTestCaseName);
const std::vector<ov::Shape> inputShapes5D = {
{1, 100, 1, 1, 1},
{1, 3, 4, 3, 4},

View File

@ -812,6 +812,21 @@ ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v5::Round>& node,
return Activation::generate(elemType, targetShape, InputGenerateData(-10, 20, 4));
}
ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v8::Softmax>& node,
size_t port,
const ov::element::Type& elemType,
const ov::Shape& targetShape) {
auto axis = node->get_axis();
axis = axis < 0 ? targetShape.size() + axis : axis;
unsigned datasetSize = std::accumulate(targetShape.begin() + axis, targetShape.end(), 1,
[](std::size_t a, size_t b) { return a * b; });
// Generate small negative values for datasets which exceed 2048 size
// to avoid NaN values in Softmax results for fp16 precision
if (datasetSize >= 2048 && static_cast<ov::element::Type_t>(elemType) == ov::element::Type_t::f16)
return ov::test::utils::create_and_fill_tensor_normal_distribution(elemType, targetShape, -5.f, 0.5f, 7235346);
return generate(std::dynamic_pointer_cast<ov::Node>(node), port, elemType, targetShape);
}
template<typename T>
ov::runtime::Tensor generateInput(const std::shared_ptr<ov::Node>& node,
size_t port,