Update pre_replace_deconv to support output_shape for transposed conv (#12335)
Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
parent
32182bd3ce
commit
bb1560c05c
@ -511,6 +511,7 @@ struct convolution : public primitive_base<convolution> {
|
||||
ov::Strides stride = {1, 1},
|
||||
ov::CoordinateDiff pad = {0, 0},
|
||||
ov::Strides dilation = {1, 1},
|
||||
tensor output_size = {0, 0, 0, 0},
|
||||
bool grouped_weights_shape = false,
|
||||
const primitive_id& ext_prim_id = "",
|
||||
const padding& output_padding = padding())
|
||||
@ -518,7 +519,8 @@ struct convolution : public primitive_base<convolution> {
|
||||
pad(pad),
|
||||
stride(stride),
|
||||
dilation(dilation),
|
||||
with_output_size(false),
|
||||
with_output_size(output_size.count() > 0 ? true : false),
|
||||
output_size(output_size),
|
||||
groups(groups),
|
||||
deformable_groups(1),
|
||||
padding_above(stride.size(), 0),
|
||||
|
@ -121,16 +121,33 @@ void pre_replace_deconv::run(program& p) {
|
||||
"",
|
||||
output_padding);
|
||||
} else {
|
||||
conv_prim = std::make_shared<convolution>(deconv_node_id,
|
||||
input_node_id,
|
||||
weights_nodes_id,
|
||||
groups,
|
||||
stride,
|
||||
pad,
|
||||
dilation,
|
||||
grouped_weights_shape,
|
||||
"",
|
||||
output_padding);
|
||||
tensor output_size(0);
|
||||
if (deconv_prim->with_output_size) {
|
||||
output_size = deconv_prim->output_size;
|
||||
conv_prim = std::make_shared<convolution>(deconv_node_id,
|
||||
input_node_id,
|
||||
weights_nodes_id,
|
||||
groups,
|
||||
stride,
|
||||
pad,
|
||||
dilation,
|
||||
output_size,
|
||||
grouped_weights_shape,
|
||||
"",
|
||||
output_padding);
|
||||
} else {
|
||||
conv_prim = std::make_shared<convolution>(deconv_node_id,
|
||||
input_node_id,
|
||||
weights_nodes_id,
|
||||
groups,
|
||||
stride,
|
||||
pad,
|
||||
dilation,
|
||||
output_size,
|
||||
grouped_weights_shape,
|
||||
"",
|
||||
output_padding);
|
||||
}
|
||||
}
|
||||
program_node& new_node = p.get_or_create(conv_prim);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user