[GPU][DG2] Fix output scale post-op condition (#13567)

* fix bug in oscale post-op condition
This commit is contained in:
Felix Dohyun Kim 2022-10-24 10:17:00 +09:00 committed by GitHub
parent a55b277c68
commit 385d87edaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 8 deletions

View File

@ -989,16 +989,23 @@ void program_node::init_onednn_primitive_attributes() {
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
}
} else {
// convolution using post-op output scales can only be int8/uint8
if (idx == 0 && !has_out_scales(attrs) && !is_type<pooling>() && !is_type<reduce>() &&
!(is_type<convolution>() && data_type_traits::is_floating_point(output_layout.data_type))) {
int mask = in.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
} else {
auto input_datatype = get_dependency(0).get_output_layout().data_type;
// convolution using post-op output scales can only be used when i8/u8 input (which use integer accumulator)
bool cant_use_output_scales =
idx != 0 ||
has_out_scales(attrs) ||
is_type<pooling>() ||
is_type<reduce>() ||
(is_type<convolution>() && data_type_traits::is_floating_point(input_datatype)) ||
(is_type<deconvolution>() && data_type_traits::is_floating_point(input_datatype));
if (cant_use_output_scales) {
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
} else {
int mask = in.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
}
}
} else if (desc.is_type<quantize>()) {

View File

@ -1461,7 +1461,7 @@ TEST_P(conv_fp32_scale_activation_quantize_i8_eltwise_fp32_quantize_i8, basic) {
reorder("reorder_bfyx", "quantize_1", p.default_format, data_types::f32)
);
tolerance = 1.f;
tolerance = 2.f;
execute(p);
}