[GPU][DG2] Fix output scale post-op condition (#13567)
* fix bug in oscale post-op condition
This commit is contained in:
parent
a55b277c68
commit
385d87edaf
@ -989,16 +989,23 @@ void program_node::init_onednn_primitive_attributes() {
|
|||||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
|
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// convolution using post-op output scales can only be int8/uint8
|
auto input_datatype = get_dependency(0).get_output_layout().data_type;
|
||||||
if (idx == 0 && !has_out_scales(attrs) && !is_type<pooling>() && !is_type<reduce>() &&
|
// convolution using post-op output scales can only be used when i8/u8 input (which use integer accumulator)
|
||||||
!(is_type<convolution>() && data_type_traits::is_floating_point(output_layout.data_type))) {
|
bool cant_use_output_scales =
|
||||||
int mask = in.count() > 1 ? 2 : 0;
|
idx != 0 ||
|
||||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
has_out_scales(attrs) ||
|
||||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
|
is_type<pooling>() ||
|
||||||
} else {
|
is_type<reduce>() ||
|
||||||
|
(is_type<convolution>() && data_type_traits::is_floating_point(input_datatype)) ||
|
||||||
|
(is_type<deconvolution>() && data_type_traits::is_floating_point(input_datatype));
|
||||||
|
if (cant_use_output_scales) {
|
||||||
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
|
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
|
||||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
|
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
|
||||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
|
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
|
||||||
|
} else {
|
||||||
|
int mask = in.count() > 1 ? 2 : 0;
|
||||||
|
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||||
|
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (desc.is_type<quantize>()) {
|
} else if (desc.is_type<quantize>()) {
|
||||||
|
@ -1461,7 +1461,7 @@ TEST_P(conv_fp32_scale_activation_quantize_i8_eltwise_fp32_quantize_i8, basic) {
|
|||||||
reorder("reorder_bfyx", "quantize_1", p.default_format, data_types::f32)
|
reorder("reorder_bfyx", "quantize_1", p.default_format, data_types::f32)
|
||||||
);
|
);
|
||||||
|
|
||||||
tolerance = 1.f;
|
tolerance = 2.f;
|
||||||
execute(p);
|
execute(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user