[GPU] Bugfix optimized slice_mem (#21452)

Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
Min, Byungil 2023-12-12 17:17:49 +09:00 committed by GitHub
parent 521da22797
commit 42c33ac7b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -803,8 +803,8 @@ void loop_inst::concatenated_memory_mapping::slice_mem(const int64_t num_iterati
char* concat_data = reinterpret_cast<char*>(concatenated_mem->lock(stream, cldnn::mem_lock_type::read));
auto concate_layout = concatenated_mem->get_layout();
auto trait = format::traits(concate_layout.format);
if (format::is_blocked(concate_layout.format) || concate_layout.data_padding) {
auto dims = concat_mem_shape.size();
if (!format::is_default_format(concate_layout.format) || dims == 1 || concate_layout.data_padding) {
// BE CAREFUL: ov::reference::split is extremely slow.
// If we encounter any case where this code path is executed, we need to optimize it
ov::reference::split(concat_data, concat_mem_shape, elem_size, axis, num_iters, pointers_to_data.data());
@ -819,14 +819,16 @@ void loop_inst::concatenated_memory_mapping::slice_mem(const int64_t num_iterati
auto& lb_at_axis = lower_bounds[axis];
auto& ub_at_axis = upper_bounds[axis];
// Format of concat_layout is invalid here : No mixed order
size_t continuous_size = 1;
auto dims_order = trait._order;
auto target_axis = std::find(dims_order.begin(), dims_order.end(), axis);
for (auto iter = target_axis + 1 ; iter != dims_order.end() ; ++iter) {
continuous_size *= ((output_shape.size() > *iter) ? output_shape[*iter] : 1);
size_t inner_axis = axis + 1;
for (auto iter = inner_axis ; iter < dims ; ++iter) {
continuous_size *= ((output_shape.size() > iter) ? output_shape[iter] : 1);
}
auto strides = ov::Strides(lower_bounds.size(), 1);
strides[*(target_axis+1)] = continuous_size;
if (inner_axis < dims)
strides[inner_axis] = continuous_size;
const auto strides_copy_size = elem_size * continuous_size;
const auto out_last = std::next(out_data, num_iters);
@ -834,12 +836,9 @@ void loop_inst::concatenated_memory_mapping::slice_mem(const int64_t num_iterati
auto dst_mem = *out_iter;
auto slice_ranges = ov::coordinates::slice(concat_mem_shape, lower_bounds, upper_bounds, strides);
for (const auto& range : slice_ranges) {
auto src_index = range.begin_index;
for (size_t i = 0; i < range.element_number; src_index += range.step, ++i) {
const auto src_mem = concat_data + src_index * elem_size;
std::memcpy(dst_mem, src_mem, strides_copy_size);
std::advance(dst_mem, strides_copy_size);
}
const auto src_mem = concat_data + range.begin_index * elem_size;
std::memcpy(dst_mem, src_mem, strides_copy_size);
std::advance(dst_mem, strides_copy_size);
}
lb_at_axis += part_length;