Add reshape in front of a fully connected node for using bf input (#9449)
Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
parent
fc4185e92a
commit
e0485c1ad2
@ -28,7 +28,7 @@ void prepare_padding::run(program& p) {
|
||||
continue;
|
||||
|
||||
auto add_required_padding = [&p](program_node& node, padding& needed_padding) {
|
||||
// Add extra reorder for cldnn primitive to handle required padding if needed
|
||||
// Add extra reorder if a previous node or one of its user nodes is an onednn kernel not to add padding to the onednn kernel
|
||||
auto& input = node.get_dependency(0);
|
||||
bool is_usr_onednn = false;
|
||||
for (auto& input_usr : input.get_users())
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include "binary_convolution_inst.h"
|
||||
#include "mvn_inst.h"
|
||||
#include "to_string_utils.h"
|
||||
#include "reshape_inst.h"
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
@ -575,13 +576,30 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
|
||||
}
|
||||
};
|
||||
|
||||
const auto reorder_input_fully_connected = [&p, &lo, &rf](typed_program_node<fully_connected>& fc_node) {
|
||||
auto& weights = fc_node.weights();
|
||||
auto& input = fc_node.input();
|
||||
auto input_layout = input.get_output_layout();
|
||||
// Change input data of fully-connected node from bx to bf
|
||||
if (format::is_simple_data_format(input_layout.format) && weights.is_constant() && input_layout.format.dimension() == 4 &&
|
||||
input_layout.size.feature[0] == 1 && input_layout.size.spatial[0] != 1 && input_layout.size.spatial[1] == 1) {
|
||||
auto new_tensor = input_layout.size;
|
||||
new_tensor.feature[0] = input_layout.size.spatial[0];
|
||||
new_tensor.spatial[0] = 1;
|
||||
auto new_reshape = std::make_shared<reshape>("reorder:Reshape_bf_" + fc_node.id() + "_for_input", input.id(), new_tensor);
|
||||
auto& new_reorder_node = p.get_or_create(new_reshape);
|
||||
p.add_intermediate(new_reorder_node, fc_node, 0);
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& prim : p.get_processing_order()) {
|
||||
program_helpers::do_for_types<detection_output, binary_convolution, deconvolution, convolution>(
|
||||
program_helpers::do_for_types<detection_output, binary_convolution, deconvolution, convolution, fully_connected>(
|
||||
*prim,
|
||||
reorder_input_detection_output,
|
||||
reorder_input_binary_convolution,
|
||||
reorder_input_and_weights_deconvolution,
|
||||
reorder_weights_convolution);
|
||||
reorder_weights_convolution,
|
||||
reorder_input_fully_connected);
|
||||
}
|
||||
|
||||
for (auto n : p.get_processing_order()) {
|
||||
|
@ -1618,8 +1618,9 @@ TEST(fully_connected_onednn_gpu, no_biases_int8) {
|
||||
|
||||
auto& engine = get_onednn_test_engine();
|
||||
|
||||
// Change input data of fully-connected node from bx to bf
|
||||
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { input_b, 1, input_x, 1 } });
|
||||
auto weights_prim = engine.allocate_memory({ data_types::i8, format::bfyx, { weight_b, 1, weight_x, 1 } });
|
||||
auto weights_prim = engine.allocate_memory({ data_types::i8, format::bfyx, { weight_b, weight_x, 1, 1 } });
|
||||
|
||||
set_values(input_prim, { 8.4f, 2.3f, -4.49f });
|
||||
set_values<char>(weights_prim, { 2, 1, 0, -3, -2, 1, 0, -2, -4, -5, 10, 8 });
|
||||
|
Loading…
Reference in New Issue
Block a user