parent
c323775f2c
commit
25c2d5c6c4
@ -58,9 +58,9 @@ class FusedBatchNormTraining(MiddleReplacementPattern):
|
|||||||
|
|
||||||
input_rank = len(node.in_port(0).data.get_shape())
|
input_rank = len(node.in_port(0).data.get_shape())
|
||||||
rng = create_op_with_const_inputs(graph, Range,
|
rng = create_op_with_const_inputs(graph, Range,
|
||||||
{0: int64_array(2), 1: int64_array(input_rank), 2: int64_array(1)},
|
{0: int64_array(1), 1: int64_array(input_rank - 1), 2: int64_array(1)},
|
||||||
{'name': node_name + '/Range', 'output_type': np.int64})
|
{'name': node_name + '/Range', 'output_type': np.int64})
|
||||||
mvn = MVN(graph, {'name': node_name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'eps_mode': 'outside_sqrt',
|
mvn = MVN(graph, {'name': node_name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'eps_mode': 'inside_sqrt',
|
||||||
'normalize_variance': 1, 'override_output_shape': True}).create_node()
|
'normalize_variance': 1, 'override_output_shape': True}).create_node()
|
||||||
node.in_port(0).get_connection().insert_node(mvn)
|
node.in_port(0).get_connection().insert_node(mvn)
|
||||||
mvn.in_port(1).connect(rng.out_port(0))
|
mvn.in_port(1).connect(rng.out_port(0))
|
||||||
|
@ -52,16 +52,16 @@ nodes_attributes = {
|
|||||||
'reshape_to_orig': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
|
'reshape_to_orig': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
|
||||||
'reshape_to_orig_data': {'value': None, 'shape': None, 'kind': 'data'},
|
'reshape_to_orig_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||||
|
|
||||||
'start': {'kind': 'op', 'op': 'Const'},
|
'start': {'kind': 'op', 'op': 'Const', 'value': int64_array(1)},
|
||||||
'start_data': {'value': None, 'shape': None, 'kind': 'data'},
|
'start_data': {'value': None, 'shape': None, 'kind': 'data', 'value': int64_array(1)},
|
||||||
'stop': {'kind': 'op', 'op': 'Const'},
|
'stop': {'kind': 'op', 'op': 'Const', 'value': int64_array(3)},
|
||||||
'stop_data': {'value': None, 'shape': None, 'kind': 'data'},
|
'stop_data': {'value': None, 'shape': None, 'kind': 'data', 'value': int64_array(3)},
|
||||||
'step': {'kind': 'op', 'op': 'Const'},
|
'step': {'kind': 'op', 'op': 'Const', 'value': int64_array(1)},
|
||||||
'step_data': {'value': None, 'shape': None, 'kind': 'data'},
|
'step_data': {'value': None, 'shape': None, 'kind': 'data', 'value': int64_array(1)},
|
||||||
'mvn_axes': {'kind': 'op', 'op': 'Range'},
|
'mvn_axes': {'kind': 'op', 'op': 'Range'},
|
||||||
'mvn_axes_data': {'value': None, 'shape': None, 'kind': 'data'},
|
'mvn_axes_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||||
|
|
||||||
'mvn': {'type': 'MVN', 'value': None, 'kind': 'op', 'op': 'MVN', 'eps': 1e-3},
|
'mvn': {'type': 'MVN', 'value': None, 'kind': 'op', 'op': 'MVN', 'eps': 1e-3, 'eps_mode': 'inside_sqrt'},
|
||||||
'mvn_data': {'value': None, 'shape': None, 'kind': 'data'},
|
'mvn_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||||
|
|
||||||
'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
|
'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
|
||||||
|
Loading…
Reference in New Issue
Block a user