[GPU] Refactoring of pre_replace_deconv pass (#6218)

This commit is contained in:
Andrei Molotkov 2021-07-26 10:02:21 +03:00 committed by GitHub
parent 33dfcd62c2
commit 9b706711fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 157 deletions

View File

@ -34,19 +34,13 @@ void pre_replace_deconv::run(program_impl& p) {
auto& deconv_node = node->as<deconvolution>(); auto& deconv_node = node->as<deconvolution>();
auto& weights_node = deconv_node.weights(); auto& weights_node = deconv_node.weights();
auto deconv_prim = node->as<deconvolution>().typed_desc(); auto deconv_prim = deconv_node.typed_desc();
tensor filter_size = weights_node.get_output_layout().size; tensor filter_size = weights_node.get_output_layout().size;
auto weights = deconv_prim->weights; auto weights_nodes_id = deconv_prim->weights;
auto biases_nodes_id = deconv_prim->bias;
std::vector<primitive_id> weights_vec; auto& input_node = deconv_node.get_dependency(0);
for (auto& weights_id : weights) const primitive_id deconv_node_id = deconv_node.id();
weights_vec.push_back(weights_id); const primitive_id& input_node_id = input_node.id();
for (auto& weights_id : weights_vec) {
auto weights_iter = p.nodes_map.find(weights_id);
if (weights_iter == p.nodes_map.end())
continue;
}
// limit optimization to stride = 1 // limit optimization to stride = 1
// iterators shouldn't be used here because of incorrect iterator functionality in mutable_array_ref<> // iterators shouldn't be used here because of incorrect iterator functionality in mutable_array_ref<>
@ -55,8 +49,6 @@ void pre_replace_deconv::run(program_impl& p) {
unit_stride &= (deconv_prim->stride.spatial[i] == 1); unit_stride &= (deconv_prim->stride.spatial[i] == 1);
} }
if (unit_stride) { if (unit_stride) {
primitive_id deconv_id = node->id();
auto& input_node = node->get_dependency(0);
auto groups = deconv_node.get_groups(); auto groups = deconv_node.get_groups();
bool perform_opt = false; bool perform_opt = false;
@ -64,155 +56,132 @@ void pre_replace_deconv::run(program_impl& p) {
perform_opt |= cldnn::format::dimension(input_node.get_output_layout().format) == 4 && perform_opt |= cldnn::format::dimension(input_node.get_output_layout().format) == 4 &&
(input_node.get_output_layout().data_type == data_types::f32 || input_node.get_output_layout().data_type == data_types::f16) && (input_node.get_output_layout().data_type == data_types::f32 || input_node.get_output_layout().data_type == data_types::f16) &&
!((_lo.get_optimization_attributes().b_fs_yx_fsv16_network || input_node.get_output_layout().format == format::b_fs_yx_fsv16) && !((_lo.get_optimization_attributes().b_fs_yx_fsv16_network || input_node.get_output_layout().format == format::b_fs_yx_fsv16) &&
_lo.is_format_optimized(node->as<deconvolution>(), format::b_fs_yx_fsv16)); _lo.is_format_optimized(deconv_node, format::b_fs_yx_fsv16));
// int8/uint8 input // int8/uint8 input
perform_opt |= (input_node.get_output_layout().data_type == data_types::i8 || input_node.get_output_layout().data_type == data_types::u8); perform_opt |= (input_node.get_output_layout().data_type == data_types::i8 || input_node.get_output_layout().data_type == data_types::u8);
if (!perform_opt) if (!perform_opt)
continue; continue;
primitive_id input_id = deconv_prim->input[0];
// setting convolution parameters based on deconvolution params // setting convolution parameters based on deconvolution params
auto stride = deconv_prim->stride; auto stride = deconv_prim->stride;
auto biases = deconv_prim->bias;
std::vector<primitive_id> bias_vec;
for (auto& bias_id : biases)
bias_vec.push_back(bias_id);
auto input_offset = deconv_prim->input_offset; auto input_offset = deconv_prim->input_offset;
auto output_padding = deconv_prim->output_padding; auto output_padding = deconv_prim->output_padding;
auto grouped_weights_shape = deconv_prim->grouped_weights_shape; auto grouped_weights_shape = deconv_prim->grouped_weights_shape;
// remove deconvolution node and its connections to weights and biases, rename it and move to the optimized // remove deconvolution node and its connections to weights and biases, rename it and move to the optimized
// list // list
p.remove_connection(node->get_dependency(0), *node); p.remove_connection(input_node, deconv_node);
for (auto& weights_id : weights_vec) { std::vector<std::shared_ptr<program_node>> weight_connections;
for (auto& weights_id : weights_nodes_id) {
auto weights_iter = p.nodes_map.find(weights_id); auto weights_iter = p.nodes_map.find(weights_id);
if (weights_iter == p.nodes_map.end()) if (weights_iter == p.nodes_map.end())
continue; continue;
auto weights_node_ptr = weights_iter->second; auto weights_node_ptr = weights_iter->second;
p.remove_connection(*weights_node_ptr, *node); weight_connections.push_back(weights_node_ptr);
p.remove_connection(*weights_node_ptr, deconv_node);
} }
input_offset.spatial[0] = std::abs(input_offset.spatial[0]) - (filter_size.spatial[0] - 1); input_offset.spatial[0] = std::abs(input_offset.spatial[0]) - (filter_size.spatial[0] - 1);
input_offset.spatial[1] = std::abs(input_offset.spatial[1]) - (filter_size.spatial[1] - 1); input_offset.spatial[1] = std::abs(input_offset.spatial[1]) - (filter_size.spatial[1] - 1);
input_offset.spatial[2] = std::abs(input_offset.spatial[2]) - (filter_size.spatial[2] - 1); input_offset.spatial[2] = std::abs(input_offset.spatial[2]) - (filter_size.spatial[2] - 1);
if (!bias_vec.empty()) { std::vector<std::shared_ptr<program_node>> bias_connections;
for (auto& bias_id : bias_vec) { for (auto& bias_id : biases_nodes_id) {
auto bias_iter = p.nodes_map.find(bias_id); auto bias_iter = p.nodes_map.find(bias_id);
if (bias_iter == p.nodes_map.end()) if (bias_iter == p.nodes_map.end())
continue; continue;
auto bias_id_node_ptr = bias_iter->second; auto bias_id_node_ptr = bias_iter->second;
p.remove_connection(*bias_id_node_ptr, *node); bias_connections.push_back(bias_id_node_ptr);
} p.remove_connection(*bias_id_node_ptr, deconv_node);
} }
auto rename_id = deconv_id + "_tmp"; auto was_output = deconv_node.is_output();
auto was_output = node->is_output();
if (was_output) { if (was_output) {
node->set_output(false); deconv_node.set_output(false);
auto& outputs = p.get_outputs(); auto& outputs = p.get_outputs();
outputs.erase(std::remove(outputs.begin(), outputs.end(), node.get()), outputs.end()); outputs.erase(std::remove(outputs.begin(), outputs.end(), node.get()), outputs.end());
} }
p.rename(*node, rename_id); auto rename_id = deconv_node_id + "_tmp";
p.rename(deconv_node, rename_id);
// create convolution primitive // create convolution primitive
if (!biases.empty()) { std::shared_ptr<convolution> conv_prim;
auto conv_prim = std::make_shared<convolution>(deconv_id, if (!biases_nodes_id.empty()) {
input_id, conv_prim = std::make_shared<convolution>(deconv_node_id,
weights_vec, input_node_id,
bias_vec, weights_nodes_id,
groups, biases_nodes_id,
stride, groups,
input_offset, stride,
tensor{ 1, 1, 1, 1 }, input_offset,
grouped_weights_shape, tensor{ 1, 1, 1, 1 },
output_padding); grouped_weights_shape,
p.get_or_create(conv_prim); output_padding);
} else { } else {
auto conv_prim = std::make_shared<convolution>(deconv_id, conv_prim = std::make_shared<convolution>(deconv_node_id,
input_id, input_node_id,
weights_vec, weights_nodes_id,
groups, groups,
stride, stride,
input_offset, input_offset,
tensor{ 1, 1, 1, 1 }, tensor{ 1, 1, 1, 1 },
grouped_weights_shape, grouped_weights_shape,
output_padding); output_padding);
p.get_or_create(conv_prim);
} }
program_node& new_node = p.get_or_create(conv_prim);
auto conv_node_itr = p.nodes_map.find(deconv_id); auto& conv_node = new_node.as<convolution>();
if (conv_node_itr == p.nodes_map.end()) conv_node.set_transposed(true);
continue;
auto conv_node_ptr = conv_node_itr->second;
auto conv_node = &conv_node_ptr->as<convolution>();
conv_node->set_transposed(true);
// add connections input->convolution, weights->convolution and bias->convolution // add connections input->convolution, weights->convolution and bias->convolution
p.add_connection(input_node, *conv_node_ptr); p.add_connection(input_node, conv_node);
for (auto& weights_id : weights_vec) { for (auto& weight_node : weight_connections) {
auto weights_node_itr = p.nodes_map.find(weights_id); p.add_connection(*weight_node, conv_node);
if (weights_node_itr == p.nodes_map.end())
continue;
auto weights_node_ptr = weights_node_itr->second;
p.add_connection(*weights_node_ptr, *conv_node_ptr);
} }
if (!bias_vec.empty()) { for (auto& bias_node : bias_connections) {
for (auto& bias_id : bias_vec) { p.add_connection(*bias_node, conv_node);
auto bias_id_node_itr = p.nodes_map.find(bias_id);
if (bias_id_node_itr == p.nodes_map.end())
continue;
auto bias_id_node_ptr = bias_id_node_itr->second;
p.add_connection(*bias_id_node_ptr, *conv_node_ptr);
}
} }
auto deconv_node_itr = p.nodes_map.find(rename_id); auto deconv_node_itr = p.nodes_map.find(rename_id);
if (deconv_node_itr != p.nodes_map.end()) { if (deconv_node_itr != p.nodes_map.end()) {
auto deconv_node_ptr = deconv_node_itr->second; auto deconv_node_ptr = deconv_node_itr->second;
p.replace_all_usages(*deconv_node_ptr, *conv_node_ptr); p.replace_all_usages(*deconv_node_ptr, conv_node);
p.optimized_out.push_back(rename_id); p.optimized_out.push_back(rename_id);
p.nodes_map.erase(rename_id); p.nodes_map.erase(rename_id);
} }
if (was_output) { if (was_output) {
conv_node->set_output(true); conv_node.set_output(true);
p.get_outputs().push_back(conv_node); p.get_outputs().push_back(&conv_node);
} }
p.mark_if_data_flow(*conv_node); p.mark_if_data_flow(conv_node);
conv_node->recalc_output_layout(true); conv_node.recalc_output_layout(true);
update_processing_order = true; update_processing_order = true;
// current optimization only available for specific deconvolution parameters // current optimization only available for specific deconvolution parameters
} else if (node->is_output() == false && } else if (deconv_node.is_output() == false &&
node->get_output_layout().size.feature[0] == 1 && deconv_node.get_output_layout().size.feature[0] == 1 &&
deconv_prim->stride.spatial[0] == 2 && deconv_prim->stride.spatial[1] == 2 && deconv_prim->stride.spatial[0] == 2 && deconv_prim->stride.spatial[1] == 2 &&
filter_size.spatial[0] == 9 && filter_size.spatial[1] == 9 && filter_size.spatial[0] == 9 && filter_size.spatial[1] == 9 &&
deconv_prim->input_offset.spatial[0] == -4 && deconv_prim->input_offset.spatial[1] == -4 && deconv_prim->input_offset.spatial[0] == -4 && deconv_prim->input_offset.spatial[1] == -4 &&
weights_vec.size() == 1 && deconv_prim->bias.size() == 1 && weights_nodes_id.size() == 1 && biases_nodes_id.size() == 1 &&
node->get_dependency(0).get_output_layout().format == format::bfyx) { input_node.get_output_layout().format == format::bfyx) {
primitive_id deconv_id = node->id(); const auto scale_factor = deconv_prim->stride.spatial[0];
auto& input_node = node->get_dependency(0);
primitive_id input_id = deconv_prim->input[0];
auto scale_factor = deconv_prim->stride.spatial[0]; const auto& weight_node_id = weights_nodes_id.front();
auto weights_node_ptr = p.nodes_map.find(weight_node_id)->second;
const auto& weights_layout = weights_node_ptr->get_output_layout();
const auto& weights_data_type = weights_layout.data_type;
auto cur_weights_node_ptr = p.nodes_map.find(weights_vec[0])->second; const auto& bias_node_id = biases_nodes_id.front();
auto weights_layout = cur_weights_node_ptr->get_output_layout(); auto bias_id_node_ptr = p.nodes_map.find(bias_node_id)->second;
auto weights_data_type = weights_layout.data_type; const auto& bias_data_type = bias_id_node_ptr->get_output_layout().data_type;
auto biases = deconv_prim->bias[0];
auto bias_id_node_ptr = p.nodes_map.find(biases)->second;
auto bias_data_type = bias_id_node_ptr->get_output_layout().data_type;
// enable only for fp32 and fp16 // enable only for fp32 and fp16
if (weights_data_type != data_types::f16 && if (weights_data_type != data_types::f16 &&
@ -229,14 +198,13 @@ void pre_replace_deconv::run(program_impl& p) {
// remove deconvolution node and its connections to weights and biases, // remove deconvolution node and its connections to weights and biases,
// rename it and move to the optimized list // rename it and move to the optimized list
p.remove_connection(node->get_dependency(0), *node); p.remove_connection(input_node, deconv_node);
auto weights_node_ptr = p.nodes_map.find(weights_vec[0])->second; p.remove_connection(*weights_node_ptr, deconv_node);
p.remove_connection(*weights_node_ptr, *node); p.remove_connection(*bias_id_node_ptr, deconv_node);
p.remove_connection(*bias_id_node_ptr, *node);
auto rename_id = deconv_id + "_tmp"; auto rename_id = deconv_node_id + "_tmp";
p.rename(*node, rename_id); p.rename(deconv_node, rename_id);
// reshape weights // reshape weights
int pixel_shuffle_size = scale_factor * scale_factor; int pixel_shuffle_size = scale_factor * scale_factor;
@ -244,17 +212,18 @@ void pre_replace_deconv::run(program_impl& p) {
tensor target_weights_size = { pixel_shuffle_size, filter_size.feature[0], kernel_size, kernel_size }; tensor target_weights_size = { pixel_shuffle_size, filter_size.feature[0], kernel_size, kernel_size };
auto target_weights_layout = layout{ weights_layout.data_type, weights_layout.format, target_weights_size }; auto target_weights_layout = layout{ weights_layout.data_type, weights_layout.format, target_weights_size };
const primitive_id weight_replace_node_id = weight_node_id + "_conv_rpl";
{ {
memory::ptr data_to_allocate = p.get_engine().allocate_memory(target_weights_layout); memory::ptr data_to_allocate = p.get_engine().allocate_memory(target_weights_layout);
std::vector<float> weights_vec_float; std::vector<float> weights_vec_float;
if (weights_data_type == data_types::f16) { if (weights_data_type == data_types::f16) {
mem_lock<half_t> src{ cur_weights_node_ptr->as<data>().get_attached_memory_ptr(), stream }; mem_lock<half_t> src{ weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
for (uint32_t i = 0; i < weights_layout.size.count(); i++) for (uint32_t i = 0; i < weights_layout.size.count(); i++)
weights_vec_float.push_back(static_cast<float>(src.data()[i])); weights_vec_float.push_back(static_cast<float>(src.data()[i]));
} else { } else {
mem_lock<float> src{ cur_weights_node_ptr->as<data>().get_attached_memory_ptr(), stream }; mem_lock<float> src{ weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
for (uint32_t i = 0; i < weights_layout.size.count(); i++) for (uint32_t i = 0; i < weights_layout.size.count(); i++)
weights_vec_float.push_back(src.data()[i]); weights_vec_float.push_back(src.data()[i]);
} }
@ -278,12 +247,36 @@ void pre_replace_deconv::run(program_impl& p) {
throw std::logic_error("Not supported data type."); throw std::logic_error("Not supported data type.");
} }
auto data_node_weights_replace = std::make_shared<data>(weights_vec[0] + "_conv_rpl", data_to_allocate); auto data_node_weights_replace = std::make_shared<data>(weight_replace_node_id, data_to_allocate);
p.get_or_create(data_node_weights_replace); program_node& weights_replace_node = p.get_or_create(data_node_weights_replace);
auto data_node_weights_replace_node_ptr = p.nodes_map.find(weights_vec[0] + "_conv_rpl")->second; auto& data_node = weights_replace_node.as<data>();
auto& data_node = data_node_weights_replace_node_ptr->as<data>();
data_node.set_output_layout(target_weights_layout, false); data_node.set_output_layout(target_weights_layout, false);
} }
auto deconv_id_conv = deconv_node_id + "_conv";
// create convolution primitive
auto conv_prim = std::make_shared<convolution>(deconv_id_conv,
input_node_id,
std::vector<primitive_id>{ weight_replace_node_id },
stride,
input_offset,
tensor{ 1, 1, 1, 1 },
grouped_weights_shape,
output_padding);
program_node& created_node = p.get_or_create(conv_prim);
auto& conv_node = created_node.as<convolution>();
// add connections input->convolution, weights->convolution and bias->convolution
p.add_connection(input_node, conv_node);
{
auto weights_node_conv_rpl_ptr = p.nodes_map.find(weight_replace_node_id)->second;
p.add_connection(*weights_node_conv_rpl_ptr, conv_node);
p.inputs.push_back(weights_node_conv_rpl_ptr.get());
}
float bias = 0; float bias = 0;
if (bias_data_type == data_types::f16) { if (bias_data_type == data_types::f16) {
@ -293,52 +286,22 @@ void pre_replace_deconv::run(program_impl& p) {
mem_lock<float> src{ bias_id_node_ptr->as<data>().get_attached_memory_ptr(), stream }; mem_lock<float> src{ bias_id_node_ptr->as<data>().get_attached_memory_ptr(), stream };
bias = src.data()[0]; bias = src.data()[0];
} }
auto pixel_shuffle_prim = std::make_shared<depth_to_space>(deconv_node_id, deconv_id_conv, 2, depth_to_space_mode::blocks_first);
auto deconv_id_conv = deconv_id + "_conv"; program_node& pixel_shuffle_node = p.get_or_create(pixel_shuffle_prim);
pixel_shuffle_node.add_fused_activation(activation_func::linear, { 1, bias });
// create convolution primitive
auto conv_prim = std::make_shared<convolution>(deconv_id_conv,
input_id,
std::vector<primitive_id>{ weights_vec[0] + "_conv_rpl" },
stride,
input_offset,
tensor{ 1, 1, 1, 1 },
grouped_weights_shape,
output_padding);
p.get_or_create(conv_prim);
auto conv_node_itr = p.nodes_map.find(deconv_id_conv);
if (conv_node_itr == p.nodes_map.end()) continue;
auto conv_node_ptr = conv_node_itr->second;
auto conv_node = &conv_node_ptr->as<convolution>();
// add connections input->convolution, weights->convolution and bias->convolution
p.add_connection(input_node, *conv_node_ptr);
{
auto weights_node_conv_rpl_ptr = p.nodes_map.find(weights_vec[0] + "_conv_rpl")->second;
p.add_connection(*weights_node_conv_rpl_ptr, *conv_node_ptr);
p.inputs.push_back(weights_node_conv_rpl_ptr.get());
}
auto pixel_shuffle_prim = std::make_shared<depth_to_space>(deconv_id, deconv_id_conv, 2, depth_to_space_mode::blocks_first);
p.get_or_create(pixel_shuffle_prim);
auto pixel_shuffle_node_ptr = p.nodes_map.find(deconv_id)->second;
pixel_shuffle_node_ptr->add_fused_activation(activation_func::linear, { 1, bias });
// add connections input->convolution, weights->convolution // add connections input->convolution, weights->convolution
p.add_connection(*conv_node_ptr, *pixel_shuffle_node_ptr); p.add_connection(conv_node, pixel_shuffle_node);
auto deconv_node_ptr = p.nodes_map.find(rename_id); auto deconv_node_ptr = p.nodes_map.find(rename_id);
if (deconv_node_ptr != p.nodes_map.end()) { if (deconv_node_ptr != p.nodes_map.end()) {
p.replace_all_usages(*deconv_node_ptr->second, *pixel_shuffle_node_ptr); p.replace_all_usages(*deconv_node_ptr->second, pixel_shuffle_node);
p.optimized_out.push_back(rename_id); p.optimized_out.push_back(rename_id);
p.nodes_map.erase(rename_id); p.nodes_map.erase(rename_id);
} }
p.mark_if_data_flow(*conv_node); p.mark_if_data_flow(conv_node);
conv_node->recalc_output_layout(true); conv_node.recalc_output_layout(true);
update_processing_order = true; update_processing_order = true;
} }

View File

@ -834,13 +834,10 @@ void program_impl::swap_names(program_node& node1, program_node& node2) {
} }
void program_impl::replace_all_usages(program_node& old_node, program_node& new_node) { void program_impl::replace_all_usages(program_node& old_node, program_node& new_node) {
const std::list<program_node*> users(old_node.users); auto itr = old_node.users.begin();
auto itr = users.begin(); while (itr != old_node.users.end()) {
bool end = (itr == users.end()); auto user = *(itr++);
while (!end) { user->replace_dependency(old_node, new_node);
auto& usage = (*itr++);
end = (itr == users.end());
usage->replace_dependency(old_node, new_node);
} }
} }