[GPU] resolve accruacy issue related to concat (#13097)
+ Not to fuse activation function if concat is onednn + Added concat config to ForceImplType Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
parent
782615fdc6
commit
846d8e4605
@ -291,7 +291,7 @@ void prepare_primitive_fusing::fuse_activations(program &p) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (use_onednn_impls) {
|
if (use_onednn_impls) {
|
||||||
if (input.is_type<reshape>())
|
if (input.is_type<reshape>() || input.is_type<concatenation>())
|
||||||
return;
|
return;
|
||||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||||
// Activation should not be fused if it isn't supported in onednn
|
// Activation should not be fused if it isn't supported in onednn
|
||||||
|
@ -1284,8 +1284,14 @@ impl_types layout_optimizer::get_forced_impl_type_by_config(program_node& node)
|
|||||||
return impl_types::ocl;
|
return impl_types::ocl;
|
||||||
else if (forced_impl_type == "reduce:onednn")
|
else if (forced_impl_type == "reduce:onednn")
|
||||||
return impl_types::onednn;
|
return impl_types::onednn;
|
||||||
|
} else if (node.is_type<concatenation>()) {
|
||||||
|
if (forced_impl_type == "concat:ocl")
|
||||||
|
return impl_types::ocl;
|
||||||
|
else if (forced_impl_type == "concat:onednn")
|
||||||
|
return impl_types::onednn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Forcing one layer
|
// Forcing one layer
|
||||||
size_t found_type = forced_impl_type.rfind(":");
|
size_t found_type = forced_impl_type.rfind(":");
|
||||||
if (found_type != std::string::npos) {
|
if (found_type != std::string::npos) {
|
||||||
|
@ -136,7 +136,7 @@ TEST_P(concat_onednn_eltwise, along_f) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, concat_onednn_activation, ::testing::ValuesIn(std::vector<concat_test_params>{
|
INSTANTIATE_TEST_SUITE_P(fusings_gpu, concat_onednn_activation, ::testing::ValuesIn(std::vector<concat_test_params>{
|
||||||
concat_test_params{ CASE_CONCAT_F16_1, 3, 3, "" },
|
concat_test_params{ CASE_CONCAT_F16_1, 4, 4, "" },
|
||||||
}));
|
}));
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, concat_onednn_eltwise, ::testing::ValuesIn(std::vector<concat_test_params>{
|
INSTANTIATE_TEST_SUITE_P(fusings_gpu, concat_onednn_eltwise, ::testing::ValuesIn(std::vector<concat_test_params>{
|
||||||
|
@ -1118,6 +1118,7 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp32_multi_eltwise_quantization, ::te
|
|||||||
class conv_fp32_multi_eltwise_concat : public ConvFusingTest {};
|
class conv_fp32_multi_eltwise_concat : public ConvFusingTest {};
|
||||||
TEST_P(conv_fp32_multi_eltwise_concat, basic) {
|
TEST_P(conv_fp32_multi_eltwise_concat, basic) {
|
||||||
auto p = GetParam();
|
auto p = GetParam();
|
||||||
|
data_types output_type = data_types::i8;
|
||||||
create_topologies(
|
create_topologies(
|
||||||
input_layout("input", get_input_layout(p)),
|
input_layout("input", get_input_layout(p)),
|
||||||
data("eltwise_data1", get_mem(get_output_layout(p))),
|
data("eltwise_data1", get_mem(get_output_layout(p))),
|
||||||
@ -1129,15 +1130,15 @@ TEST_P(conv_fp32_multi_eltwise_concat, basic) {
|
|||||||
eltwise("eltwise2", "conv_prim", "eltwise_data2", eltwise_mode::sum),
|
eltwise("eltwise2", "conv_prim", "eltwise_data2", eltwise_mode::sum),
|
||||||
concatenation("concat",
|
concatenation("concat",
|
||||||
{ "eltwise1", "eltwise2" },
|
{ "eltwise1", "eltwise2" },
|
||||||
1,
|
2,
|
||||||
data_types::i8,
|
output_type,
|
||||||
padding{ { 0, 0, 0, 0 }, 0 }),
|
padding{ { 0, 0, 0, 0 }, 0 }),
|
||||||
reorder("reorder_bfyx", "concat", p.default_format, data_types::f32)
|
reorder("reorder_bfyx", "concat", p.default_format, data_types::f32)
|
||||||
);
|
);
|
||||||
implementation_desc conv_impl = { format::b_fs_yx_fsv16, "" };
|
implementation_desc conv_impl = { format::b_fs_yx_fsv16, "" };
|
||||||
bo_fused.set_option(build_option::force_implementations({ { "conv_prim", conv_impl } }));
|
bo_fused.set_option(build_option::force_implementations({ { "conv_prim", conv_impl } }));
|
||||||
|
|
||||||
tolerance = default_tolerance(p.default_type);
|
tolerance = default_tolerance(output_type);
|
||||||
execute(p);
|
execute(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user