Set reorder format in reorder_inputs pass (#11992)
* Set reorder format in reorder_inputs pass * set zyx_fsv16 formats if input is zyx_fsv4 formats
This commit is contained in:
parent
3e97d12fe2
commit
654105d567
@ -19,6 +19,46 @@ primitive_type_id convolution::type_id() {
|
||||
return &instance;
|
||||
}
|
||||
|
||||
static format get_recommended_format(layout input_layout, data_types output_type) {
|
||||
if (data_type_traits::is_i8_u8(output_type)) {
|
||||
switch (input_layout.format) {
|
||||
case format::b_fs_yx_fsv16: return format::b_fs_yx_fsv32;
|
||||
case format::bs_fs_yx_bsv32_fsv16: return format::bs_fs_yx_bsv32_fsv32;
|
||||
case format::b_fs_zyx_fsv16: return format::b_fs_zyx_fsv32;
|
||||
case format::bs_fs_zyx_bsv32_fsv16: return format::bs_fs_zyx_bsv32_fsv32;
|
||||
case format::b_fs_yx_fsv2:
|
||||
case format::b_fs_yx_fsv4: return format::b_fs_yx_fsv32;
|
||||
case format::b_fs_zyx_fsv2:
|
||||
case format::b_fs_zyx_fsv4: return format::b_fs_zyx_fsv32;
|
||||
case format::bs_fs_yx_bsv8_fsv2:
|
||||
case format::bs_fs_yx_bsv8_fsv4: return input_layout.batch() > 16 ? format::bs_fs_yx_bsv32_fsv32 : format::b_fs_yx_fsv32;
|
||||
case format::bs_fs_zyx_bsv8_fsv2:
|
||||
case format::bs_fs_zyx_bsv8_fsv4: return input_layout.batch() > 16 ? format::bs_fs_zyx_bsv32_fsv32 : format::b_fs_zyx_fsv32;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else if (data_type_traits::is_floating_point(output_type)) {
|
||||
switch (input_layout.format) {
|
||||
case format::b_fs_yx_fsv32: return format::b_fs_yx_fsv16;
|
||||
case format::bs_fs_yx_bsv32_fsv32: return format::bs_fs_yx_bsv32_fsv16;
|
||||
case format::b_fs_zyx_fsv32: return format::b_fs_zyx_fsv16;
|
||||
case format::bs_fs_zyx_bsv32_fsv32: return format::bs_fs_zyx_bsv32_fsv16;
|
||||
case format::b_fs_yx_fsv2:
|
||||
case format::b_fs_yx_fsv4: return format::b_fs_yx_fsv16;
|
||||
case format::b_fs_zyx_fsv2:
|
||||
case format::b_fs_zyx_fsv4: return format::b_fs_zyx_fsv16;
|
||||
case format::bs_fs_yx_bsv8_fsv2:
|
||||
case format::bs_fs_yx_bsv8_fsv4: return input_layout.batch() > 16 ? format::bs_fs_yx_bsv32_fsv16 : format::b_fs_yx_fsv16;
|
||||
case format::bs_fs_zyx_bsv8_fsv2:
|
||||
case format::bs_fs_zyx_bsv8_fsv4: return input_layout.batch() > 16 ? format::bs_fs_zyx_bsv32_fsv16 : format::b_fs_zyx_fsv16;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return format::any;
|
||||
}
|
||||
|
||||
layout convolution_inst::calc_output_layout(convolution_node const& node) {
|
||||
auto desc = node.get_primitive();
|
||||
|
||||
@ -173,40 +213,10 @@ layout convolution_inst::calc_output_layout(convolution_node const& node) {
|
||||
|
||||
// Adjust output format for mixed precision case in onednn
|
||||
auto out_fmt = input_layout.format;
|
||||
bool is_2d = (input_layout.format.spatial_num() == 2);
|
||||
bool is_3d = (input_layout.format.spatial_num() == 3);
|
||||
if (node.get_preferred_impl_type() == impl_types::onednn) {
|
||||
if (data_type_traits::is_i8_u8(output_type)) {
|
||||
if (is_2d) {
|
||||
if (input_layout.format == format::b_fs_yx_fsv16)
|
||||
out_fmt = format::b_fs_yx_fsv32;
|
||||
else if (input_layout.format == format::bs_fs_yx_bsv32_fsv16)
|
||||
out_fmt = format::bs_fs_yx_bsv32_fsv32;
|
||||
else if (input_layout.format == format::b_fs_yx_fsv2)
|
||||
out_fmt = format::b_fs_yx_fsv32;
|
||||
} else if (is_3d) {
|
||||
if (input_layout.format == format::b_fs_zyx_fsv16)
|
||||
out_fmt = format::b_fs_zyx_fsv32;
|
||||
else if (input_layout.format == format::bs_fs_zyx_bsv32_fsv16)
|
||||
out_fmt = format::bs_fs_zyx_bsv32_fsv32;
|
||||
}
|
||||
} else if (data_type_traits::is_floating_point(output_type)) {
|
||||
if (is_2d) {
|
||||
if (input_layout.format == format::b_fs_yx_fsv32)
|
||||
out_fmt = format::b_fs_yx_fsv16;
|
||||
else if (input_layout.format == format::bs_fs_yx_bsv32_fsv32)
|
||||
out_fmt = format::bs_fs_yx_bsv32_fsv16;
|
||||
} else if (is_3d) {
|
||||
if (input_layout.format == format::b_fs_zyx_fsv32)
|
||||
out_fmt = format::b_fs_zyx_fsv16;
|
||||
else if (input_layout.format == format::bs_fs_zyx_bsv32_fsv32)
|
||||
out_fmt = format::bs_fs_zyx_bsv32_fsv16;
|
||||
else if (input_layout.format == format::b_fs_zyx_fsv2)
|
||||
out_fmt = format::b_fs_zyx_fsv16;
|
||||
else if (input_layout.format == format::bs_fs_zyx_bsv8_fsv2)
|
||||
out_fmt = input_layout.batch() > 16 ? format::bs_fs_zyx_bsv32_fsv16 : format::b_fs_zyx_fsv16;
|
||||
}
|
||||
}
|
||||
format recommended_fmt = get_recommended_format(input_layout, output_type);
|
||||
if (recommended_fmt != format::any)
|
||||
out_fmt = recommended_fmt;
|
||||
}
|
||||
|
||||
// get output feature map from weights. It should be the same as number of biases. Will be verifed in
|
||||
|
@ -608,12 +608,10 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
|
||||
if (new_layout == input_layout)
|
||||
return;
|
||||
|
||||
if (!input.is_type<reorder>() || input.get_users().size() > 1) {
|
||||
auto new_input = rf.get_reorder(input.id(), input_layout, new_layout);
|
||||
if (new_input.first)
|
||||
p.add_intermediate(new_input.first, conv_node, 0, !new_input.second);
|
||||
auto new_input = rf.get_reorder(input.id(), input_layout, new_layout);
|
||||
if (new_input.first) {
|
||||
p.add_intermediate(new_input.first, conv_node, 0, !new_input.second);
|
||||
}
|
||||
|
||||
conv_node.get_dependencies().front()->set_output_layout(new_layout, false);
|
||||
}
|
||||
|
||||
|
@ -32,16 +32,6 @@ layout reorder_inst::calc_output_layout(reorder_node const& node) {
|
||||
ofmt = ifmt;
|
||||
}
|
||||
|
||||
if (node.is_valid_output_layout() && input_layout.feature() <= 4) {
|
||||
auto users = node.get_users();
|
||||
if (users.size() > 0 && users.front()->is_type<convolution>()) {
|
||||
auto expected_fmt = node.get_output_layout().format;
|
||||
if (expected_fmt == format::b_fs_zyx_fsv2 || expected_fmt == format::bs_fs_zyx_bsv8_fsv2) {
|
||||
ofmt = expected_fmt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ifmt.is_nv12()) {
|
||||
auto data_size = tensor{ input_layout.batch(), input_layout.feature() * 3,
|
||||
input_layout.spatial(0), input_layout.spatial(1) };
|
||||
|
Loading…
Reference in New Issue
Block a user