[GPU][DG2] Fix some bugs (#13517)

* Bugfix: deconv 2 conv optimization
This commit is contained in:
Felix Dohyun Kim
2022-10-18 15:34:50 +09:00
committed by GitHub
parent b21510f9f6
commit f2bdffb04f
2 changed files with 4 additions and 3 deletions

View File

@@ -148,7 +148,7 @@ protected:
return attrs;
}
static kernel_selector::WeightsReorderParams get_weights_reorder(const kernel_impl_params& impl_params, const dnnl::primitive_desc& pd) {
static kernel_selector::WeightsReorderParams get_weights_reorder(const kernel_impl_params& impl_params, const dnnl::primitive_desc& pd, bool rotate) {
kernel_selector::WeightsReorderParams weights_reorder_params;
auto& reorderKS = kernel_selector::ReorderWeightsKernelSelctor::Instance();
kernel_selector::reorder_weights_params r_params;
@@ -163,7 +163,7 @@ protected:
r_params.layerID = cldnn_prim->id + "_reorder_";
r_params.input = convert_weights_tensor(weights_layout, cldnn_prim->grouped_weights_shape);
r_params.output = r_params.input.TransformIgnorePadding(reqLayout, r_params.input.GetDType(), cldnn_prim->groups, false);
r_params.rotate_180 = false;
r_params.rotate_180 = rotate;
kernel_selector::reorder_optional_params op;
kernel_selector::KernelsData kernels_data = reorderKS.GetBestKernels(r_params, op);
@@ -190,7 +190,7 @@ public:
auto attr = get_primitive_attributes(arg);
dnnl::primitive_desc prim_desc{&desc->data, attr.get(), engine.get_onednn_engine(), nullptr};
return new convolution_onednn(engine, desc, attr, prim_desc, get_weights_reorder(impl_params, prim_desc));
return new convolution_onednn(engine, desc, attr, prim_desc, get_weights_reorder(impl_params, prim_desc, arg.get_transposed()));
}
};

View File

@@ -454,6 +454,7 @@ static cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_groupe
case dnnl::memory::format_tag::Acdeb16a: return cldnn::format::os_zyxi_osv16;
case dnnl::memory::format_tag::ABcde16b16a: return cldnn::format::os_is_zyx_isv16_osv16;
case dnnl::memory::format_tag::aBcd16b: return cldnn::format::o_is_yx_isv16;
case dnnl::memory::format_tag::Abcd16a: return cldnn::format::os_iyx_osv16;
case dnnl::memory::format_tag::ABcd2a8b8a2b: return cldnn::format::os_is_yx_osa2_isa8_osv8_isv2;
case dnnl::memory::format_tag::ABcd2a8b16a4b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv4;
case dnnl::memory::format_tag::ABcd2a8b16a2b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv2;