Possible fix for GroupConvolution unit test (#2584)

* initial commit

* initial commit

* move fix to tf conv_extractor

* added 3d case

* fix e2e with 3d conv

* remove 3d case

Co-authored-by: yegor.kruglov <ykruglov@nnlvdp-mkaglins.inn.intel.com>
This commit is contained in:
Yegor Kruglov 2020-12-14 21:30:26 +03:00 committed by GitHub
parent f1d99b5887
commit 98fffe7f22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -31,7 +31,8 @@ class Conv2DFrontExtractor(FrontExtractorOp):
def extract(cls, node): def extract(cls, node):
attrs = tf_create_attrs(node, 2, 3) attrs = tf_create_attrs(node, 2, 3)
attrs.update({'op': __class__.op, attrs.update({'op': __class__.op,
'get_group': lambda node: 1, 'get_group': lambda node: node.group if 'group' in node and node.group is not None else
node.in_node(0).shape[node.channel_dims] // node.kernel_shape[node.input_feature_channel],
'get_output_feature_dim': lambda node: node.kernel_shape[node.output_feature_channel], 'get_output_feature_dim': lambda node: node.kernel_shape[node.output_feature_channel],
'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([3, 2, 0, 1]), 'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([3, 2, 0, 1]),
inv=int64_array([2, 3, 1, 0])) inv=int64_array([2, 3, 1, 0]))