[ MO GroupNorm ] Covered float Multiplication with Converts (#1602)

This commit is contained in:
Evgenya Stepyreva
2020-08-03 14:45:39 +03:00
committed by GitHub
parent 9f767f7b93
commit 067c2414d1

View File

@@ -18,10 +18,12 @@ from typing import Dict
import numpy as np
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Mul, Add
from extensions.ops.mvn import MVN
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.middle.passes.convert_data_type import data_type_str_to_np
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.reshape import Reshape
@@ -58,9 +60,19 @@ class GroupNormToMVN(MiddleReplacementPattern):
initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node()
initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source())
initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node)
initial_spatial_dims_node = node_to_get_spatial_dimensions_value(initial_shape_op_node)
initial_shape_op_node_float = Cast(
graph, {'name': initial_shape_op_node.name + '/to_float',
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0))
initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float)
initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float)
initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node)
initial_spatial_dims_node = Cast(
graph, {'name': initial_spatial_dims_node_int.name + '/to_float',
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0))
group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]),
'name': group_norm_node.name + '/GroupSize'}).create_node()
@@ -77,8 +89,11 @@ class GroupNormToMVN(MiddleReplacementPattern):
batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0))
# create new node which concatenates several dims to one
new_shape_node = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
initial_spatial_dims_node])
new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
initial_spatial_dims_node])
new_shape_node = Cast(graph,
{'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node()
new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))
reshape_for_mvn_node = Reshape(graph, {}).create_node()