[GPU] Softmax for stable diffusion (#15863)
This commit is contained in:
parent
b64cbff10b
commit
8491f15ba7
@ -3,11 +3,27 @@
|
||||
//
|
||||
|
||||
#include "include/batch_headers/common.cl"
|
||||
#include "include/batch_headers/fetch_data.cl"
|
||||
#include "include/batch_headers/sub_group_block_read.cl"
|
||||
#include "include/batch_headers/sub_group_block_write.cl"
|
||||
|
||||
#if SUBGROUP_BLOCK_SIZE == 1
|
||||
#define BLOCK_READ(ptr, offset) DT_INPUT_BLOCK_READ(ptr, offset)
|
||||
#define BLOCK_WRITE(ptr, offset, val) DT_OUTPUT_BLOCK_WRITE(ptr, offset, val)
|
||||
#define BLOCK_TYPE INPUT0_TYPE
|
||||
#else
|
||||
#define BLOCK_READ(ptr, offset) CAT(DT_INPUT_BLOCK_READ, SUBGROUP_BLOCK_SIZE)(ptr, offset)
|
||||
#define BLOCK_WRITE(ptr, offset, val) CAT(DT_OUTPUT_BLOCK_WRITE, SUBGROUP_BLOCK_SIZE)(ptr, offset, val)
|
||||
#define BLOCK_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, SUBGROUP_BLOCK_SIZE)
|
||||
#endif
|
||||
|
||||
#if IS_DYNAMIC
|
||||
#define CALC_POWER(n) ({uint pos = 0; uint i = n; do { i >>= 1; ++pos; } while (i); --pos;})
|
||||
#endif
|
||||
|
||||
#define SUB_GROUP_SIZE 16
|
||||
|
||||
REQD_SUB_GROUP_SIZE(SUB_GROUP_SIZE)
|
||||
#if !IS_DYNAMIC
|
||||
__attribute__((reqd_work_group_size(LWS, 1, 1)))
|
||||
#endif
|
||||
@ -36,36 +52,54 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
||||
#endif
|
||||
|
||||
const uint data_set_offset = data_set_idx * data_set_size;
|
||||
const uint my_data_offset = data_set_offset + in_data_set_idx;
|
||||
const uint subgroup_offset = get_sub_group_id() * get_sub_group_size() * items_num;
|
||||
|
||||
INPUT0_TYPE my_chunk[STACK_SIZE];
|
||||
INPUT0_TYPE my_maximum = -UNIT_VAL_MAX;
|
||||
INPUT0_TYPE my_sum = UNIT_VAL_ZERO;
|
||||
INPUT0_TYPE tmp;
|
||||
|
||||
|
||||
__local INPUT0_TYPE lg_storage[SLM_SIZE];
|
||||
|
||||
//each WI reads items_num consecutive items from batch
|
||||
for (uint i=0; i<items_num; ++i)
|
||||
uint i=0;
|
||||
if (workers_per_data_set > SUB_GROUP_SIZE)
|
||||
{
|
||||
tmp = input[my_data_offset + i * workers_per_data_set];
|
||||
for (; i<items_num - (items_num % SUBGROUP_BLOCK_SIZE); i+=SUBGROUP_BLOCK_SIZE)
|
||||
{
|
||||
BLOCK_TYPE vec_tmp = BLOCK_READ(input, data_set_offset + subgroup_offset + i * get_sub_group_size());
|
||||
#if SUBGROUP_BLOCK_SIZE == 1
|
||||
my_maximum = max(my_maximum, vec_tmp);
|
||||
my_chunk[i] = vec_tmp;
|
||||
#else
|
||||
for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
|
||||
{
|
||||
INPUT0_TYPE tmp = vec_tmp[j];
|
||||
my_maximum = max(my_maximum, tmp);
|
||||
my_chunk[i+j] = tmp;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (; i<items_num; i++)
|
||||
{
|
||||
INPUT0_TYPE tmp = input[data_set_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()];
|
||||
my_maximum = max(my_maximum, tmp);
|
||||
my_chunk[i] = tmp;
|
||||
}
|
||||
|
||||
if (in_data_set_idx < leftovers)
|
||||
{
|
||||
tmp = input[data_set_offset + workers_per_data_set * items_num + in_data_set_idx];
|
||||
INPUT0_TYPE tmp = input[data_set_offset + workers_per_data_set * items_num + in_data_set_idx];
|
||||
my_maximum = max(my_maximum, tmp);
|
||||
my_chunk[items_num] = tmp;
|
||||
}
|
||||
|
||||
lg_storage[in_data_set_idx] = my_maximum;
|
||||
my_maximum = sub_group_reduce_max(my_maximum);
|
||||
|
||||
if (get_sub_group_local_id() == 0)
|
||||
lg_storage[get_sub_group_id()] = my_maximum;
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (in_data_set_idx == 0)
|
||||
{
|
||||
for (uint i=1; i<LWS; ++i)
|
||||
for (uint i=1; i<get_num_sub_groups(); ++i)
|
||||
my_maximum = max(my_maximum, lg_storage[i]);
|
||||
|
||||
lg_storage[0] = my_maximum;
|
||||
@ -79,24 +113,27 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
||||
|
||||
for (uint i=0; i<items_num; ++i)
|
||||
{
|
||||
tmp = native_exp(my_chunk[i] - my_maximum);
|
||||
INPUT0_TYPE tmp = native_exp(my_chunk[i] - my_maximum);
|
||||
my_sum += tmp;
|
||||
my_chunk[i] = tmp;
|
||||
}
|
||||
|
||||
if (in_data_set_idx < leftovers)
|
||||
{
|
||||
tmp = native_exp(my_chunk[items_num] - my_maximum);
|
||||
INPUT0_TYPE tmp = native_exp(my_chunk[items_num] - my_maximum);
|
||||
my_sum += tmp;
|
||||
my_chunk[items_num] = tmp;
|
||||
}
|
||||
|
||||
my_sum = sub_group_reduce_add(my_sum);
|
||||
|
||||
lg_storage[in_data_set_idx] = my_sum;
|
||||
if (get_sub_group_local_id() == 0)
|
||||
lg_storage[get_sub_group_id()] = my_sum;
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (in_data_set_idx == 0)
|
||||
{
|
||||
for (uint i=1; i<LWS; ++i)
|
||||
for (uint i=1; i<get_num_sub_groups(); ++i)
|
||||
my_sum += lg_storage[i];
|
||||
|
||||
lg_storage[0] = my_sum;
|
||||
@ -104,13 +141,35 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
my_sum = lg_storage[0];
|
||||
|
||||
i=0;
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
for (uint i=0; i<items_num; ++i)
|
||||
if (workers_per_data_set > SUB_GROUP_SIZE)
|
||||
{
|
||||
for (; i < items_num - (items_num % SUBGROUP_BLOCK_SIZE); i+=SUBGROUP_BLOCK_SIZE)
|
||||
{
|
||||
BLOCK_TYPE vec_tmp;
|
||||
#if SUBGROUP_BLOCK_SIZE == 1
|
||||
ACTIVATION_TYPE dequantized = my_chunk[i] / my_sum;
|
||||
FUSED_OPS_MAIN;
|
||||
vec_tmp = FUSED_OPS_RESULT_MAIN;
|
||||
#else
|
||||
for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
|
||||
{
|
||||
ACTIVATION_TYPE dequantized = my_chunk[i + j] / my_sum;
|
||||
FUSED_OPS_MAIN;
|
||||
vec_tmp[j] = FUSED_OPS_RESULT_MAIN;
|
||||
}
|
||||
#endif
|
||||
BLOCK_WRITE(output, data_set_offset + subgroup_offset + i * get_sub_group_size(), vec_tmp);
|
||||
}
|
||||
}
|
||||
for (; i<items_num; i++)
|
||||
{
|
||||
ACTIVATION_TYPE dequantized = my_chunk[i] / my_sum;
|
||||
FUSED_OPS_MAIN;
|
||||
output[my_data_offset + i * workers_per_data_set] = FUSED_OPS_RESULT_MAIN;
|
||||
output[data_set_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = FUSED_OPS_RESULT_MAIN;
|
||||
}
|
||||
if (in_data_set_idx < leftovers)
|
||||
{
|
||||
@ -119,8 +178,24 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
||||
output[data_set_offset + workers_per_data_set * items_num + in_data_set_idx] = FUSED_OPS_RESULT_LEFTOVERS;
|
||||
}
|
||||
#else
|
||||
for (uint i=0; i<items_num; ++i)
|
||||
output[my_data_offset + i * workers_per_data_set] = ACTIVATION(my_chunk[i] / my_sum, ACTIVATION_PARAMS);
|
||||
if (workers_per_data_set > SUB_GROUP_SIZE)
|
||||
{
|
||||
for (; i<items_num - (items_num % SUBGROUP_BLOCK_SIZE); i+=SUBGROUP_BLOCK_SIZE)
|
||||
{
|
||||
BLOCK_TYPE vec_tmp;
|
||||
#if SUBGROUP_BLOCK_SIZE == 1
|
||||
vec_tmp = ACTIVATION(my_chunk[i] / my_sum, ACTIVATION_PARAMS);
|
||||
#else
|
||||
for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
|
||||
vec_tmp[j] = ACTIVATION(my_chunk[i + j] / my_sum, ACTIVATION_PARAMS);
|
||||
#endif
|
||||
BLOCK_WRITE(output, data_set_offset + subgroup_offset + i * get_sub_group_size(), vec_tmp);
|
||||
}
|
||||
}
|
||||
for (; i < items_num; i++)
|
||||
{
|
||||
output[data_set_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = ACTIVATION(my_chunk[i] / my_sum, ACTIVATION_PARAMS);
|
||||
}
|
||||
if (in_data_set_idx < leftovers)
|
||||
output[data_set_offset + workers_per_data_set * items_num + in_data_set_idx] = ACTIVATION(my_chunk[items_num] / my_sum, ACTIVATION_PARAMS);
|
||||
#endif
|
||||
@ -128,3 +203,7 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
||||
#ifdef CALC_POWER
|
||||
#undef CALC_POWER
|
||||
#endif
|
||||
#undef BLOCK_READ
|
||||
#undef BLOCK_WRITE
|
||||
#undef BLOCK_TYPE
|
||||
|
||||
|
@ -46,6 +46,7 @@ public:
|
||||
size_t maxSlmSize;
|
||||
size_t normIndex; // which dimension (from in-memory representation) is normalized, e.g. for bfyx and
|
||||
// softmax::normalize_f, it will be f's index == 2 (used only by naive kernel)
|
||||
size_t subgroupBlockSize;
|
||||
};
|
||||
|
||||
protected:
|
||||
|
@ -31,6 +31,7 @@ SoftmaxKernel_bf::Parent::DispatchData SoftmaxKernel_bf::SetDefault(const softma
|
||||
auto dispatchData = Parent::SetDefault(params);
|
||||
|
||||
dispatchData.normIndex = 0;
|
||||
|
||||
// We have two units of data per work item in current implementation.
|
||||
auto local_mem_per_wi = 2 * BytesPerElement(params.inputs[0].GetDType());
|
||||
// Combining device execution and local memory restrictions to compute maximum possible LWS.
|
||||
@ -50,12 +51,23 @@ SoftmaxKernel_bf::Parent::DispatchData SoftmaxKernel_bf::SetDefault(const softma
|
||||
dispatchData.itemsNum /= 2;
|
||||
}
|
||||
|
||||
if (dispatchData.itemsNum >> 3)
|
||||
dispatchData.subgroupBlockSize = 8;
|
||||
else if (dispatchData.itemsNum >> 2)
|
||||
dispatchData.subgroupBlockSize = 4;
|
||||
else if (dispatchData.itemsNum >> 1)
|
||||
dispatchData.subgroupBlockSize = 2;
|
||||
else
|
||||
dispatchData.subgroupBlockSize = 1;
|
||||
|
||||
assert((dispatchData.itemsNum + 1) * dispatchData.lws[0] >= dispatchData.dataSetSize && "More than 'lws[0]' items per batch remains! Lws too small?");
|
||||
|
||||
dispatchData.gws[0] = dispatchData.lws[0];
|
||||
dispatchData.leftovers = dispatchData.dataSetSize % dispatchData.lws[0];
|
||||
|
||||
assert(dispatchData.itemsNum > 0 && dispatchData.lws[0] && dispatchData.gws[0] > 0);
|
||||
} else {
|
||||
dispatchData.subgroupBlockSize = 1;
|
||||
}
|
||||
return dispatchData;
|
||||
}
|
||||
@ -106,6 +118,7 @@ JitConstants SoftmaxKernel_bf::GetJitConstants(const softmax_params& params, Dis
|
||||
MakeJitConstant("DATA_SETS_COUNT", data_set_count),
|
||||
MakeJitConstant("DATA_SET_SIZE", data_set_size),
|
||||
MakeJitConstant("STACK_SIZE", stack_size),
|
||||
MakeJitConstant("SUBGROUP_BLOCK_SIZE", dispatchData.subgroupBlockSize),
|
||||
});
|
||||
} else {
|
||||
jit.AddConstants({
|
||||
@ -116,6 +129,7 @@ JitConstants SoftmaxKernel_bf::GetJitConstants(const softmax_params& params, Dis
|
||||
MakeJitConstant("DATA_SET_SIZE", dispatchData.dataSetSize),
|
||||
MakeJitConstant("LEFTOVERS", dispatchData.leftovers),
|
||||
MakeJitConstant("STACK_SIZE", dispatchData.itemsNum + 1),
|
||||
MakeJitConstant("SUBGROUP_BLOCK_SIZE", dispatchData.subgroupBlockSize),
|
||||
});
|
||||
}
|
||||
auto activation_dt = GetActivationType(params);
|
||||
|
@ -84,6 +84,18 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
testing::Values(ov::AnyMap())),
|
||||
SoftMax8LayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
smoke_SoftMaxStableDiffusion,
|
||||
SoftMax8LayerTest,
|
||||
testing::Combine(testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(ov::element::undefined),
|
||||
::testing::Values(ov::element::undefined),
|
||||
::testing::ValuesIn(ov::test::static_shapes_to_test_representation({{16, 4096, 4096}})),
|
||||
testing::Values(-1),
|
||||
testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
testing::Values(ov::AnyMap())),
|
||||
SoftMax8LayerTest::getTestCaseName);
|
||||
|
||||
const std::vector<ov::Shape> inputShapes5D = {
|
||||
{1, 100, 1, 1, 1},
|
||||
{1, 3, 4, 3, 4},
|
||||
|
@ -812,6 +812,21 @@ ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v5::Round>& node,
|
||||
return Activation::generate(elemType, targetShape, InputGenerateData(-10, 20, 4));
|
||||
}
|
||||
|
||||
ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v8::Softmax>& node,
|
||||
size_t port,
|
||||
const ov::element::Type& elemType,
|
||||
const ov::Shape& targetShape) {
|
||||
auto axis = node->get_axis();
|
||||
axis = axis < 0 ? targetShape.size() + axis : axis;
|
||||
unsigned datasetSize = std::accumulate(targetShape.begin() + axis, targetShape.end(), 1,
|
||||
[](std::size_t a, size_t b) { return a * b; });
|
||||
// Generate small negative values for datasets which exceed 2048 size
|
||||
// to avoid NaN values in Softmax results for fp16 precision
|
||||
if (datasetSize >= 2048 && static_cast<ov::element::Type_t>(elemType) == ov::element::Type_t::f16)
|
||||
return ov::test::utils::create_and_fill_tensor_normal_distribution(elemType, targetShape, -5.f, 0.5f, 7235346);
|
||||
return generate(std::dynamic_pointer_cast<ov::Node>(node), port, elemType, targetShape);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ov::runtime::Tensor generateInput(const std::shared_ptr<ov::Node>& node,
|
||||
size_t port,
|
||||
|
Loading…
Reference in New Issue
Block a user