[MO] Fix axes in FusedBatchNorm -> MVN transformation (#7679) (#7913)

This commit is contained in:
Maxim Vafin 2021-10-12 16:47:29 +03:00 committed by GitHub
parent c323775f2c
commit 25c2d5c6c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 9 deletions

View File

@ -58,9 +58,9 @@ class FusedBatchNormTraining(MiddleReplacementPattern):
input_rank = len(node.in_port(0).data.get_shape())
rng = create_op_with_const_inputs(graph, Range,
{0: int64_array(2), 1: int64_array(input_rank), 2: int64_array(1)},
{0: int64_array(1), 1: int64_array(input_rank - 1), 2: int64_array(1)},
{'name': node_name + '/Range', 'output_type': np.int64})
mvn = MVN(graph, {'name': node_name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'eps_mode': 'outside_sqrt',
mvn = MVN(graph, {'name': node_name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'eps_mode': 'inside_sqrt',
'normalize_variance': 1, 'override_output_shape': True}).create_node()
node.in_port(0).get_connection().insert_node(mvn)
mvn.in_port(1).connect(rng.out_port(0))

View File

@ -52,16 +52,16 @@ nodes_attributes = {
'reshape_to_orig': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
'reshape_to_orig_data': {'value': None, 'shape': None, 'kind': 'data'},
'start': {'kind': 'op', 'op': 'Const'},
'start_data': {'value': None, 'shape': None, 'kind': 'data'},
'stop': {'kind': 'op', 'op': 'Const'},
'stop_data': {'value': None, 'shape': None, 'kind': 'data'},
'step': {'kind': 'op', 'op': 'Const'},
'step_data': {'value': None, 'shape': None, 'kind': 'data'},
'start': {'kind': 'op', 'op': 'Const', 'value': int64_array(1)},
'start_data': {'value': None, 'shape': None, 'kind': 'data', 'value': int64_array(1)},
'stop': {'kind': 'op', 'op': 'Const', 'value': int64_array(3)},
'stop_data': {'value': None, 'shape': None, 'kind': 'data', 'value': int64_array(3)},
'step': {'kind': 'op', 'op': 'Const', 'value': int64_array(1)},
'step_data': {'value': None, 'shape': None, 'kind': 'data', 'value': int64_array(1)},
'mvn_axes': {'kind': 'op', 'op': 'Range'},
'mvn_axes_data': {'value': None, 'shape': None, 'kind': 'data'},
'mvn': {'type': 'MVN', 'value': None, 'kind': 'op', 'op': 'MVN', 'eps': 1e-3},
'mvn': {'type': 'MVN', 'value': None, 'kind': 'op', 'op': 'MVN', 'eps': 1e-3, 'eps_mode': 'inside_sqrt'},
'mvn_data': {'value': None, 'shape': None, 'kind': 'data'},
'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},