diff --git a/model-optimizer/extensions/middle/GroupNorm.py b/model-optimizer/extensions/middle/GroupNorm.py index 4e7ed9bd35a..166fcd55859 100644 --- a/model-optimizer/extensions/middle/GroupNorm.py +++ b/model-optimizer/extensions/middle/GroupNorm.py @@ -18,10 +18,12 @@ from typing import Dict import numpy as np +from extensions.ops.Cast import Cast from extensions.ops.elementwise import Mul, Add from extensions.ops.mvn import MVN from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Graph, Node +from mo.middle.passes.convert_data_type import data_type_str_to_np from mo.middle.replacement import MiddleReplacementPattern from mo.ops.const import Const from mo.ops.reshape import Reshape @@ -58,9 +60,19 @@ class GroupNormToMVN(MiddleReplacementPattern): initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node() initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source()) - initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node) - initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node) - initial_spatial_dims_node = node_to_get_spatial_dimensions_value(initial_shape_op_node) + initial_shape_op_node_float = Cast( + graph, {'name': initial_shape_op_node.name + '/to_float', + 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() + initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0)) + + initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float) + initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float) + initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node) + initial_spatial_dims_node = Cast( + graph, {'name': initial_spatial_dims_node_int.name + '/to_float', + 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() + initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0)) + group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize'}).create_node() @@ -77,8 +89,11 @@ class GroupNormToMVN(MiddleReplacementPattern): batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0)) # create new node which concatenates several dims to one - new_shape_node = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node, - initial_spatial_dims_node]) + new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node, + initial_spatial_dims_node]) + new_shape_node = Cast(graph, + {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node() + new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0)) reshape_for_mvn_node = Reshape(graph, {}).create_node()