Update pre_replace_deconv to support output_shape for transposed conv (#12335)

Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
Andrew Kwangwoong Park 2022-08-03 10:22:42 +09:00 committed by GitHub
parent 32182bd3ce
commit bb1560c05c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 11 deletions

View File

@ -511,6 +511,7 @@ struct convolution : public primitive_base<convolution> {
ov::Strides stride = {1, 1},
ov::CoordinateDiff pad = {0, 0},
ov::Strides dilation = {1, 1},
tensor output_size = {0, 0, 0, 0},
bool grouped_weights_shape = false,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
@ -518,7 +519,8 @@ struct convolution : public primitive_base<convolution> {
pad(pad),
stride(stride),
dilation(dilation),
with_output_size(false),
with_output_size(output_size.count() > 0 ? true : false),
output_size(output_size),
groups(groups),
deformable_groups(1),
padding_above(stride.size(), 0),

View File

@ -121,16 +121,33 @@ void pre_replace_deconv::run(program& p) {
"",
output_padding);
} else {
conv_prim = std::make_shared<convolution>(deconv_node_id,
input_node_id,
weights_nodes_id,
groups,
stride,
pad,
dilation,
grouped_weights_shape,
"",
output_padding);
tensor output_size(0);
if (deconv_prim->with_output_size) {
output_size = deconv_prim->output_size;
conv_prim = std::make_shared<convolution>(deconv_node_id,
input_node_id,
weights_nodes_id,
groups,
stride,
pad,
dilation,
output_size,
grouped_weights_shape,
"",
output_padding);
} else {
conv_prim = std::make_shared<convolution>(deconv_node_id,
input_node_id,
weights_nodes_id,
groups,
stride,
pad,
dilation,
output_size,
grouped_weights_shape,
"",
output_padding);
}
}
program_node& new_node = p.get_or_create(conv_prim);