[ MO GroupNorm ] Covered float Multiplication with Converts (#1602)
This commit is contained in:
committed by
GitHub
parent
9f767f7b93
commit
067c2414d1
@@ -18,10 +18,12 @@ from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Mul, Add
|
||||
from extensions.ops.mvn import MVN
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.reshape import Reshape
|
||||
@@ -58,9 +60,19 @@ class GroupNormToMVN(MiddleReplacementPattern):
|
||||
initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node()
|
||||
initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source())
|
||||
|
||||
initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
|
||||
initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node)
|
||||
initial_spatial_dims_node = node_to_get_spatial_dimensions_value(initial_shape_op_node)
|
||||
initial_shape_op_node_float = Cast(
|
||||
graph, {'name': initial_shape_op_node.name + '/to_float',
|
||||
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
|
||||
initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0))
|
||||
|
||||
initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float)
|
||||
initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float)
|
||||
initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node)
|
||||
initial_spatial_dims_node = Cast(
|
||||
graph, {'name': initial_spatial_dims_node_int.name + '/to_float',
|
||||
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
|
||||
initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0))
|
||||
|
||||
group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]),
|
||||
'name': group_norm_node.name + '/GroupSize'}).create_node()
|
||||
|
||||
@@ -77,8 +89,11 @@ class GroupNormToMVN(MiddleReplacementPattern):
|
||||
batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0))
|
||||
|
||||
# create new node which concatenates several dims to one
|
||||
new_shape_node = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
|
||||
initial_spatial_dims_node])
|
||||
new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
|
||||
initial_spatial_dims_node])
|
||||
new_shape_node = Cast(graph,
|
||||
{'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node()
|
||||
new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))
|
||||
|
||||
reshape_for_mvn_node = Reshape(graph, {}).create_node()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user