[GPU] Refactoring of pre_replace_deconv pass (#6218)

This commit is contained in:
Andrei Molotkov 2021-07-26 10:02:21 +03:00 committed by GitHub
parent 33dfcd62c2
commit 9b706711fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 157 deletions

View File

@ -34,19 +34,13 @@ void pre_replace_deconv::run(program_impl& p) {
auto& deconv_node = node->as<deconvolution>();
auto& weights_node = deconv_node.weights();
auto deconv_prim = node->as<deconvolution>().typed_desc();
auto deconv_prim = deconv_node.typed_desc();
tensor filter_size = weights_node.get_output_layout().size;
auto weights = deconv_prim->weights;
std::vector<primitive_id> weights_vec;
for (auto& weights_id : weights)
weights_vec.push_back(weights_id);
for (auto& weights_id : weights_vec) {
auto weights_iter = p.nodes_map.find(weights_id);
if (weights_iter == p.nodes_map.end())
continue;
}
auto weights_nodes_id = deconv_prim->weights;
auto biases_nodes_id = deconv_prim->bias;
auto& input_node = deconv_node.get_dependency(0);
const primitive_id deconv_node_id = deconv_node.id();
const primitive_id& input_node_id = input_node.id();
// limit optimization to stride = 1
// iterators shouldn't be used here because of incorrect iterator functionality in mutable_array_ref<>
@ -55,8 +49,6 @@ void pre_replace_deconv::run(program_impl& p) {
unit_stride &= (deconv_prim->stride.spatial[i] == 1);
}
if (unit_stride) {
primitive_id deconv_id = node->id();
auto& input_node = node->get_dependency(0);
auto groups = deconv_node.get_groups();
bool perform_opt = false;
@ -64,155 +56,132 @@ void pre_replace_deconv::run(program_impl& p) {
perform_opt |= cldnn::format::dimension(input_node.get_output_layout().format) == 4 &&
(input_node.get_output_layout().data_type == data_types::f32 || input_node.get_output_layout().data_type == data_types::f16) &&
!((_lo.get_optimization_attributes().b_fs_yx_fsv16_network || input_node.get_output_layout().format == format::b_fs_yx_fsv16) &&
_lo.is_format_optimized(node->as<deconvolution>(), format::b_fs_yx_fsv16));
_lo.is_format_optimized(deconv_node, format::b_fs_yx_fsv16));
// int8/uint8 input
perform_opt |= (input_node.get_output_layout().data_type == data_types::i8 || input_node.get_output_layout().data_type == data_types::u8);
if (!perform_opt)
continue;
primitive_id input_id = deconv_prim->input[0];
// setting convolution parameters based on deconvolution params
auto stride = deconv_prim->stride;
auto biases = deconv_prim->bias;
std::vector<primitive_id> bias_vec;
for (auto& bias_id : biases)
bias_vec.push_back(bias_id);
auto input_offset = deconv_prim->input_offset;
auto output_padding = deconv_prim->output_padding;
auto grouped_weights_shape = deconv_prim->grouped_weights_shape;
// remove deconvolution node and its connections to weights and biases, rename it and move to the optimized
// list
p.remove_connection(node->get_dependency(0), *node);
for (auto& weights_id : weights_vec) {
p.remove_connection(input_node, deconv_node);
std::vector<std::shared_ptr<program_node>> weight_connections;
for (auto& weights_id : weights_nodes_id) {
auto weights_iter = p.nodes_map.find(weights_id);
if (weights_iter == p.nodes_map.end())
continue;
auto weights_node_ptr = weights_iter->second;
p.remove_connection(*weights_node_ptr, *node);
weight_connections.push_back(weights_node_ptr);
p.remove_connection(*weights_node_ptr, deconv_node);
}
input_offset.spatial[0] = std::abs(input_offset.spatial[0]) - (filter_size.spatial[0] - 1);
input_offset.spatial[1] = std::abs(input_offset.spatial[1]) - (filter_size.spatial[1] - 1);
input_offset.spatial[2] = std::abs(input_offset.spatial[2]) - (filter_size.spatial[2] - 1);
if (!bias_vec.empty()) {
for (auto& bias_id : bias_vec) {
auto bias_iter = p.nodes_map.find(bias_id);
if (bias_iter == p.nodes_map.end())
continue;
std::vector<std::shared_ptr<program_node>> bias_connections;
for (auto& bias_id : biases_nodes_id) {
auto bias_iter = p.nodes_map.find(bias_id);
if (bias_iter == p.nodes_map.end())
continue;
auto bias_id_node_ptr = bias_iter->second;
p.remove_connection(*bias_id_node_ptr, *node);
}
auto bias_id_node_ptr = bias_iter->second;
bias_connections.push_back(bias_id_node_ptr);
p.remove_connection(*bias_id_node_ptr, deconv_node);
}
auto rename_id = deconv_id + "_tmp";
auto was_output = node->is_output();
auto was_output = deconv_node.is_output();
if (was_output) {
node->set_output(false);
deconv_node.set_output(false);
auto& outputs = p.get_outputs();
outputs.erase(std::remove(outputs.begin(), outputs.end(), node.get()), outputs.end());
}
p.rename(*node, rename_id);
auto rename_id = deconv_node_id + "_tmp";
p.rename(deconv_node, rename_id);
// create convolution primitive
if (!biases.empty()) {
auto conv_prim = std::make_shared<convolution>(deconv_id,
input_id,
weights_vec,
bias_vec,
groups,
stride,
input_offset,
tensor{ 1, 1, 1, 1 },
grouped_weights_shape,
output_padding);
p.get_or_create(conv_prim);
std::shared_ptr<convolution> conv_prim;
if (!biases_nodes_id.empty()) {
conv_prim = std::make_shared<convolution>(deconv_node_id,
input_node_id,
weights_nodes_id,
biases_nodes_id,
groups,
stride,
input_offset,
tensor{ 1, 1, 1, 1 },
grouped_weights_shape,
output_padding);
} else {
auto conv_prim = std::make_shared<convolution>(deconv_id,
input_id,
weights_vec,
groups,
stride,
input_offset,
tensor{ 1, 1, 1, 1 },
grouped_weights_shape,
output_padding);
p.get_or_create(conv_prim);
conv_prim = std::make_shared<convolution>(deconv_node_id,
input_node_id,
weights_nodes_id,
groups,
stride,
input_offset,
tensor{ 1, 1, 1, 1 },
grouped_weights_shape,
output_padding);
}
program_node& new_node = p.get_or_create(conv_prim);
auto conv_node_itr = p.nodes_map.find(deconv_id);
if (conv_node_itr == p.nodes_map.end())
continue;
auto conv_node_ptr = conv_node_itr->second;
auto conv_node = &conv_node_ptr->as<convolution>();
conv_node->set_transposed(true);
auto& conv_node = new_node.as<convolution>();
conv_node.set_transposed(true);
// add connections input->convolution, weights->convolution and bias->convolution
p.add_connection(input_node, *conv_node_ptr);
p.add_connection(input_node, conv_node);
for (auto& weights_id : weights_vec) {
auto weights_node_itr = p.nodes_map.find(weights_id);
if (weights_node_itr == p.nodes_map.end())
continue;
auto weights_node_ptr = weights_node_itr->second;
p.add_connection(*weights_node_ptr, *conv_node_ptr);
for (auto& weight_node : weight_connections) {
p.add_connection(*weight_node, conv_node);
}
if (!bias_vec.empty()) {
for (auto& bias_id : bias_vec) {
auto bias_id_node_itr = p.nodes_map.find(bias_id);
if (bias_id_node_itr == p.nodes_map.end())
continue;
auto bias_id_node_ptr = bias_id_node_itr->second;
p.add_connection(*bias_id_node_ptr, *conv_node_ptr);
}
for (auto& bias_node : bias_connections) {
p.add_connection(*bias_node, conv_node);
}
auto deconv_node_itr = p.nodes_map.find(rename_id);
if (deconv_node_itr != p.nodes_map.end()) {
auto deconv_node_ptr = deconv_node_itr->second;
p.replace_all_usages(*deconv_node_ptr, *conv_node_ptr);
p.replace_all_usages(*deconv_node_ptr, conv_node);
p.optimized_out.push_back(rename_id);
p.nodes_map.erase(rename_id);
}
if (was_output) {
conv_node->set_output(true);
p.get_outputs().push_back(conv_node);
conv_node.set_output(true);
p.get_outputs().push_back(&conv_node);
}
p.mark_if_data_flow(*conv_node);
conv_node->recalc_output_layout(true);
p.mark_if_data_flow(conv_node);
conv_node.recalc_output_layout(true);
update_processing_order = true;
// current optimization only available for specific deconvolution parameters
} else if (node->is_output() == false &&
node->get_output_layout().size.feature[0] == 1 &&
} else if (deconv_node.is_output() == false &&
deconv_node.get_output_layout().size.feature[0] == 1 &&
deconv_prim->stride.spatial[0] == 2 && deconv_prim->stride.spatial[1] == 2 &&
filter_size.spatial[0] == 9 && filter_size.spatial[1] == 9 &&
deconv_prim->input_offset.spatial[0] == -4 && deconv_prim->input_offset.spatial[1] == -4 &&
weights_vec.size() == 1 && deconv_prim->bias.size() == 1 &&
node->get_dependency(0).get_output_layout().format == format::bfyx) {
primitive_id deconv_id = node->id();
auto& input_node = node->get_dependency(0);
primitive_id input_id = deconv_prim->input[0];
weights_nodes_id.size() == 1 && biases_nodes_id.size() == 1 &&
input_node.get_output_layout().format == format::bfyx) {
const auto scale_factor = deconv_prim->stride.spatial[0];
auto scale_factor = deconv_prim->stride.spatial[0];
const auto& weight_node_id = weights_nodes_id.front();
auto weights_node_ptr = p.nodes_map.find(weight_node_id)->second;
const auto& weights_layout = weights_node_ptr->get_output_layout();
const auto& weights_data_type = weights_layout.data_type;
auto cur_weights_node_ptr = p.nodes_map.find(weights_vec[0])->second;
auto weights_layout = cur_weights_node_ptr->get_output_layout();
auto weights_data_type = weights_layout.data_type;
auto biases = deconv_prim->bias[0];
auto bias_id_node_ptr = p.nodes_map.find(biases)->second;
auto bias_data_type = bias_id_node_ptr->get_output_layout().data_type;
const auto& bias_node_id = biases_nodes_id.front();
auto bias_id_node_ptr = p.nodes_map.find(bias_node_id)->second;
const auto& bias_data_type = bias_id_node_ptr->get_output_layout().data_type;
// enable only for fp32 and fp16
if (weights_data_type != data_types::f16 &&
@ -229,14 +198,13 @@ void pre_replace_deconv::run(program_impl& p) {
// remove deconvolution node and its connections to weights and biases,
// rename it and move to the optimized list
p.remove_connection(node->get_dependency(0), *node);
p.remove_connection(input_node, deconv_node);
auto weights_node_ptr = p.nodes_map.find(weights_vec[0])->second;
p.remove_connection(*weights_node_ptr, *node);
p.remove_connection(*bias_id_node_ptr, *node);
p.remove_connection(*weights_node_ptr, deconv_node);
p.remove_connection(*bias_id_node_ptr, deconv_node);
auto rename_id = deconv_id + "_tmp";
p.rename(*node, rename_id);
auto rename_id = deconv_node_id + "_tmp";
p.rename(deconv_node, rename_id);
// reshape weights
int pixel_shuffle_size = scale_factor * scale_factor;
@ -244,17 +212,18 @@ void pre_replace_deconv::run(program_impl& p) {
tensor target_weights_size = { pixel_shuffle_size, filter_size.feature[0], kernel_size, kernel_size };
auto target_weights_layout = layout{ weights_layout.data_type, weights_layout.format, target_weights_size };
const primitive_id weight_replace_node_id = weight_node_id + "_conv_rpl";
{
memory::ptr data_to_allocate = p.get_engine().allocate_memory(target_weights_layout);
std::vector<float> weights_vec_float;
if (weights_data_type == data_types::f16) {
mem_lock<half_t> src{ cur_weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
mem_lock<half_t> src{ weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
for (uint32_t i = 0; i < weights_layout.size.count(); i++)
weights_vec_float.push_back(static_cast<float>(src.data()[i]));
} else {
mem_lock<float> src{ cur_weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
mem_lock<float> src{ weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
for (uint32_t i = 0; i < weights_layout.size.count(); i++)
weights_vec_float.push_back(src.data()[i]);
}
@ -278,12 +247,36 @@ void pre_replace_deconv::run(program_impl& p) {
throw std::logic_error("Not supported data type.");
}
auto data_node_weights_replace = std::make_shared<data>(weights_vec[0] + "_conv_rpl", data_to_allocate);
p.get_or_create(data_node_weights_replace);
auto data_node_weights_replace_node_ptr = p.nodes_map.find(weights_vec[0] + "_conv_rpl")->second;
auto& data_node = data_node_weights_replace_node_ptr->as<data>();
auto data_node_weights_replace = std::make_shared<data>(weight_replace_node_id, data_to_allocate);
program_node& weights_replace_node = p.get_or_create(data_node_weights_replace);
auto& data_node = weights_replace_node.as<data>();
data_node.set_output_layout(target_weights_layout, false);
}
auto deconv_id_conv = deconv_node_id + "_conv";
// create convolution primitive
auto conv_prim = std::make_shared<convolution>(deconv_id_conv,
input_node_id,
std::vector<primitive_id>{ weight_replace_node_id },
stride,
input_offset,
tensor{ 1, 1, 1, 1 },
grouped_weights_shape,
output_padding);
program_node& created_node = p.get_or_create(conv_prim);
auto& conv_node = created_node.as<convolution>();
// add connections input->convolution, weights->convolution and bias->convolution
p.add_connection(input_node, conv_node);
{
auto weights_node_conv_rpl_ptr = p.nodes_map.find(weight_replace_node_id)->second;
p.add_connection(*weights_node_conv_rpl_ptr, conv_node);
p.inputs.push_back(weights_node_conv_rpl_ptr.get());
}
float bias = 0;
if (bias_data_type == data_types::f16) {
@ -293,52 +286,22 @@ void pre_replace_deconv::run(program_impl& p) {
mem_lock<float> src{ bias_id_node_ptr->as<data>().get_attached_memory_ptr(), stream };
bias = src.data()[0];
}
auto pixel_shuffle_prim = std::make_shared<depth_to_space>(deconv_node_id, deconv_id_conv, 2, depth_to_space_mode::blocks_first);
auto deconv_id_conv = deconv_id + "_conv";
// create convolution primitive
auto conv_prim = std::make_shared<convolution>(deconv_id_conv,
input_id,
std::vector<primitive_id>{ weights_vec[0] + "_conv_rpl" },
stride,
input_offset,
tensor{ 1, 1, 1, 1 },
grouped_weights_shape,
output_padding);
p.get_or_create(conv_prim);
auto conv_node_itr = p.nodes_map.find(deconv_id_conv);
if (conv_node_itr == p.nodes_map.end()) continue;
auto conv_node_ptr = conv_node_itr->second;
auto conv_node = &conv_node_ptr->as<convolution>();
// add connections input->convolution, weights->convolution and bias->convolution
p.add_connection(input_node, *conv_node_ptr);
{
auto weights_node_conv_rpl_ptr = p.nodes_map.find(weights_vec[0] + "_conv_rpl")->second;
p.add_connection(*weights_node_conv_rpl_ptr, *conv_node_ptr);
p.inputs.push_back(weights_node_conv_rpl_ptr.get());
}
auto pixel_shuffle_prim = std::make_shared<depth_to_space>(deconv_id, deconv_id_conv, 2, depth_to_space_mode::blocks_first);
p.get_or_create(pixel_shuffle_prim);
auto pixel_shuffle_node_ptr = p.nodes_map.find(deconv_id)->second;
pixel_shuffle_node_ptr->add_fused_activation(activation_func::linear, { 1, bias });
program_node& pixel_shuffle_node = p.get_or_create(pixel_shuffle_prim);
pixel_shuffle_node.add_fused_activation(activation_func::linear, { 1, bias });
// add connections input->convolution, weights->convolution
p.add_connection(*conv_node_ptr, *pixel_shuffle_node_ptr);
p.add_connection(conv_node, pixel_shuffle_node);
auto deconv_node_ptr = p.nodes_map.find(rename_id);
if (deconv_node_ptr != p.nodes_map.end()) {
p.replace_all_usages(*deconv_node_ptr->second, *pixel_shuffle_node_ptr);
p.replace_all_usages(*deconv_node_ptr->second, pixel_shuffle_node);
p.optimized_out.push_back(rename_id);
p.nodes_map.erase(rename_id);
}
p.mark_if_data_flow(*conv_node);
conv_node->recalc_output_layout(true);
p.mark_if_data_flow(conv_node);
conv_node.recalc_output_layout(true);
update_processing_order = true;
}

View File

@ -834,13 +834,10 @@ void program_impl::swap_names(program_node& node1, program_node& node2) {
}
void program_impl::replace_all_usages(program_node& old_node, program_node& new_node) {
const std::list<program_node*> users(old_node.users);
auto itr = users.begin();
bool end = (itr == users.end());
while (!end) {
auto& usage = (*itr++);
end = (itr == users.end());
usage->replace_dependency(old_node, new_node);
auto itr = old_node.users.begin();
while (itr != old_node.users.end()) {
auto user = *(itr++);
user->replace_dependency(old_node, new_node);
}
}