[GPU] Refactoring of pre_replace_deconv pass (#6218)
This commit is contained in:
parent
33dfcd62c2
commit
9b706711fe
@ -34,19 +34,13 @@ void pre_replace_deconv::run(program_impl& p) {
|
|||||||
|
|
||||||
auto& deconv_node = node->as<deconvolution>();
|
auto& deconv_node = node->as<deconvolution>();
|
||||||
auto& weights_node = deconv_node.weights();
|
auto& weights_node = deconv_node.weights();
|
||||||
auto deconv_prim = node->as<deconvolution>().typed_desc();
|
auto deconv_prim = deconv_node.typed_desc();
|
||||||
tensor filter_size = weights_node.get_output_layout().size;
|
tensor filter_size = weights_node.get_output_layout().size;
|
||||||
auto weights = deconv_prim->weights;
|
auto weights_nodes_id = deconv_prim->weights;
|
||||||
|
auto biases_nodes_id = deconv_prim->bias;
|
||||||
std::vector<primitive_id> weights_vec;
|
auto& input_node = deconv_node.get_dependency(0);
|
||||||
for (auto& weights_id : weights)
|
const primitive_id deconv_node_id = deconv_node.id();
|
||||||
weights_vec.push_back(weights_id);
|
const primitive_id& input_node_id = input_node.id();
|
||||||
|
|
||||||
for (auto& weights_id : weights_vec) {
|
|
||||||
auto weights_iter = p.nodes_map.find(weights_id);
|
|
||||||
if (weights_iter == p.nodes_map.end())
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// limit optimization to stride = 1
|
// limit optimization to stride = 1
|
||||||
// iterators shouldn't be used here because of incorrect iterator functionality in mutable_array_ref<>
|
// iterators shouldn't be used here because of incorrect iterator functionality in mutable_array_ref<>
|
||||||
@ -55,8 +49,6 @@ void pre_replace_deconv::run(program_impl& p) {
|
|||||||
unit_stride &= (deconv_prim->stride.spatial[i] == 1);
|
unit_stride &= (deconv_prim->stride.spatial[i] == 1);
|
||||||
}
|
}
|
||||||
if (unit_stride) {
|
if (unit_stride) {
|
||||||
primitive_id deconv_id = node->id();
|
|
||||||
auto& input_node = node->get_dependency(0);
|
|
||||||
auto groups = deconv_node.get_groups();
|
auto groups = deconv_node.get_groups();
|
||||||
|
|
||||||
bool perform_opt = false;
|
bool perform_opt = false;
|
||||||
@ -64,155 +56,132 @@ void pre_replace_deconv::run(program_impl& p) {
|
|||||||
perform_opt |= cldnn::format::dimension(input_node.get_output_layout().format) == 4 &&
|
perform_opt |= cldnn::format::dimension(input_node.get_output_layout().format) == 4 &&
|
||||||
(input_node.get_output_layout().data_type == data_types::f32 || input_node.get_output_layout().data_type == data_types::f16) &&
|
(input_node.get_output_layout().data_type == data_types::f32 || input_node.get_output_layout().data_type == data_types::f16) &&
|
||||||
!((_lo.get_optimization_attributes().b_fs_yx_fsv16_network || input_node.get_output_layout().format == format::b_fs_yx_fsv16) &&
|
!((_lo.get_optimization_attributes().b_fs_yx_fsv16_network || input_node.get_output_layout().format == format::b_fs_yx_fsv16) &&
|
||||||
_lo.is_format_optimized(node->as<deconvolution>(), format::b_fs_yx_fsv16));
|
_lo.is_format_optimized(deconv_node, format::b_fs_yx_fsv16));
|
||||||
// int8/uint8 input
|
// int8/uint8 input
|
||||||
perform_opt |= (input_node.get_output_layout().data_type == data_types::i8 || input_node.get_output_layout().data_type == data_types::u8);
|
perform_opt |= (input_node.get_output_layout().data_type == data_types::i8 || input_node.get_output_layout().data_type == data_types::u8);
|
||||||
|
|
||||||
if (!perform_opt)
|
if (!perform_opt)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
primitive_id input_id = deconv_prim->input[0];
|
|
||||||
|
|
||||||
// setting convolution parameters based on deconvolution params
|
// setting convolution parameters based on deconvolution params
|
||||||
auto stride = deconv_prim->stride;
|
auto stride = deconv_prim->stride;
|
||||||
auto biases = deconv_prim->bias;
|
|
||||||
std::vector<primitive_id> bias_vec;
|
|
||||||
for (auto& bias_id : biases)
|
|
||||||
bias_vec.push_back(bias_id);
|
|
||||||
auto input_offset = deconv_prim->input_offset;
|
auto input_offset = deconv_prim->input_offset;
|
||||||
auto output_padding = deconv_prim->output_padding;
|
auto output_padding = deconv_prim->output_padding;
|
||||||
auto grouped_weights_shape = deconv_prim->grouped_weights_shape;
|
auto grouped_weights_shape = deconv_prim->grouped_weights_shape;
|
||||||
|
|
||||||
// remove deconvolution node and its connections to weights and biases, rename it and move to the optimized
|
// remove deconvolution node and its connections to weights and biases, rename it and move to the optimized
|
||||||
// list
|
// list
|
||||||
p.remove_connection(node->get_dependency(0), *node);
|
p.remove_connection(input_node, deconv_node);
|
||||||
for (auto& weights_id : weights_vec) {
|
std::vector<std::shared_ptr<program_node>> weight_connections;
|
||||||
|
for (auto& weights_id : weights_nodes_id) {
|
||||||
auto weights_iter = p.nodes_map.find(weights_id);
|
auto weights_iter = p.nodes_map.find(weights_id);
|
||||||
if (weights_iter == p.nodes_map.end())
|
if (weights_iter == p.nodes_map.end())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto weights_node_ptr = weights_iter->second;
|
auto weights_node_ptr = weights_iter->second;
|
||||||
p.remove_connection(*weights_node_ptr, *node);
|
weight_connections.push_back(weights_node_ptr);
|
||||||
|
p.remove_connection(*weights_node_ptr, deconv_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
input_offset.spatial[0] = std::abs(input_offset.spatial[0]) - (filter_size.spatial[0] - 1);
|
input_offset.spatial[0] = std::abs(input_offset.spatial[0]) - (filter_size.spatial[0] - 1);
|
||||||
input_offset.spatial[1] = std::abs(input_offset.spatial[1]) - (filter_size.spatial[1] - 1);
|
input_offset.spatial[1] = std::abs(input_offset.spatial[1]) - (filter_size.spatial[1] - 1);
|
||||||
input_offset.spatial[2] = std::abs(input_offset.spatial[2]) - (filter_size.spatial[2] - 1);
|
input_offset.spatial[2] = std::abs(input_offset.spatial[2]) - (filter_size.spatial[2] - 1);
|
||||||
|
|
||||||
if (!bias_vec.empty()) {
|
std::vector<std::shared_ptr<program_node>> bias_connections;
|
||||||
for (auto& bias_id : bias_vec) {
|
for (auto& bias_id : biases_nodes_id) {
|
||||||
auto bias_iter = p.nodes_map.find(bias_id);
|
auto bias_iter = p.nodes_map.find(bias_id);
|
||||||
if (bias_iter == p.nodes_map.end())
|
if (bias_iter == p.nodes_map.end())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto bias_id_node_ptr = bias_iter->second;
|
auto bias_id_node_ptr = bias_iter->second;
|
||||||
p.remove_connection(*bias_id_node_ptr, *node);
|
bias_connections.push_back(bias_id_node_ptr);
|
||||||
}
|
p.remove_connection(*bias_id_node_ptr, deconv_node);
|
||||||
}
|
}
|
||||||
auto rename_id = deconv_id + "_tmp";
|
auto was_output = deconv_node.is_output();
|
||||||
auto was_output = node->is_output();
|
|
||||||
if (was_output) {
|
if (was_output) {
|
||||||
node->set_output(false);
|
deconv_node.set_output(false);
|
||||||
auto& outputs = p.get_outputs();
|
auto& outputs = p.get_outputs();
|
||||||
outputs.erase(std::remove(outputs.begin(), outputs.end(), node.get()), outputs.end());
|
outputs.erase(std::remove(outputs.begin(), outputs.end(), node.get()), outputs.end());
|
||||||
}
|
}
|
||||||
p.rename(*node, rename_id);
|
auto rename_id = deconv_node_id + "_tmp";
|
||||||
|
p.rename(deconv_node, rename_id);
|
||||||
|
|
||||||
// create convolution primitive
|
// create convolution primitive
|
||||||
if (!biases.empty()) {
|
std::shared_ptr<convolution> conv_prim;
|
||||||
auto conv_prim = std::make_shared<convolution>(deconv_id,
|
if (!biases_nodes_id.empty()) {
|
||||||
input_id,
|
conv_prim = std::make_shared<convolution>(deconv_node_id,
|
||||||
weights_vec,
|
input_node_id,
|
||||||
bias_vec,
|
weights_nodes_id,
|
||||||
groups,
|
biases_nodes_id,
|
||||||
stride,
|
groups,
|
||||||
input_offset,
|
stride,
|
||||||
tensor{ 1, 1, 1, 1 },
|
input_offset,
|
||||||
grouped_weights_shape,
|
tensor{ 1, 1, 1, 1 },
|
||||||
output_padding);
|
grouped_weights_shape,
|
||||||
p.get_or_create(conv_prim);
|
output_padding);
|
||||||
} else {
|
} else {
|
||||||
auto conv_prim = std::make_shared<convolution>(deconv_id,
|
conv_prim = std::make_shared<convolution>(deconv_node_id,
|
||||||
input_id,
|
input_node_id,
|
||||||
weights_vec,
|
weights_nodes_id,
|
||||||
groups,
|
groups,
|
||||||
stride,
|
stride,
|
||||||
input_offset,
|
input_offset,
|
||||||
tensor{ 1, 1, 1, 1 },
|
tensor{ 1, 1, 1, 1 },
|
||||||
grouped_weights_shape,
|
grouped_weights_shape,
|
||||||
output_padding);
|
output_padding);
|
||||||
p.get_or_create(conv_prim);
|
|
||||||
}
|
}
|
||||||
|
program_node& new_node = p.get_or_create(conv_prim);
|
||||||
|
|
||||||
auto conv_node_itr = p.nodes_map.find(deconv_id);
|
auto& conv_node = new_node.as<convolution>();
|
||||||
if (conv_node_itr == p.nodes_map.end())
|
conv_node.set_transposed(true);
|
||||||
continue;
|
|
||||||
|
|
||||||
auto conv_node_ptr = conv_node_itr->second;
|
|
||||||
auto conv_node = &conv_node_ptr->as<convolution>();
|
|
||||||
conv_node->set_transposed(true);
|
|
||||||
|
|
||||||
// add connections input->convolution, weights->convolution and bias->convolution
|
// add connections input->convolution, weights->convolution and bias->convolution
|
||||||
p.add_connection(input_node, *conv_node_ptr);
|
p.add_connection(input_node, conv_node);
|
||||||
|
|
||||||
for (auto& weights_id : weights_vec) {
|
for (auto& weight_node : weight_connections) {
|
||||||
auto weights_node_itr = p.nodes_map.find(weights_id);
|
p.add_connection(*weight_node, conv_node);
|
||||||
if (weights_node_itr == p.nodes_map.end())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto weights_node_ptr = weights_node_itr->second;
|
|
||||||
p.add_connection(*weights_node_ptr, *conv_node_ptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!bias_vec.empty()) {
|
for (auto& bias_node : bias_connections) {
|
||||||
for (auto& bias_id : bias_vec) {
|
p.add_connection(*bias_node, conv_node);
|
||||||
auto bias_id_node_itr = p.nodes_map.find(bias_id);
|
|
||||||
if (bias_id_node_itr == p.nodes_map.end())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto bias_id_node_ptr = bias_id_node_itr->second;
|
|
||||||
p.add_connection(*bias_id_node_ptr, *conv_node_ptr);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto deconv_node_itr = p.nodes_map.find(rename_id);
|
auto deconv_node_itr = p.nodes_map.find(rename_id);
|
||||||
if (deconv_node_itr != p.nodes_map.end()) {
|
if (deconv_node_itr != p.nodes_map.end()) {
|
||||||
auto deconv_node_ptr = deconv_node_itr->second;
|
auto deconv_node_ptr = deconv_node_itr->second;
|
||||||
p.replace_all_usages(*deconv_node_ptr, *conv_node_ptr);
|
p.replace_all_usages(*deconv_node_ptr, conv_node);
|
||||||
p.optimized_out.push_back(rename_id);
|
p.optimized_out.push_back(rename_id);
|
||||||
p.nodes_map.erase(rename_id);
|
p.nodes_map.erase(rename_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (was_output) {
|
if (was_output) {
|
||||||
conv_node->set_output(true);
|
conv_node.set_output(true);
|
||||||
p.get_outputs().push_back(conv_node);
|
p.get_outputs().push_back(&conv_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
p.mark_if_data_flow(*conv_node);
|
p.mark_if_data_flow(conv_node);
|
||||||
conv_node->recalc_output_layout(true);
|
conv_node.recalc_output_layout(true);
|
||||||
|
|
||||||
update_processing_order = true;
|
update_processing_order = true;
|
||||||
// current optimization only available for specific deconvolution parameters
|
// current optimization only available for specific deconvolution parameters
|
||||||
} else if (node->is_output() == false &&
|
} else if (deconv_node.is_output() == false &&
|
||||||
node->get_output_layout().size.feature[0] == 1 &&
|
deconv_node.get_output_layout().size.feature[0] == 1 &&
|
||||||
deconv_prim->stride.spatial[0] == 2 && deconv_prim->stride.spatial[1] == 2 &&
|
deconv_prim->stride.spatial[0] == 2 && deconv_prim->stride.spatial[1] == 2 &&
|
||||||
filter_size.spatial[0] == 9 && filter_size.spatial[1] == 9 &&
|
filter_size.spatial[0] == 9 && filter_size.spatial[1] == 9 &&
|
||||||
deconv_prim->input_offset.spatial[0] == -4 && deconv_prim->input_offset.spatial[1] == -4 &&
|
deconv_prim->input_offset.spatial[0] == -4 && deconv_prim->input_offset.spatial[1] == -4 &&
|
||||||
weights_vec.size() == 1 && deconv_prim->bias.size() == 1 &&
|
weights_nodes_id.size() == 1 && biases_nodes_id.size() == 1 &&
|
||||||
node->get_dependency(0).get_output_layout().format == format::bfyx) {
|
input_node.get_output_layout().format == format::bfyx) {
|
||||||
primitive_id deconv_id = node->id();
|
const auto scale_factor = deconv_prim->stride.spatial[0];
|
||||||
auto& input_node = node->get_dependency(0);
|
|
||||||
primitive_id input_id = deconv_prim->input[0];
|
|
||||||
|
|
||||||
auto scale_factor = deconv_prim->stride.spatial[0];
|
const auto& weight_node_id = weights_nodes_id.front();
|
||||||
|
auto weights_node_ptr = p.nodes_map.find(weight_node_id)->second;
|
||||||
|
const auto& weights_layout = weights_node_ptr->get_output_layout();
|
||||||
|
const auto& weights_data_type = weights_layout.data_type;
|
||||||
|
|
||||||
auto cur_weights_node_ptr = p.nodes_map.find(weights_vec[0])->second;
|
const auto& bias_node_id = biases_nodes_id.front();
|
||||||
auto weights_layout = cur_weights_node_ptr->get_output_layout();
|
auto bias_id_node_ptr = p.nodes_map.find(bias_node_id)->second;
|
||||||
auto weights_data_type = weights_layout.data_type;
|
const auto& bias_data_type = bias_id_node_ptr->get_output_layout().data_type;
|
||||||
|
|
||||||
auto biases = deconv_prim->bias[0];
|
|
||||||
auto bias_id_node_ptr = p.nodes_map.find(biases)->second;
|
|
||||||
auto bias_data_type = bias_id_node_ptr->get_output_layout().data_type;
|
|
||||||
|
|
||||||
// enable only for fp32 and fp16
|
// enable only for fp32 and fp16
|
||||||
if (weights_data_type != data_types::f16 &&
|
if (weights_data_type != data_types::f16 &&
|
||||||
@ -229,14 +198,13 @@ void pre_replace_deconv::run(program_impl& p) {
|
|||||||
|
|
||||||
// remove deconvolution node and its connections to weights and biases,
|
// remove deconvolution node and its connections to weights and biases,
|
||||||
// rename it and move to the optimized list
|
// rename it and move to the optimized list
|
||||||
p.remove_connection(node->get_dependency(0), *node);
|
p.remove_connection(input_node, deconv_node);
|
||||||
|
|
||||||
auto weights_node_ptr = p.nodes_map.find(weights_vec[0])->second;
|
p.remove_connection(*weights_node_ptr, deconv_node);
|
||||||
p.remove_connection(*weights_node_ptr, *node);
|
p.remove_connection(*bias_id_node_ptr, deconv_node);
|
||||||
p.remove_connection(*bias_id_node_ptr, *node);
|
|
||||||
|
|
||||||
auto rename_id = deconv_id + "_tmp";
|
auto rename_id = deconv_node_id + "_tmp";
|
||||||
p.rename(*node, rename_id);
|
p.rename(deconv_node, rename_id);
|
||||||
|
|
||||||
// reshape weights
|
// reshape weights
|
||||||
int pixel_shuffle_size = scale_factor * scale_factor;
|
int pixel_shuffle_size = scale_factor * scale_factor;
|
||||||
@ -244,17 +212,18 @@ void pre_replace_deconv::run(program_impl& p) {
|
|||||||
tensor target_weights_size = { pixel_shuffle_size, filter_size.feature[0], kernel_size, kernel_size };
|
tensor target_weights_size = { pixel_shuffle_size, filter_size.feature[0], kernel_size, kernel_size };
|
||||||
auto target_weights_layout = layout{ weights_layout.data_type, weights_layout.format, target_weights_size };
|
auto target_weights_layout = layout{ weights_layout.data_type, weights_layout.format, target_weights_size };
|
||||||
|
|
||||||
|
const primitive_id weight_replace_node_id = weight_node_id + "_conv_rpl";
|
||||||
{
|
{
|
||||||
memory::ptr data_to_allocate = p.get_engine().allocate_memory(target_weights_layout);
|
memory::ptr data_to_allocate = p.get_engine().allocate_memory(target_weights_layout);
|
||||||
|
|
||||||
std::vector<float> weights_vec_float;
|
std::vector<float> weights_vec_float;
|
||||||
|
|
||||||
if (weights_data_type == data_types::f16) {
|
if (weights_data_type == data_types::f16) {
|
||||||
mem_lock<half_t> src{ cur_weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
|
mem_lock<half_t> src{ weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
|
||||||
for (uint32_t i = 0; i < weights_layout.size.count(); i++)
|
for (uint32_t i = 0; i < weights_layout.size.count(); i++)
|
||||||
weights_vec_float.push_back(static_cast<float>(src.data()[i]));
|
weights_vec_float.push_back(static_cast<float>(src.data()[i]));
|
||||||
} else {
|
} else {
|
||||||
mem_lock<float> src{ cur_weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
|
mem_lock<float> src{ weights_node_ptr->as<data>().get_attached_memory_ptr(), stream };
|
||||||
for (uint32_t i = 0; i < weights_layout.size.count(); i++)
|
for (uint32_t i = 0; i < weights_layout.size.count(); i++)
|
||||||
weights_vec_float.push_back(src.data()[i]);
|
weights_vec_float.push_back(src.data()[i]);
|
||||||
}
|
}
|
||||||
@ -278,12 +247,36 @@ void pre_replace_deconv::run(program_impl& p) {
|
|||||||
throw std::logic_error("Not supported data type.");
|
throw std::logic_error("Not supported data type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto data_node_weights_replace = std::make_shared<data>(weights_vec[0] + "_conv_rpl", data_to_allocate);
|
auto data_node_weights_replace = std::make_shared<data>(weight_replace_node_id, data_to_allocate);
|
||||||
p.get_or_create(data_node_weights_replace);
|
program_node& weights_replace_node = p.get_or_create(data_node_weights_replace);
|
||||||
auto data_node_weights_replace_node_ptr = p.nodes_map.find(weights_vec[0] + "_conv_rpl")->second;
|
auto& data_node = weights_replace_node.as<data>();
|
||||||
auto& data_node = data_node_weights_replace_node_ptr->as<data>();
|
|
||||||
data_node.set_output_layout(target_weights_layout, false);
|
data_node.set_output_layout(target_weights_layout, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto deconv_id_conv = deconv_node_id + "_conv";
|
||||||
|
|
||||||
|
// create convolution primitive
|
||||||
|
auto conv_prim = std::make_shared<convolution>(deconv_id_conv,
|
||||||
|
input_node_id,
|
||||||
|
std::vector<primitive_id>{ weight_replace_node_id },
|
||||||
|
stride,
|
||||||
|
input_offset,
|
||||||
|
tensor{ 1, 1, 1, 1 },
|
||||||
|
grouped_weights_shape,
|
||||||
|
output_padding);
|
||||||
|
program_node& created_node = p.get_or_create(conv_prim);
|
||||||
|
|
||||||
|
auto& conv_node = created_node.as<convolution>();
|
||||||
|
|
||||||
|
// add connections input->convolution, weights->convolution and bias->convolution
|
||||||
|
p.add_connection(input_node, conv_node);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto weights_node_conv_rpl_ptr = p.nodes_map.find(weight_replace_node_id)->second;
|
||||||
|
p.add_connection(*weights_node_conv_rpl_ptr, conv_node);
|
||||||
|
p.inputs.push_back(weights_node_conv_rpl_ptr.get());
|
||||||
|
}
|
||||||
|
|
||||||
float bias = 0;
|
float bias = 0;
|
||||||
|
|
||||||
if (bias_data_type == data_types::f16) {
|
if (bias_data_type == data_types::f16) {
|
||||||
@ -293,52 +286,22 @@ void pre_replace_deconv::run(program_impl& p) {
|
|||||||
mem_lock<float> src{ bias_id_node_ptr->as<data>().get_attached_memory_ptr(), stream };
|
mem_lock<float> src{ bias_id_node_ptr->as<data>().get_attached_memory_ptr(), stream };
|
||||||
bias = src.data()[0];
|
bias = src.data()[0];
|
||||||
}
|
}
|
||||||
|
auto pixel_shuffle_prim = std::make_shared<depth_to_space>(deconv_node_id, deconv_id_conv, 2, depth_to_space_mode::blocks_first);
|
||||||
|
|
||||||
auto deconv_id_conv = deconv_id + "_conv";
|
program_node& pixel_shuffle_node = p.get_or_create(pixel_shuffle_prim);
|
||||||
|
pixel_shuffle_node.add_fused_activation(activation_func::linear, { 1, bias });
|
||||||
// create convolution primitive
|
|
||||||
auto conv_prim = std::make_shared<convolution>(deconv_id_conv,
|
|
||||||
input_id,
|
|
||||||
std::vector<primitive_id>{ weights_vec[0] + "_conv_rpl" },
|
|
||||||
stride,
|
|
||||||
input_offset,
|
|
||||||
tensor{ 1, 1, 1, 1 },
|
|
||||||
grouped_weights_shape,
|
|
||||||
output_padding);
|
|
||||||
p.get_or_create(conv_prim);
|
|
||||||
|
|
||||||
auto conv_node_itr = p.nodes_map.find(deconv_id_conv);
|
|
||||||
if (conv_node_itr == p.nodes_map.end()) continue;
|
|
||||||
|
|
||||||
auto conv_node_ptr = conv_node_itr->second;
|
|
||||||
auto conv_node = &conv_node_ptr->as<convolution>();
|
|
||||||
|
|
||||||
// add connections input->convolution, weights->convolution and bias->convolution
|
|
||||||
p.add_connection(input_node, *conv_node_ptr);
|
|
||||||
|
|
||||||
{
|
|
||||||
auto weights_node_conv_rpl_ptr = p.nodes_map.find(weights_vec[0] + "_conv_rpl")->second;
|
|
||||||
p.add_connection(*weights_node_conv_rpl_ptr, *conv_node_ptr);
|
|
||||||
p.inputs.push_back(weights_node_conv_rpl_ptr.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pixel_shuffle_prim = std::make_shared<depth_to_space>(deconv_id, deconv_id_conv, 2, depth_to_space_mode::blocks_first);
|
|
||||||
|
|
||||||
p.get_or_create(pixel_shuffle_prim);
|
|
||||||
auto pixel_shuffle_node_ptr = p.nodes_map.find(deconv_id)->second;
|
|
||||||
pixel_shuffle_node_ptr->add_fused_activation(activation_func::linear, { 1, bias });
|
|
||||||
|
|
||||||
// add connections input->convolution, weights->convolution
|
// add connections input->convolution, weights->convolution
|
||||||
p.add_connection(*conv_node_ptr, *pixel_shuffle_node_ptr);
|
p.add_connection(conv_node, pixel_shuffle_node);
|
||||||
|
|
||||||
auto deconv_node_ptr = p.nodes_map.find(rename_id);
|
auto deconv_node_ptr = p.nodes_map.find(rename_id);
|
||||||
if (deconv_node_ptr != p.nodes_map.end()) {
|
if (deconv_node_ptr != p.nodes_map.end()) {
|
||||||
p.replace_all_usages(*deconv_node_ptr->second, *pixel_shuffle_node_ptr);
|
p.replace_all_usages(*deconv_node_ptr->second, pixel_shuffle_node);
|
||||||
p.optimized_out.push_back(rename_id);
|
p.optimized_out.push_back(rename_id);
|
||||||
p.nodes_map.erase(rename_id);
|
p.nodes_map.erase(rename_id);
|
||||||
}
|
}
|
||||||
p.mark_if_data_flow(*conv_node);
|
p.mark_if_data_flow(conv_node);
|
||||||
conv_node->recalc_output_layout(true);
|
conv_node.recalc_output_layout(true);
|
||||||
|
|
||||||
update_processing_order = true;
|
update_processing_order = true;
|
||||||
}
|
}
|
||||||
|
@ -834,13 +834,10 @@ void program_impl::swap_names(program_node& node1, program_node& node2) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void program_impl::replace_all_usages(program_node& old_node, program_node& new_node) {
|
void program_impl::replace_all_usages(program_node& old_node, program_node& new_node) {
|
||||||
const std::list<program_node*> users(old_node.users);
|
auto itr = old_node.users.begin();
|
||||||
auto itr = users.begin();
|
while (itr != old_node.users.end()) {
|
||||||
bool end = (itr == users.end());
|
auto user = *(itr++);
|
||||||
while (!end) {
|
user->replace_dependency(old_node, new_node);
|
||||||
auto& usage = (*itr++);
|
|
||||||
end = (itr == users.end());
|
|
||||||
usage->replace_dependency(old_node, new_node);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user