[MO] Support TensorFlow FusedBatchNorm with channel_first data_format (#11084)
* layout fix in FusedBatchNorm decomposition * added tests
This commit is contained in:
@@ -19,7 +19,7 @@ def tf_fused_bn_extractor(pb):
|
||||
log.warning('FusedBatchNorm doesn\'t support is_training=True')
|
||||
|
||||
return {
|
||||
'data_format': pb.attr["data_format"].s,
|
||||
'data_format': pb.attr["data_format"].s.decode(),
|
||||
'data_type': tf_dtype_extractor(pb.attr["T"].type),
|
||||
'eps': pb.attr['epsilon'].f,
|
||||
'infer': tf_fused_bn_infer,
|
||||
|
||||
@@ -58,7 +58,8 @@ def convert_batch_norm(graph: Graph):
|
||||
shift = (mean.data.get_value() * (-1)) * scale
|
||||
|
||||
# Expand dims for current layout
|
||||
broadcast_dims_cnt = len(node.in_port(0).data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0
|
||||
layout = node.soft_get('data_format', graph.graph['layout'])
|
||||
broadcast_dims_cnt = len(node.in_port(0).data.get_shape()) - 2 if layout in ['NCHW', "NCDHW"] else 0
|
||||
|
||||
# Update values and shapes with new shape
|
||||
expand_node_shape(const, broadcast_dims_cnt)
|
||||
|
||||
@@ -500,4 +500,162 @@ class BatchNormDecomposition(unittest.TestCase):
|
||||
graph.clean_up()
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
# graph - NCHW
|
||||
# BatchNorm - NHWC
|
||||
def test_bn_decomposition_different_layouts_1(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'bn_op'),
|
||||
('const_bn_const', 'bn_const'),
|
||||
('const_bn_beta', 'bn_beta'),
|
||||
('const_bn_mean', 'bn_mean'),
|
||||
('const_bn_var', 'bn_var'),
|
||||
('bn_const', 'bn_op'),
|
||||
('bn_beta', 'bn_op'),
|
||||
('bn_mean', 'bn_op'),
|
||||
('bn_var', 'bn_op'),
|
||||
('bn_op', 'bn_data'),
|
||||
('concat', 'concat_data'),
|
||||
('bn_data', 'concat'),
|
||||
('concat_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
|
||||
'bn_op': {'eps': 1.2, 'data_format': 'NHWC'},
|
||||
'bn_const': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'bn_beta': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'bn_data': {'shape': np.array([1, 227, 227, 3])},
|
||||
'concat_data': {}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'mul_1'),
|
||||
('const_mul_1_w', 'mul_1_w'),
|
||||
('mul_1_w', 'mul_1'),
|
||||
('mul_1', 'mul_1_data'),
|
||||
('mul_1_data', 'add_1'),
|
||||
('const_add_1_w', 'add_1_w'),
|
||||
('add_1_w', 'add_1'),
|
||||
('add_1', 'add_1_data'),
|
||||
('add_1_data', 'mul_2'),
|
||||
('const_mul_2_w', 'mul_2_w'),
|
||||
('mul_2_w', 'mul_2'),
|
||||
('mul_2', 'mul_2_data'),
|
||||
('mul_2_data', 'add_2'),
|
||||
('const_add_2_w', 'add_2_w'),
|
||||
('add_2_w', 'add_2'),
|
||||
('add_2', 'add_2_data'),
|
||||
('concat', 'concat_data'),
|
||||
('add_2_data', 'concat'),
|
||||
('concat_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
|
||||
'const_mul_1_w': {'shape': np.array([3]),
|
||||
'value': np.array([0.67419986, 0.55901699, 0.48795004])},
|
||||
'mul_1_w': {'shape': np.array([3]),
|
||||
'value': np.array([0.67419986, 0.55901699, 0.48795004])},
|
||||
'const_mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'const_add_1_w': {'shape': np.array([3]),
|
||||
'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
|
||||
'add_1_w': {'shape': np.array([3]),
|
||||
'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
|
||||
'const_add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'add_2_data': {'shape': np.array([1, 227, 227, 3])},
|
||||
'mul_1': {'can_be_fused': True},
|
||||
'mul_2': {'can_be_fused': True},
|
||||
'add_1': {'can_be_fused': True},
|
||||
'add_2': {'can_be_fused': True},
|
||||
'concat_data': {}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
convert_batch_norm(graph)
|
||||
graph.clean_up()
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
# graph - NHWC
|
||||
# BatchNorm - NCHW
|
||||
def test_bn_decomposition_different_layouts_2(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'bn_op'),
|
||||
('const_bn_const', 'bn_const'),
|
||||
('const_bn_beta', 'bn_beta'),
|
||||
('const_bn_mean', 'bn_mean'),
|
||||
('const_bn_var', 'bn_var'),
|
||||
('bn_const', 'bn_op'),
|
||||
('bn_beta', 'bn_op'),
|
||||
('bn_mean', 'bn_op'),
|
||||
('bn_var', 'bn_op'),
|
||||
('bn_op', 'bn_data'),
|
||||
('concat', 'concat_data'),
|
||||
('bn_data', 'concat'),
|
||||
('concat_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 227, 227])},
|
||||
'bn_op': {'eps': 1.2, 'data_format': 'NCHW'},
|
||||
'bn_const': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'bn_beta': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'bn_data': {'shape': np.array([1, 3, 227, 227])},
|
||||
'concat_data': {}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'mul_1'),
|
||||
('const_mul_1_w', 'mul_1_w'),
|
||||
('mul_1_w', 'mul_1'),
|
||||
('mul_1', 'mul_1_data'),
|
||||
('mul_1_data', 'add_1'),
|
||||
('const_add_1_w', 'add_1_w'),
|
||||
('add_1_w', 'add_1'),
|
||||
('add_1', 'add_1_data'),
|
||||
('add_1_data', 'mul_2'),
|
||||
('const_mul_2_w', 'mul_2_w'),
|
||||
('mul_2_w', 'mul_2'),
|
||||
('mul_2', 'mul_2_data'),
|
||||
('mul_2_data', 'add_2'),
|
||||
('const_add_2_w', 'add_2_w'),
|
||||
('add_2_w', 'add_2'),
|
||||
('add_2', 'add_2_data'),
|
||||
('concat', 'concat_data'),
|
||||
('add_2_data', 'concat'),
|
||||
('concat_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 227, 227])},
|
||||
'const_mul_1_w': {'shape': np.array([3, 1, 1]),
|
||||
'value': np.array([[[0.67419986]], [[0.55901699]], [[0.48795004]]])},
|
||||
'mul_1_w': {'shape': np.array([3, 1, 1]),
|
||||
'value': np.array([[[0.67419986]], [[0.55901699]], [[0.48795004]]])},
|
||||
'const_mul_2_w': {'shape': np.array([3, 1, 1]), 'value': np.array([[[1]], [[2]], [[3]]])},
|
||||
'mul_2_w': {'shape': np.array([3, 1, 1]), 'value': np.array([[[1]], [[2]], [[3]]])},
|
||||
'const_add_1_w': {'shape': np.array([3, 1, 1]),
|
||||
'value': np.array([[[-0.67419986]], [[-1.11803399]], [[-1.46385011]]])},
|
||||
'add_1_w': {'shape': np.array([3, 1, 1]),
|
||||
'value': np.array([[[-0.67419986]], [[-1.11803399]], [[-1.46385011]]])},
|
||||
'const_add_2_w': {'shape': np.array([3, 1, 1]), 'value': np.array([[[1]], [[2]], [[3]]])},
|
||||
'add_2_w': {'shape': np.array([3, 1, 1]), 'value': np.array([[[1]], [[2]], [[3]]])},
|
||||
'add_2_data': {'shape': np.array([1, 3, 227, 227])},
|
||||
'mul_1': {'can_be_fused': True},
|
||||
'mul_2': {'can_be_fused': True},
|
||||
'add_1': {'can_be_fused': True},
|
||||
'add_2': {'can_be_fused': True},
|
||||
'concat_data': {}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph.graph['layout'] = 'NHWC'
|
||||
convert_batch_norm(graph)
|
||||
graph.clean_up()
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
Reference in New Issue
Block a user