[GPU][DG2] Fix output scale post-op condition (#13567)

* fix bug in oscale post-op condition
This commit is contained in:
Felix Dohyun Kim 2022-10-24 10:17:00 +09:00 committed by GitHub
parent a55b277c68
commit 385d87edaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 8 deletions

View File

@ -989,16 +989,23 @@ void program_node::init_onednn_primitive_attributes() {
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx); update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
} }
} else { } else {
// convolution using post-op output scales can only be int8/uint8 auto input_datatype = get_dependency(0).get_output_layout().data_type;
if (idx == 0 && !has_out_scales(attrs) && !is_type<pooling>() && !is_type<reduce>() && // convolution using post-op output scales can only be used when i8/u8 input (which use integer accumulator)
!(is_type<convolution>() && data_type_traits::is_floating_point(output_layout.data_type))) { bool cant_use_output_scales =
int mask = in.count() > 1 ? 2 : 0; idx != 0 ||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL}); has_out_scales(attrs) ||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx); is_type<pooling>() ||
} else { is_type<reduce>() ||
(is_type<convolution>() && data_type_traits::is_floating_point(input_datatype)) ||
(is_type<deconvolution>() && data_type_traits::is_floating_point(input_datatype));
if (cant_use_output_scales) {
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true); dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc); post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx); update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
} else {
int mask = in.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
} }
} }
} else if (desc.is_type<quantize>()) { } else if (desc.is_type<quantize>()) {

View File

@ -1461,7 +1461,7 @@ TEST_P(conv_fp32_scale_activation_quantize_i8_eltwise_fp32_quantize_i8, basic) {
reorder("reorder_bfyx", "quantize_1", p.default_format, data_types::f32) reorder("reorder_bfyx", "quantize_1", p.default_format, data_types::f32)
); );
tolerance = 1.f; tolerance = 2.f;
execute(p); execute(p);
} }