@@ -148,7 +148,7 @@ protected:
|
||||
return attrs;
|
||||
}
|
||||
|
||||
static kernel_selector::WeightsReorderParams get_weights_reorder(const kernel_impl_params& impl_params, const dnnl::primitive_desc& pd) {
|
||||
static kernel_selector::WeightsReorderParams get_weights_reorder(const kernel_impl_params& impl_params, const dnnl::primitive_desc& pd, bool rotate) {
|
||||
kernel_selector::WeightsReorderParams weights_reorder_params;
|
||||
auto& reorderKS = kernel_selector::ReorderWeightsKernelSelctor::Instance();
|
||||
kernel_selector::reorder_weights_params r_params;
|
||||
@@ -163,7 +163,7 @@ protected:
|
||||
r_params.layerID = cldnn_prim->id + "_reorder_";
|
||||
r_params.input = convert_weights_tensor(weights_layout, cldnn_prim->grouped_weights_shape);
|
||||
r_params.output = r_params.input.TransformIgnorePadding(reqLayout, r_params.input.GetDType(), cldnn_prim->groups, false);
|
||||
r_params.rotate_180 = false;
|
||||
r_params.rotate_180 = rotate;
|
||||
|
||||
kernel_selector::reorder_optional_params op;
|
||||
kernel_selector::KernelsData kernels_data = reorderKS.GetBestKernels(r_params, op);
|
||||
@@ -190,7 +190,7 @@ public:
|
||||
auto attr = get_primitive_attributes(arg);
|
||||
dnnl::primitive_desc prim_desc{&desc->data, attr.get(), engine.get_onednn_engine(), nullptr};
|
||||
|
||||
return new convolution_onednn(engine, desc, attr, prim_desc, get_weights_reorder(impl_params, prim_desc));
|
||||
return new convolution_onednn(engine, desc, attr, prim_desc, get_weights_reorder(impl_params, prim_desc, arg.get_transposed()));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -454,6 +454,7 @@ static cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_groupe
|
||||
case dnnl::memory::format_tag::Acdeb16a: return cldnn::format::os_zyxi_osv16;
|
||||
case dnnl::memory::format_tag::ABcde16b16a: return cldnn::format::os_is_zyx_isv16_osv16;
|
||||
case dnnl::memory::format_tag::aBcd16b: return cldnn::format::o_is_yx_isv16;
|
||||
case dnnl::memory::format_tag::Abcd16a: return cldnn::format::os_iyx_osv16;
|
||||
case dnnl::memory::format_tag::ABcd2a8b8a2b: return cldnn::format::os_is_yx_osa2_isa8_osv8_isv2;
|
||||
case dnnl::memory::format_tag::ABcd2a8b16a4b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv4;
|
||||
case dnnl::memory::format_tag::ABcd2a8b16a2b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv2;
|
||||
|
||||
Reference in New Issue
Block a user