Incorrect type of 'scales' input of the operation Interpolate-4 in some cases (#4375)
* Commit. * Fixed element type of scales input in the MO transformation UpsampleToResample. * Fixes in the transformation TFSliceToSliceReplacer. * Fixes in tests. * Small fixes. * Reverted fixes in TFSliceToSliceReplacer. * Small fix. * Added tests for fractional scales in the transformation UpsampleToResample.
This commit is contained in:
committed by
GitHub
parent
9559f6f301
commit
2b732ec1d7
@@ -24,7 +24,7 @@ from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Mul
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, Node, rename_nodes
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
@@ -92,10 +92,10 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
|
||||
if input_shape_rank == 4:
|
||||
begin_value = int64_array([get_height_dim(layout, input_shape_rank)])
|
||||
factor_value = np.array([height_scale, width_scale])
|
||||
factor_value = float32_array([height_scale, width_scale])
|
||||
else:
|
||||
begin_value = int64_array([get_depth_dim(layout, input_shape_rank)])
|
||||
factor_value = np.array([depth_scale, height_scale, width_scale])
|
||||
factor_value = float32_array([depth_scale, height_scale, width_scale])
|
||||
|
||||
ss = create_op_with_const_inputs(graph, StridedSlice,
|
||||
{1: begin_value,
|
||||
@@ -141,7 +141,8 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
mul.out_port(0).connect(interpolate.in_port(1))
|
||||
axes_node.out_port(0).connect(interpolate.in_port(3))
|
||||
|
||||
scales_node = Const(graph, {'name': upsample_name + '/scales', 'value': factor_value}).create_node()
|
||||
scales_node = Const(graph, {'name': upsample_name + '/scales',
|
||||
'value': factor_value}).create_node()
|
||||
scales_node.out_port(0).connect(interpolate.in_port(2))
|
||||
|
||||
upsample.in_port(0).get_connection().set_destination(interpolate.in_port(0))
|
||||
|
||||
@@ -20,7 +20,7 @@ import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from extensions.middle.UpsampleToResample import UpsampleToResample
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
|
||||
@@ -166,15 +166,26 @@ class UpsampleToResampleTest(unittest.TestCase):
|
||||
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3], [2, 3, 4]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3], [2, 3, 4]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4], [2, 3, 4]),
|
||||
([2, 10, 20, 30], [1, 1, 5.5, 5.7], [2, 3]),
|
||||
([2, 20, 30, 40], [1, 1, 3.3, 3.1], [2, 3]),
|
||||
([2, 10, 20, 30], [1, 1, 6.18, 5.34], [2, 3]),
|
||||
([2, 20, 30, 40], [1, 1, 3.79, 4.16], [2, 3]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3.12, 3.87, 3.92], [2, 3, 4]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3.74, 4.873, 3.287], [2, 3, 4]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 4.8, 3.6, 3.11], [2, 3, 4]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3.33, 3.73, 4.765], [2, 3, 4]),
|
||||
])
|
||||
def test_conversion(self, input_shape, scales, axes):
|
||||
input_shape_as_array = int64_array(input_shape)
|
||||
scales_as_array = float32_array(scales)
|
||||
graph = build_graph(graph_node_attrs,
|
||||
graph_edges,
|
||||
{
|
||||
'placeholder_data': {'shape': int64_array(input_shape)},
|
||||
'scales': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
|
||||
'scales_data': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
|
||||
'upsample_data': {'shape': int64_array(input_shape) * int64_array(scales)}
|
||||
'placeholder_data': {'shape': input_shape_as_array},
|
||||
'scales': {'value': scales_as_array, 'shape': scales_as_array.shape},
|
||||
'scales_data': {'value': scales_as_array, 'shape': scales_as_array.shape},
|
||||
'upsample_data':
|
||||
{'shape': ((input_shape_as_array + 1.e-5) * scales_as_array).astype(np.int64)}
|
||||
})
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
ref_graph = build_graph(new_ref_graph_node_attr,
|
||||
@@ -185,12 +196,13 @@ class UpsampleToResampleTest(unittest.TestCase):
|
||||
'ss_end': {'value': int64_array([axes[-1] + 1])},
|
||||
'ss_begin_data': {'value': int64_array([axes[0]])},
|
||||
'ss_end_data': {'value': int64_array([axes[-1] + 1])},
|
||||
'factor': {'value': int64_array(scales)[2:],
|
||||
'shape': int64_array(scales[2:]).shape},
|
||||
'factor_data': {'value': int64_array(scales)[2:],
|
||||
'shape': int64_array(scales[2:]).shape},
|
||||
'factor': {'value': scales_as_array[2:],
|
||||
'shape': scales_as_array[2:].shape},
|
||||
'factor_data': {'value': scales_as_array[2:],
|
||||
'shape': scales_as_array[2:].shape},
|
||||
'axes_const': {'value': int64_array(axes), 'shape': int64_array(axes).shape},
|
||||
'interpolate_data': {'shape': int64_array(input_shape) * int64_array(scales)},
|
||||
'interpolate_data': {
|
||||
'shape': (input_shape_as_array * scales_as_array + 1e-5).astype(np.int64)},
|
||||
})
|
||||
UpsampleToResample().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||
|
||||
Reference in New Issue
Block a user