[GPU] Add parallel quantizes optimization (#9370)

This commit is contained in:
Sergey Shlyapnikov 2021-12-27 09:47:20 +03:00 committed by GitHub
parent 0fa226a0c2
commit 95d86eb2bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 0 deletions

View File

@ -327,6 +327,9 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
}
void prepare_quantization::handle_quantize_node(program& p, quantize_node& quantize_node) {
if (optimize_quantize(p, quantize_node))
return;
if (quantize_node.get_primitive()->levels == 2) {
prepare_packed_quantize(p, quantize_node);
} else if (quantize_node.get_primitive()->levels <= 256 && !quantize_node.get_scale_shift_opt() && !quantize_node.is_constant()) {
@ -759,6 +762,100 @@ void prepare_quantization::prepare_asymmetric_quantization(program &p, convoluti
new_conv_node.recalc_output_layout();
}
bool prepare_quantization::optimize_quantize(program &p, quantize_node& quantize_node) {
const auto& stream = p.get_stream();
auto& input = quantize_node.get_dependency(0);
auto parallel_quantizes_num = 0;
for (auto& usr : input.get_users()) {
if (usr->is_type<quantize>())
parallel_quantizes_num++;
}
if (parallel_quantizes_num < 2)
return false;
auto quantize_prim_first = quantize_node.get_primitive();
program_node &input_low_node_first = quantize_node.get_dependency(1);
program_node &input_high_node_first = quantize_node.get_dependency(2);
program_node &output_low_node_first = quantize_node.get_dependency(3);
program_node &output_high_node_first = quantize_node.get_dependency(4);
if (!input_low_node_first.is_type<data>() || !input_high_node_first.is_type<data>() ||
!output_low_node_first.is_type<data>() || !output_high_node_first.is_type<data>()) {
return false;
}
auto mem_input_low_first = input_low_node_first.as<data>().get_attached_memory_ptr();
auto mem_input_high_first = input_high_node_first.as<data>().get_attached_memory_ptr();
auto mem_output_low_first = output_low_node_first.as<data>().get_attached_memory_ptr();
auto mem_output_high_first = output_high_node_first.as<data>().get_attached_memory_ptr();
mem_lock<uint8_t, mem_lock_type::read> mem_input_low_lock_first{mem_input_low_first, stream};
mem_lock<uint8_t, mem_lock_type::read> mem_input_high_lock_first{mem_input_high_first, stream};
mem_lock<uint8_t, mem_lock_type::read> mem_output_low_lock_first{mem_output_low_first, stream};
mem_lock<uint8_t, mem_lock_type::read> mem_output_high_lock_first{mem_output_high_first, stream};
program_node* same_quantize = nullptr;
for (auto& usr : input.get_users()) {
if (!usr->is_type<quantize>() || usr == &quantize_node)
continue;
auto quantize_prim_second = usr->as<quantize>().get_primitive();
program_node &input_low_node_second = usr->get_dependency(1);
program_node &input_high_node_second = usr->get_dependency(2);
program_node &output_low_node_second = usr->get_dependency(3);
program_node &output_high_node_second = usr->get_dependency(4);
if (!input_low_node_second.is_type<data>() || !input_high_node_second.is_type<data>() ||
!output_low_node_second.is_type<data>() || !output_high_node_second.is_type<data>())
continue;
auto mem_input_low_second = input_low_node_second.as<data>().get_attached_memory_ptr();
auto mem_input_high_second = input_high_node_second.as<data>().get_attached_memory_ptr();
auto mem_output_low_second = output_low_node_second.as<data>().get_attached_memory_ptr();
auto mem_output_high_second = output_high_node_second.as<data>().get_attached_memory_ptr();
mem_lock<uint8_t, mem_lock_type::read> mem_input_low_lock_second{mem_input_low_second, stream};
mem_lock<uint8_t, mem_lock_type::read> mem_input_high_lock_second{mem_input_high_second, stream};
mem_lock<uint8_t, mem_lock_type::read> mem_output_low_lock_second{mem_output_low_second, stream};
mem_lock<uint8_t, mem_lock_type::read> mem_output_high_lock_second{mem_output_high_second, stream};
if (mem_input_low_first->count() != mem_input_low_second->count() || mem_input_high_first->count() != mem_input_high_second->count() ||
mem_output_low_first->count() != mem_output_low_second->count() || mem_output_high_first->count() != mem_output_high_second->count())
continue;
if (memcmp(mem_input_low_lock_first.data(), mem_input_low_lock_second.data(), mem_input_low_first->size()) != 0 ||
memcmp(mem_input_high_lock_first.data(), mem_input_high_lock_second.data(), mem_input_high_first->size()) != 0 ||
memcmp(mem_output_low_lock_first.data(), mem_output_low_lock_second.data(), mem_output_low_first->size()) != 0 ||
memcmp(mem_output_high_lock_first.data(), mem_output_high_lock_second.data(), mem_output_high_first->size()) != 0)
continue;
if (quantize_prim_first->output_data_type != quantize_prim_second->output_data_type ||
quantize_prim_first->levels != quantize_prim_second->levels)
continue;
same_quantize = usr;
break;
}
if (!same_quantize)
return false;
while (!quantize_node.get_dependencies().empty()) {
auto& dep = quantize_node.get_dependency(0);
p.remove_connection(dep, quantize_node);
p.remove_if_dangling(dep);
}
p.add_optimized_primitive_info(quantize_node.id(), {same_quantize->id()});
p.replace_all_usages(quantize_node, *same_quantize);
return true;
}
void prepare_quantization::run(program& p) {
auto itr = p.get_processing_order().begin();
while (itr != p.get_processing_order().end()) {

View File

@ -159,6 +159,7 @@ private:
void remove_fake_reorders(program& p, reorder_node& reorder_node);
void prepare_asymmetric_quantization(program& p, convolution_node& convolution_node);
void prepare_scale_shift_opt(program &p, quantize_node& quantize_node);
bool optimize_quantize(program &p, quantize_node& quantize_node);
};
class prepare_conv_eltw_fusing : public base_pass {