Add reshape in front of a fully connected node for using bf input (#9449)

Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
Min, Byungil 2022-01-11 13:18:40 +09:00 committed by GitHub
parent fc4185e92a
commit e0485c1ad2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 5 deletions

View File

@ -28,7 +28,7 @@ void prepare_padding::run(program& p) {
continue;
auto add_required_padding = [&p](program_node& node, padding& needed_padding) {
// Add extra reorder for cldnn primitive to handle required padding if needed
// Add extra reorder if a previous node or one of its user nodes is an onednn kernel not to add padding to the onednn kernel
auto& input = node.get_dependency(0);
bool is_usr_onednn = false;
for (auto& input_usr : input.get_users())

View File

@ -13,6 +13,7 @@
#include "binary_convolution_inst.h"
#include "mvn_inst.h"
#include "to_string_utils.h"
#include "reshape_inst.h"
#include <vector>
#include <memory>
@ -575,13 +576,30 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
}
};
const auto reorder_input_fully_connected = [&p, &lo, &rf](typed_program_node<fully_connected>& fc_node) {
auto& weights = fc_node.weights();
auto& input = fc_node.input();
auto input_layout = input.get_output_layout();
// Change input data of fully-connected node from bx to bf
if (format::is_simple_data_format(input_layout.format) && weights.is_constant() && input_layout.format.dimension() == 4 &&
input_layout.size.feature[0] == 1 && input_layout.size.spatial[0] != 1 && input_layout.size.spatial[1] == 1) {
auto new_tensor = input_layout.size;
new_tensor.feature[0] = input_layout.size.spatial[0];
new_tensor.spatial[0] = 1;
auto new_reshape = std::make_shared<reshape>("reorder:Reshape_bf_" + fc_node.id() + "_for_input", input.id(), new_tensor);
auto& new_reorder_node = p.get_or_create(new_reshape);
p.add_intermediate(new_reorder_node, fc_node, 0);
}
};
for (auto& prim : p.get_processing_order()) {
program_helpers::do_for_types<detection_output, binary_convolution, deconvolution, convolution>(
program_helpers::do_for_types<detection_output, binary_convolution, deconvolution, convolution, fully_connected>(
*prim,
reorder_input_detection_output,
reorder_input_binary_convolution,
reorder_input_and_weights_deconvolution,
reorder_weights_convolution);
reorder_weights_convolution,
reorder_input_fully_connected);
}
for (auto n : p.get_processing_order()) {

View File

@ -1618,8 +1618,9 @@ TEST(fully_connected_onednn_gpu, no_biases_int8) {
auto& engine = get_onednn_test_engine();
// Change input data of fully-connected node from bx to bf
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { input_b, 1, input_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::i8, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::i8, format::bfyx, { weight_b, weight_x, 1, 1 } });
set_values(input_prim, { 8.4f, 2.3f, -4.49f });
set_values<char>(weights_prim, { 2, 1, 0, -3, -2, 1, 0, -2, -4, -5, 10, 8 });