[GPU] Add parallel quantizes optimization (#9370)
This commit is contained in:
parent
0fa226a0c2
commit
95d86eb2bf
@ -327,6 +327,9 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
|
||||
}
|
||||
|
||||
void prepare_quantization::handle_quantize_node(program& p, quantize_node& quantize_node) {
|
||||
if (optimize_quantize(p, quantize_node))
|
||||
return;
|
||||
|
||||
if (quantize_node.get_primitive()->levels == 2) {
|
||||
prepare_packed_quantize(p, quantize_node);
|
||||
} else if (quantize_node.get_primitive()->levels <= 256 && !quantize_node.get_scale_shift_opt() && !quantize_node.is_constant()) {
|
||||
@ -759,6 +762,100 @@ void prepare_quantization::prepare_asymmetric_quantization(program &p, convoluti
|
||||
new_conv_node.recalc_output_layout();
|
||||
}
|
||||
|
||||
bool prepare_quantization::optimize_quantize(program &p, quantize_node& quantize_node) {
|
||||
const auto& stream = p.get_stream();
|
||||
|
||||
auto& input = quantize_node.get_dependency(0);
|
||||
auto parallel_quantizes_num = 0;
|
||||
for (auto& usr : input.get_users()) {
|
||||
if (usr->is_type<quantize>())
|
||||
parallel_quantizes_num++;
|
||||
}
|
||||
|
||||
if (parallel_quantizes_num < 2)
|
||||
return false;
|
||||
|
||||
auto quantize_prim_first = quantize_node.get_primitive();
|
||||
|
||||
program_node &input_low_node_first = quantize_node.get_dependency(1);
|
||||
program_node &input_high_node_first = quantize_node.get_dependency(2);
|
||||
program_node &output_low_node_first = quantize_node.get_dependency(3);
|
||||
program_node &output_high_node_first = quantize_node.get_dependency(4);
|
||||
|
||||
if (!input_low_node_first.is_type<data>() || !input_high_node_first.is_type<data>() ||
|
||||
!output_low_node_first.is_type<data>() || !output_high_node_first.is_type<data>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto mem_input_low_first = input_low_node_first.as<data>().get_attached_memory_ptr();
|
||||
auto mem_input_high_first = input_high_node_first.as<data>().get_attached_memory_ptr();
|
||||
auto mem_output_low_first = output_low_node_first.as<data>().get_attached_memory_ptr();
|
||||
auto mem_output_high_first = output_high_node_first.as<data>().get_attached_memory_ptr();
|
||||
|
||||
mem_lock<uint8_t, mem_lock_type::read> mem_input_low_lock_first{mem_input_low_first, stream};
|
||||
mem_lock<uint8_t, mem_lock_type::read> mem_input_high_lock_first{mem_input_high_first, stream};
|
||||
mem_lock<uint8_t, mem_lock_type::read> mem_output_low_lock_first{mem_output_low_first, stream};
|
||||
mem_lock<uint8_t, mem_lock_type::read> mem_output_high_lock_first{mem_output_high_first, stream};
|
||||
|
||||
program_node* same_quantize = nullptr;
|
||||
for (auto& usr : input.get_users()) {
|
||||
if (!usr->is_type<quantize>() || usr == &quantize_node)
|
||||
continue;
|
||||
|
||||
auto quantize_prim_second = usr->as<quantize>().get_primitive();
|
||||
|
||||
program_node &input_low_node_second = usr->get_dependency(1);
|
||||
program_node &input_high_node_second = usr->get_dependency(2);
|
||||
program_node &output_low_node_second = usr->get_dependency(3);
|
||||
program_node &output_high_node_second = usr->get_dependency(4);
|
||||
|
||||
if (!input_low_node_second.is_type<data>() || !input_high_node_second.is_type<data>() ||
|
||||
!output_low_node_second.is_type<data>() || !output_high_node_second.is_type<data>())
|
||||
continue;
|
||||
|
||||
auto mem_input_low_second = input_low_node_second.as<data>().get_attached_memory_ptr();
|
||||
auto mem_input_high_second = input_high_node_second.as<data>().get_attached_memory_ptr();
|
||||
auto mem_output_low_second = output_low_node_second.as<data>().get_attached_memory_ptr();
|
||||
auto mem_output_high_second = output_high_node_second.as<data>().get_attached_memory_ptr();
|
||||
|
||||
mem_lock<uint8_t, mem_lock_type::read> mem_input_low_lock_second{mem_input_low_second, stream};
|
||||
mem_lock<uint8_t, mem_lock_type::read> mem_input_high_lock_second{mem_input_high_second, stream};
|
||||
mem_lock<uint8_t, mem_lock_type::read> mem_output_low_lock_second{mem_output_low_second, stream};
|
||||
mem_lock<uint8_t, mem_lock_type::read> mem_output_high_lock_second{mem_output_high_second, stream};
|
||||
|
||||
if (mem_input_low_first->count() != mem_input_low_second->count() || mem_input_high_first->count() != mem_input_high_second->count() ||
|
||||
mem_output_low_first->count() != mem_output_low_second->count() || mem_output_high_first->count() != mem_output_high_second->count())
|
||||
continue;
|
||||
|
||||
if (memcmp(mem_input_low_lock_first.data(), mem_input_low_lock_second.data(), mem_input_low_first->size()) != 0 ||
|
||||
memcmp(mem_input_high_lock_first.data(), mem_input_high_lock_second.data(), mem_input_high_first->size()) != 0 ||
|
||||
memcmp(mem_output_low_lock_first.data(), mem_output_low_lock_second.data(), mem_output_low_first->size()) != 0 ||
|
||||
memcmp(mem_output_high_lock_first.data(), mem_output_high_lock_second.data(), mem_output_high_first->size()) != 0)
|
||||
continue;
|
||||
|
||||
if (quantize_prim_first->output_data_type != quantize_prim_second->output_data_type ||
|
||||
quantize_prim_first->levels != quantize_prim_second->levels)
|
||||
continue;
|
||||
|
||||
same_quantize = usr;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!same_quantize)
|
||||
return false;
|
||||
|
||||
while (!quantize_node.get_dependencies().empty()) {
|
||||
auto& dep = quantize_node.get_dependency(0);
|
||||
p.remove_connection(dep, quantize_node);
|
||||
p.remove_if_dangling(dep);
|
||||
}
|
||||
|
||||
p.add_optimized_primitive_info(quantize_node.id(), {same_quantize->id()});
|
||||
p.replace_all_usages(quantize_node, *same_quantize);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void prepare_quantization::run(program& p) {
|
||||
auto itr = p.get_processing_order().begin();
|
||||
while (itr != p.get_processing_order().end()) {
|
||||
|
@ -159,6 +159,7 @@ private:
|
||||
void remove_fake_reorders(program& p, reorder_node& reorder_node);
|
||||
void prepare_asymmetric_quantization(program& p, convolution_node& convolution_node);
|
||||
void prepare_scale_shift_opt(program &p, quantize_node& quantize_node);
|
||||
bool optimize_quantize(program &p, quantize_node& quantize_node);
|
||||
};
|
||||
|
||||
class prepare_conv_eltw_fusing : public base_pass {
|
||||
|
Loading…
Reference in New Issue
Block a user