reset hard (#12055)

This commit is contained in:
Felix Dohyun Kim 2022-08-19 09:57:54 +09:00 committed by GitHub
parent 37a0bddd76
commit 9f0acc4535
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 8 deletions

View File

@ -51,6 +51,9 @@
#include <utility> #include <utility>
#include <deque> #include <deque>
#include "intel_gpu/runtime/error_handler.hpp" #include "intel_gpu/runtime/error_handler.hpp"
#ifdef ENABLE_ONEDNN_FOR_GPU
#include <impls/onednn/utils.hpp>
#endif
void prepare_primitive_fusing::run(program& p) { void prepare_primitive_fusing::run(program& p) {
fuse_reorders(p); fuse_reorders(p);
@ -246,8 +249,18 @@ void prepare_primitive_fusing::fuse_activations(program &p) {
return; return;
} }
if (input.is_type<reshape>() && use_onednn_impls) if (use_onednn_impls) {
if (input.is_type<reshape>())
return; return;
#ifdef ENABLE_ONEDNN_FOR_GPU
// Activation should not fused if it isn't supported in onednn
try {
onednn::convert_activation_func(node.get_primitive()->activation_function);
} catch (...) {
return;
}
#endif
}
if (input.get_fused_primitives().empty()) { if (input.get_fused_primitives().empty()) {
input.add_fused_activation(node.get_primitive()->activation_function, node.get_primitive()->additional_params); input.add_fused_activation(node.get_primitive()->activation_function, node.get_primitive()->additional_params);
@ -706,6 +719,17 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
if (!input_data_supports_fusings(input_data, activation_node.id()) || input_data.get_dependencies().empty()) if (!input_data_supports_fusings(input_data, activation_node.id()) || input_data.get_dependencies().empty())
return; return;
if (_lo.get_optimization_attributes().use_onednn_impls) {
#ifdef ENABLE_ONEDNN_FOR_GPU
// Activation should not fused if it isn't supported in onednn
try {
onednn::convert_activation_func(activation_node.get_primitive()->activation_function);
} catch (...) {
return;
}
#endif
}
bool should_fuse = input_data.is_type<binary_convolution>(); bool should_fuse = input_data.is_type<binary_convolution>();
should_fuse |= input_data.is_type<convolution>() && conv_supports_fusings(input_data.as<convolution>()); should_fuse |= input_data.is_type<convolution>() && conv_supports_fusings(input_data.as<convolution>());

View File

@ -372,7 +372,7 @@ TEST_P(deconv_scale, basic) {
data("weights", get_mem(get_weights_layout(p))), data("weights", get_mem(get_weights_layout(p))),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())), data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
deconvolution("deconv", "input", { "weights" }, p.groups, p.stride, p.pad), deconvolution("deconv", "input", { "weights" }, p.groups, p.stride, p.pad),
scale("scale", "deconv", "scale_data"), eltwise("scale", { "deconv", "scale_data" }, eltwise_mode::prod),
reorder("out", "scale", p.default_format, data_types::f32) reorder("out", "scale", p.default_format, data_types::f32)
); );
@ -387,7 +387,7 @@ TEST_P(deconv_scale, fp16_scale_out) {
data("weights", get_mem(get_weights_layout(p))), data("weights", get_mem(get_weights_layout(p))),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())), data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
deconvolution("deconv", "input", { "weights" }, p.groups, p.stride, p.pad), deconvolution("deconv", "input", { "weights" }, p.groups, p.stride, p.pad),
scale("scale", "deconv", "scale_data", optional_data_type{ data_types::f16 }), eltwise("scale", { "deconv", "scale_data" }, eltwise_mode::prod, data_types::f16),
reorder("out", "scale", p.default_format, data_types::f32) reorder("out", "scale", p.default_format, data_types::f32)
); );
@ -541,12 +541,16 @@ TEST_P(deconv_scale_actv_quant_i8, basic) {
data("out_lo", get_mem(get_single_element_layout(p), -127)), data("out_lo", get_mem(get_single_element_layout(p), -127)),
data("out_hi", get_mem(get_single_element_layout(p), 127)), data("out_hi", get_mem(get_single_element_layout(p), 127)),
deconvolution("deconv", "input", { "weights" }, p.groups, p.stride, p.pad), deconvolution("deconv", "input", { "weights" }, p.groups, p.stride, p.pad),
scale("scale", "deconv", "scale_data"), eltwise("scale", { "deconv", "scale_data" }, eltwise_mode::prod),
activation("actv", "scale", activation_func::softsign), activation("actv", "scale", activation_func::softsign),
quantize("quant", "actv", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8), quantize("quant", "actv", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
reorder("out", "quant", p.default_format, data_types::f32) reorder("out", "quant", p.default_format, data_types::f32)
); );
//Activation won't be fused because onednn doesn't support softsign activation
if(engine.get_device_info().supports_immad)
p.expected_fused_primitives++;
tolerance = 1.f; tolerance = 1.f;
execute(p); execute(p);
} }
@ -647,11 +651,11 @@ TEST_P(deconv_scale_actv_quant_u8_eltw_scale_actv_quant_i8, basic) {
data("out2_lo", get_mem(get_single_element_layout(p), -127)), data("out2_lo", get_mem(get_single_element_layout(p), -127)),
data("out2_hi", get_mem(get_single_element_layout(p), 127)), data("out2_hi", get_mem(get_single_element_layout(p), 127)),
deconvolution("deconv", "input", { "weights" }, p.groups, p.stride, p.pad), deconvolution("deconv", "input", { "weights" }, p.groups, p.stride, p.pad),
scale("scale1", "deconv", "scale1_data"), eltwise("scale1", { "deconv", "scale1_data" }, eltwise_mode::prod),
activation("actv1", "scale1", activation_func::relu), activation("actv1", "scale1", activation_func::relu),
quantize("quant1", "actv1", "in1_lo", "in1_hi", "out1_lo", "out1_hi", 256, data_types::u8), quantize("quant1", "actv1", "in1_lo", "in1_hi", "out1_lo", "out1_hi", 256, data_types::u8),
eltwise("eltw", { "quant1", "eltw_data" }, eltwise_mode::sum, p.default_type), eltwise("eltw", { "quant1", "eltw_data" }, eltwise_mode::sum, p.default_type),
scale("scale2", "eltw", "scale2_data"), eltwise("scale2", { "eltw", "scale2_data" }, eltwise_mode::prod),
activation("actv2", "scale2", activation_func::relu), activation("actv2", "scale2", activation_func::relu),
quantize("quant2", "actv2", "in2_lo", "in2_hi", "out2_lo", "out2_hi", 255, data_types::i8), quantize("quant2", "actv2", "in2_lo", "in2_hi", "out2_lo", "out2_hi", 255, data_types::i8),
reorder("out", "quant2", p.default_format, data_types::f32) reorder("out", "quant2", p.default_format, data_types::f32)
@ -746,7 +750,7 @@ TEST_P(deconv_scale_activation_quantize_i8_eltwise_quantize_u8, basic) {
data("weights", get_mem(get_weights_layout(p))), data("weights", get_mem(get_weights_layout(p))),
deconvolution("deconv_prim", "input", { "weights" }, p.groups, p.stride, p.pad), deconvolution("deconv_prim", "input", { "weights" }, p.groups, p.stride, p.pad),
data("scale_data", get_mem(get_per_channel_layout(p), 1.f / p.kernel.count())), data("scale_data", get_mem(get_per_channel_layout(p), 1.f / p.kernel.count())),
scale("scale", "deconv_prim", "scale_data"), eltwise("scale", { "deconv_prim", "scale_data" }, eltwise_mode::prod),
activation("activation", "scale", activation_func::relu), activation("activation", "scale", activation_func::relu),
data("in_low", get_mem(get_per_channel_layout(p), min_random, 0)), data("in_low", get_mem(get_per_channel_layout(p), min_random, 0)),
data("in_high", get_mem(get_per_channel_layout(p), 1, max_random)), data("in_high", get_mem(get_per_channel_layout(p), 1, max_random)),