From 98fffe7f228ad0f802ca844a03d5091e487852c2 Mon Sep 17 00:00:00 2001 From: Yegor Kruglov Date: Mon, 14 Dec 2020 21:30:26 +0300 Subject: [PATCH] Possible fix for GroupConvolution unit test (#2584) * initial commit * initial commit * move fix to tf conv_extractor * added 3d case * fix e2e with 3d conv * remove 3d case Co-authored-by: yegor.kruglov --- model-optimizer/extensions/front/tf/conv_ext.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model-optimizer/extensions/front/tf/conv_ext.py b/model-optimizer/extensions/front/tf/conv_ext.py index 82ae051cbb8..5b89dfb7f57 100644 --- a/model-optimizer/extensions/front/tf/conv_ext.py +++ b/model-optimizer/extensions/front/tf/conv_ext.py @@ -31,7 +31,8 @@ class Conv2DFrontExtractor(FrontExtractorOp): def extract(cls, node): attrs = tf_create_attrs(node, 2, 3) attrs.update({'op': __class__.op, - 'get_group': lambda node: 1, + 'get_group': lambda node: node.group if 'group' in node and node.group is not None else + node.in_node(0).shape[node.channel_dims] // node.kernel_shape[node.input_feature_channel], 'get_output_feature_dim': lambda node: node.kernel_shape[node.output_feature_channel], 'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([3, 2, 0, 1]), inv=int64_array([2, 3, 1, 0]))