[GPU] update weights_layout for GroupConv 1d spatial (#17109)
* update weights_layout for GroupConv 1d spatial
This commit is contained in:
@@ -21,10 +21,31 @@ static std::shared_ptr<dnnl::convolution_forward::primitive_desc> get_convolutio
|
||||
dnnl::memory::dims pad_l(prim->pad.begin(), prim->pad.end());
|
||||
dnnl::memory::dims pad_r(prim->pad.begin(), prim->pad.end());
|
||||
|
||||
// issue: it could not find the implementation for 1d kernel GroupConvolution from onednn.
|
||||
// root-cause: 3d tensor of input/output is changed to 4d via ngraph.
|
||||
// Creating conv description returns error if two inputs have same tensor of data input and weight.
|
||||
// - original dims of IR
|
||||
// input1: [ 1, 280, 1200] // [number of batches, number of channels, X]
|
||||
// input2: [280, 1, 1, 67] // [number of output channels, number of input channels, Y, X]
|
||||
// output: [ 1, 280, 1200] // [number of batches, number of kernel output channels, X]
|
||||
// - changed dims
|
||||
// input1: [ 1, 280, 1200, 1]
|
||||
// input2: [280, 1, 67, 1]
|
||||
// output: [ 1, 280, 1200, 1]
|
||||
// WA: Weight tensor will be updated from 4d to 5d.
|
||||
auto grouped_weights = format::is_grouped(weights_layout.format) || prim->grouped_weights_shape;
|
||||
if (grouped_weights && (input_layout.get_rank() == weights_layout.get_rank())) {
|
||||
auto tensor = weights_layout.get_tensor();
|
||||
if (tensor.spatial[0] == 1 && tensor.spatial[1] != 1) {
|
||||
std::swap(tensor.spatial[0], tensor.spatial[1]);
|
||||
weights_layout.set_tensor(tensor);
|
||||
}
|
||||
weights_layout.format = format::get_default_format(weights_layout.get_rank() + 1, true, true);
|
||||
}
|
||||
|
||||
auto input_md = onednn::layout_to_memory_desc(input_layout, tag_in_out);
|
||||
auto weights_md = onednn::layout_to_memory_desc(weights_layout, dnnl::memory::format_tag::any);
|
||||
auto output_md = onednn::layout_to_memory_desc(output_layout, tag_in_out);
|
||||
auto grouped_weights = format::is_grouped(weights_layout.format) || prim->grouped_weights_shape;
|
||||
|
||||
// adjust_conv_dilation_pad(dilation, stride, pad_l, pad_r, input_md, output_md, weights_md, grouped_weights);
|
||||
for (size_t i = 0; i < dilation.size(); i++) {
|
||||
|
||||
Reference in New Issue
Block a user