[GPU] Softmax for stable diffusion (#15863)
This commit is contained in:
parent
b64cbff10b
commit
8491f15ba7
@ -3,11 +3,27 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include "include/batch_headers/common.cl"
|
#include "include/batch_headers/common.cl"
|
||||||
|
#include "include/batch_headers/fetch_data.cl"
|
||||||
|
#include "include/batch_headers/sub_group_block_read.cl"
|
||||||
|
#include "include/batch_headers/sub_group_block_write.cl"
|
||||||
|
|
||||||
|
#if SUBGROUP_BLOCK_SIZE == 1
|
||||||
|
#define BLOCK_READ(ptr, offset) DT_INPUT_BLOCK_READ(ptr, offset)
|
||||||
|
#define BLOCK_WRITE(ptr, offset, val) DT_OUTPUT_BLOCK_WRITE(ptr, offset, val)
|
||||||
|
#define BLOCK_TYPE INPUT0_TYPE
|
||||||
|
#else
|
||||||
|
#define BLOCK_READ(ptr, offset) CAT(DT_INPUT_BLOCK_READ, SUBGROUP_BLOCK_SIZE)(ptr, offset)
|
||||||
|
#define BLOCK_WRITE(ptr, offset, val) CAT(DT_OUTPUT_BLOCK_WRITE, SUBGROUP_BLOCK_SIZE)(ptr, offset, val)
|
||||||
|
#define BLOCK_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, SUBGROUP_BLOCK_SIZE)
|
||||||
|
#endif
|
||||||
|
|
||||||
#if IS_DYNAMIC
|
#if IS_DYNAMIC
|
||||||
#define CALC_POWER(n) ({uint pos = 0; uint i = n; do { i >>= 1; ++pos; } while (i); --pos;})
|
#define CALC_POWER(n) ({uint pos = 0; uint i = n; do { i >>= 1; ++pos; } while (i); --pos;})
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define SUB_GROUP_SIZE 16
|
||||||
|
|
||||||
|
REQD_SUB_GROUP_SIZE(SUB_GROUP_SIZE)
|
||||||
#if !IS_DYNAMIC
|
#if !IS_DYNAMIC
|
||||||
__attribute__((reqd_work_group_size(LWS, 1, 1)))
|
__attribute__((reqd_work_group_size(LWS, 1, 1)))
|
||||||
#endif
|
#endif
|
||||||
@ -36,36 +52,54 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
const uint data_set_offset = data_set_idx * data_set_size;
|
const uint data_set_offset = data_set_idx * data_set_size;
|
||||||
const uint my_data_offset = data_set_offset + in_data_set_idx;
|
const uint subgroup_offset = get_sub_group_id() * get_sub_group_size() * items_num;
|
||||||
|
|
||||||
INPUT0_TYPE my_chunk[STACK_SIZE];
|
INPUT0_TYPE my_chunk[STACK_SIZE];
|
||||||
INPUT0_TYPE my_maximum = -UNIT_VAL_MAX;
|
INPUT0_TYPE my_maximum = -UNIT_VAL_MAX;
|
||||||
INPUT0_TYPE my_sum = UNIT_VAL_ZERO;
|
INPUT0_TYPE my_sum = UNIT_VAL_ZERO;
|
||||||
INPUT0_TYPE tmp;
|
|
||||||
|
|
||||||
__local INPUT0_TYPE lg_storage[SLM_SIZE];
|
__local INPUT0_TYPE lg_storage[SLM_SIZE];
|
||||||
|
|
||||||
//each WI reads items_num consecutive items from batch
|
uint i=0;
|
||||||
for (uint i=0; i<items_num; ++i)
|
if (workers_per_data_set > SUB_GROUP_SIZE)
|
||||||
{
|
{
|
||||||
tmp = input[my_data_offset + i * workers_per_data_set];
|
for (; i<items_num - (items_num % SUBGROUP_BLOCK_SIZE); i+=SUBGROUP_BLOCK_SIZE)
|
||||||
|
{
|
||||||
|
BLOCK_TYPE vec_tmp = BLOCK_READ(input, data_set_offset + subgroup_offset + i * get_sub_group_size());
|
||||||
|
#if SUBGROUP_BLOCK_SIZE == 1
|
||||||
|
my_maximum = max(my_maximum, vec_tmp);
|
||||||
|
my_chunk[i] = vec_tmp;
|
||||||
|
#else
|
||||||
|
for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
|
||||||
|
{
|
||||||
|
INPUT0_TYPE tmp = vec_tmp[j];
|
||||||
|
my_maximum = max(my_maximum, tmp);
|
||||||
|
my_chunk[i+j] = tmp;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; i<items_num; i++)
|
||||||
|
{
|
||||||
|
INPUT0_TYPE tmp = input[data_set_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()];
|
||||||
my_maximum = max(my_maximum, tmp);
|
my_maximum = max(my_maximum, tmp);
|
||||||
my_chunk[i] = tmp;
|
my_chunk[i] = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (in_data_set_idx < leftovers)
|
if (in_data_set_idx < leftovers)
|
||||||
{
|
{
|
||||||
tmp = input[data_set_offset + workers_per_data_set * items_num + in_data_set_idx];
|
INPUT0_TYPE tmp = input[data_set_offset + workers_per_data_set * items_num + in_data_set_idx];
|
||||||
my_maximum = max(my_maximum, tmp);
|
my_maximum = max(my_maximum, tmp);
|
||||||
my_chunk[items_num] = tmp;
|
my_chunk[items_num] = tmp;
|
||||||
}
|
}
|
||||||
|
my_maximum = sub_group_reduce_max(my_maximum);
|
||||||
|
|
||||||
lg_storage[in_data_set_idx] = my_maximum;
|
if (get_sub_group_local_id() == 0)
|
||||||
|
lg_storage[get_sub_group_id()] = my_maximum;
|
||||||
|
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
if (in_data_set_idx == 0)
|
if (in_data_set_idx == 0)
|
||||||
{
|
{
|
||||||
for (uint i=1; i<LWS; ++i)
|
for (uint i=1; i<get_num_sub_groups(); ++i)
|
||||||
my_maximum = max(my_maximum, lg_storage[i]);
|
my_maximum = max(my_maximum, lg_storage[i]);
|
||||||
|
|
||||||
lg_storage[0] = my_maximum;
|
lg_storage[0] = my_maximum;
|
||||||
@ -79,24 +113,27 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
|||||||
|
|
||||||
for (uint i=0; i<items_num; ++i)
|
for (uint i=0; i<items_num; ++i)
|
||||||
{
|
{
|
||||||
tmp = native_exp(my_chunk[i] - my_maximum);
|
INPUT0_TYPE tmp = native_exp(my_chunk[i] - my_maximum);
|
||||||
my_sum += tmp;
|
my_sum += tmp;
|
||||||
my_chunk[i] = tmp;
|
my_chunk[i] = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (in_data_set_idx < leftovers)
|
if (in_data_set_idx < leftovers)
|
||||||
{
|
{
|
||||||
tmp = native_exp(my_chunk[items_num] - my_maximum);
|
INPUT0_TYPE tmp = native_exp(my_chunk[items_num] - my_maximum);
|
||||||
my_sum += tmp;
|
my_sum += tmp;
|
||||||
my_chunk[items_num] = tmp;
|
my_chunk[items_num] = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
lg_storage[in_data_set_idx] = my_sum;
|
my_sum = sub_group_reduce_add(my_sum);
|
||||||
|
|
||||||
|
if (get_sub_group_local_id() == 0)
|
||||||
|
lg_storage[get_sub_group_id()] = my_sum;
|
||||||
|
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
if (in_data_set_idx == 0)
|
if (in_data_set_idx == 0)
|
||||||
{
|
{
|
||||||
for (uint i=1; i<LWS; ++i)
|
for (uint i=1; i<get_num_sub_groups(); ++i)
|
||||||
my_sum += lg_storage[i];
|
my_sum += lg_storage[i];
|
||||||
|
|
||||||
lg_storage[0] = my_sum;
|
lg_storage[0] = my_sum;
|
||||||
@ -105,12 +142,34 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
|||||||
|
|
||||||
my_sum = lg_storage[0];
|
my_sum = lg_storage[0];
|
||||||
|
|
||||||
|
i=0;
|
||||||
|
|
||||||
#if HAS_FUSED_OPS
|
#if HAS_FUSED_OPS
|
||||||
for (uint i=0; i<items_num; ++i)
|
if (workers_per_data_set > SUB_GROUP_SIZE)
|
||||||
|
{
|
||||||
|
for (; i < items_num - (items_num % SUBGROUP_BLOCK_SIZE); i+=SUBGROUP_BLOCK_SIZE)
|
||||||
|
{
|
||||||
|
BLOCK_TYPE vec_tmp;
|
||||||
|
#if SUBGROUP_BLOCK_SIZE == 1
|
||||||
|
ACTIVATION_TYPE dequantized = my_chunk[i] / my_sum;
|
||||||
|
FUSED_OPS_MAIN;
|
||||||
|
vec_tmp = FUSED_OPS_RESULT_MAIN;
|
||||||
|
#else
|
||||||
|
for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
|
||||||
|
{
|
||||||
|
ACTIVATION_TYPE dequantized = my_chunk[i + j] / my_sum;
|
||||||
|
FUSED_OPS_MAIN;
|
||||||
|
vec_tmp[j] = FUSED_OPS_RESULT_MAIN;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
BLOCK_WRITE(output, data_set_offset + subgroup_offset + i * get_sub_group_size(), vec_tmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; i<items_num; i++)
|
||||||
{
|
{
|
||||||
ACTIVATION_TYPE dequantized = my_chunk[i] / my_sum;
|
ACTIVATION_TYPE dequantized = my_chunk[i] / my_sum;
|
||||||
FUSED_OPS_MAIN;
|
FUSED_OPS_MAIN;
|
||||||
output[my_data_offset + i * workers_per_data_set] = FUSED_OPS_RESULT_MAIN;
|
output[data_set_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = FUSED_OPS_RESULT_MAIN;
|
||||||
}
|
}
|
||||||
if (in_data_set_idx < leftovers)
|
if (in_data_set_idx < leftovers)
|
||||||
{
|
{
|
||||||
@ -119,8 +178,24 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
|||||||
output[data_set_offset + workers_per_data_set * items_num + in_data_set_idx] = FUSED_OPS_RESULT_LEFTOVERS;
|
output[data_set_offset + workers_per_data_set * items_num + in_data_set_idx] = FUSED_OPS_RESULT_LEFTOVERS;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
for (uint i=0; i<items_num; ++i)
|
if (workers_per_data_set > SUB_GROUP_SIZE)
|
||||||
output[my_data_offset + i * workers_per_data_set] = ACTIVATION(my_chunk[i] / my_sum, ACTIVATION_PARAMS);
|
{
|
||||||
|
for (; i<items_num - (items_num % SUBGROUP_BLOCK_SIZE); i+=SUBGROUP_BLOCK_SIZE)
|
||||||
|
{
|
||||||
|
BLOCK_TYPE vec_tmp;
|
||||||
|
#if SUBGROUP_BLOCK_SIZE == 1
|
||||||
|
vec_tmp = ACTIVATION(my_chunk[i] / my_sum, ACTIVATION_PARAMS);
|
||||||
|
#else
|
||||||
|
for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
|
||||||
|
vec_tmp[j] = ACTIVATION(my_chunk[i + j] / my_sum, ACTIVATION_PARAMS);
|
||||||
|
#endif
|
||||||
|
BLOCK_WRITE(output, data_set_offset + subgroup_offset + i * get_sub_group_size(), vec_tmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; i < items_num; i++)
|
||||||
|
{
|
||||||
|
output[data_set_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = ACTIVATION(my_chunk[i] / my_sum, ACTIVATION_PARAMS);
|
||||||
|
}
|
||||||
if (in_data_set_idx < leftovers)
|
if (in_data_set_idx < leftovers)
|
||||||
output[data_set_offset + workers_per_data_set * items_num + in_data_set_idx] = ACTIVATION(my_chunk[items_num] / my_sum, ACTIVATION_PARAMS);
|
output[data_set_offset + workers_per_data_set * items_num + in_data_set_idx] = ACTIVATION(my_chunk[items_num] / my_sum, ACTIVATION_PARAMS);
|
||||||
#endif
|
#endif
|
||||||
@ -128,3 +203,7 @@ KERNEL (softmax_gpu_continuous_bfyx)(
|
|||||||
#ifdef CALC_POWER
|
#ifdef CALC_POWER
|
||||||
#undef CALC_POWER
|
#undef CALC_POWER
|
||||||
#endif
|
#endif
|
||||||
|
#undef BLOCK_READ
|
||||||
|
#undef BLOCK_WRITE
|
||||||
|
#undef BLOCK_TYPE
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ public:
|
|||||||
size_t maxSlmSize;
|
size_t maxSlmSize;
|
||||||
size_t normIndex; // which dimension (from in-memory representation) is normalized, e.g. for bfyx and
|
size_t normIndex; // which dimension (from in-memory representation) is normalized, e.g. for bfyx and
|
||||||
// softmax::normalize_f, it will be f's index == 2 (used only by naive kernel)
|
// softmax::normalize_f, it will be f's index == 2 (used only by naive kernel)
|
||||||
|
size_t subgroupBlockSize;
|
||||||
};
|
};
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -31,6 +31,7 @@ SoftmaxKernel_bf::Parent::DispatchData SoftmaxKernel_bf::SetDefault(const softma
|
|||||||
auto dispatchData = Parent::SetDefault(params);
|
auto dispatchData = Parent::SetDefault(params);
|
||||||
|
|
||||||
dispatchData.normIndex = 0;
|
dispatchData.normIndex = 0;
|
||||||
|
|
||||||
// We have two units of data per work item in current implementation.
|
// We have two units of data per work item in current implementation.
|
||||||
auto local_mem_per_wi = 2 * BytesPerElement(params.inputs[0].GetDType());
|
auto local_mem_per_wi = 2 * BytesPerElement(params.inputs[0].GetDType());
|
||||||
// Combining device execution and local memory restrictions to compute maximum possible LWS.
|
// Combining device execution and local memory restrictions to compute maximum possible LWS.
|
||||||
@ -50,12 +51,23 @@ SoftmaxKernel_bf::Parent::DispatchData SoftmaxKernel_bf::SetDefault(const softma
|
|||||||
dispatchData.itemsNum /= 2;
|
dispatchData.itemsNum /= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (dispatchData.itemsNum >> 3)
|
||||||
|
dispatchData.subgroupBlockSize = 8;
|
||||||
|
else if (dispatchData.itemsNum >> 2)
|
||||||
|
dispatchData.subgroupBlockSize = 4;
|
||||||
|
else if (dispatchData.itemsNum >> 1)
|
||||||
|
dispatchData.subgroupBlockSize = 2;
|
||||||
|
else
|
||||||
|
dispatchData.subgroupBlockSize = 1;
|
||||||
|
|
||||||
assert((dispatchData.itemsNum + 1) * dispatchData.lws[0] >= dispatchData.dataSetSize && "More than 'lws[0]' items per batch remains! Lws too small?");
|
assert((dispatchData.itemsNum + 1) * dispatchData.lws[0] >= dispatchData.dataSetSize && "More than 'lws[0]' items per batch remains! Lws too small?");
|
||||||
|
|
||||||
dispatchData.gws[0] = dispatchData.lws[0];
|
dispatchData.gws[0] = dispatchData.lws[0];
|
||||||
dispatchData.leftovers = dispatchData.dataSetSize % dispatchData.lws[0];
|
dispatchData.leftovers = dispatchData.dataSetSize % dispatchData.lws[0];
|
||||||
|
|
||||||
assert(dispatchData.itemsNum > 0 && dispatchData.lws[0] && dispatchData.gws[0] > 0);
|
assert(dispatchData.itemsNum > 0 && dispatchData.lws[0] && dispatchData.gws[0] > 0);
|
||||||
|
} else {
|
||||||
|
dispatchData.subgroupBlockSize = 1;
|
||||||
}
|
}
|
||||||
return dispatchData;
|
return dispatchData;
|
||||||
}
|
}
|
||||||
@ -106,6 +118,7 @@ JitConstants SoftmaxKernel_bf::GetJitConstants(const softmax_params& params, Dis
|
|||||||
MakeJitConstant("DATA_SETS_COUNT", data_set_count),
|
MakeJitConstant("DATA_SETS_COUNT", data_set_count),
|
||||||
MakeJitConstant("DATA_SET_SIZE", data_set_size),
|
MakeJitConstant("DATA_SET_SIZE", data_set_size),
|
||||||
MakeJitConstant("STACK_SIZE", stack_size),
|
MakeJitConstant("STACK_SIZE", stack_size),
|
||||||
|
MakeJitConstant("SUBGROUP_BLOCK_SIZE", dispatchData.subgroupBlockSize),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
jit.AddConstants({
|
jit.AddConstants({
|
||||||
@ -116,6 +129,7 @@ JitConstants SoftmaxKernel_bf::GetJitConstants(const softmax_params& params, Dis
|
|||||||
MakeJitConstant("DATA_SET_SIZE", dispatchData.dataSetSize),
|
MakeJitConstant("DATA_SET_SIZE", dispatchData.dataSetSize),
|
||||||
MakeJitConstant("LEFTOVERS", dispatchData.leftovers),
|
MakeJitConstant("LEFTOVERS", dispatchData.leftovers),
|
||||||
MakeJitConstant("STACK_SIZE", dispatchData.itemsNum + 1),
|
MakeJitConstant("STACK_SIZE", dispatchData.itemsNum + 1),
|
||||||
|
MakeJitConstant("SUBGROUP_BLOCK_SIZE", dispatchData.subgroupBlockSize),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
auto activation_dt = GetActivationType(params);
|
auto activation_dt = GetActivationType(params);
|
||||||
|
@ -84,6 +84,18 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
testing::Values(ov::AnyMap())),
|
testing::Values(ov::AnyMap())),
|
||||||
SoftMax8LayerTest::getTestCaseName);
|
SoftMax8LayerTest::getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
smoke_SoftMaxStableDiffusion,
|
||||||
|
SoftMax8LayerTest,
|
||||||
|
testing::Combine(testing::ValuesIn(netPrecisions),
|
||||||
|
::testing::Values(ov::element::undefined),
|
||||||
|
::testing::Values(ov::element::undefined),
|
||||||
|
::testing::ValuesIn(ov::test::static_shapes_to_test_representation({{16, 4096, 4096}})),
|
||||||
|
testing::Values(-1),
|
||||||
|
testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||||
|
testing::Values(ov::AnyMap())),
|
||||||
|
SoftMax8LayerTest::getTestCaseName);
|
||||||
|
|
||||||
const std::vector<ov::Shape> inputShapes5D = {
|
const std::vector<ov::Shape> inputShapes5D = {
|
||||||
{1, 100, 1, 1, 1},
|
{1, 100, 1, 1, 1},
|
||||||
{1, 3, 4, 3, 4},
|
{1, 3, 4, 3, 4},
|
||||||
|
@ -812,6 +812,21 @@ ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v5::Round>& node,
|
|||||||
return Activation::generate(elemType, targetShape, InputGenerateData(-10, 20, 4));
|
return Activation::generate(elemType, targetShape, InputGenerateData(-10, 20, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v8::Softmax>& node,
|
||||||
|
size_t port,
|
||||||
|
const ov::element::Type& elemType,
|
||||||
|
const ov::Shape& targetShape) {
|
||||||
|
auto axis = node->get_axis();
|
||||||
|
axis = axis < 0 ? targetShape.size() + axis : axis;
|
||||||
|
unsigned datasetSize = std::accumulate(targetShape.begin() + axis, targetShape.end(), 1,
|
||||||
|
[](std::size_t a, size_t b) { return a * b; });
|
||||||
|
// Generate small negative values for datasets which exceed 2048 size
|
||||||
|
// to avoid NaN values in Softmax results for fp16 precision
|
||||||
|
if (datasetSize >= 2048 && static_cast<ov::element::Type_t>(elemType) == ov::element::Type_t::f16)
|
||||||
|
return ov::test::utils::create_and_fill_tensor_normal_distribution(elemType, targetShape, -5.f, 0.5f, 7235346);
|
||||||
|
return generate(std::dynamic_pointer_cast<ov::Node>(node), port, elemType, targetShape);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
ov::runtime::Tensor generateInput(const std::shared_ptr<ov::Node>& node,
|
ov::runtime::Tensor generateInput(const std::shared_ptr<ov::Node>& node,
|
||||||
size_t port,
|
size_t port,
|
||||||
|
Loading…
Reference in New Issue
Block a user