Incorrect type of 'scales' input of the operation Interpolate-4 in some cases (#4375)

* Commit.

* Fixed element type of scales input in the MO transformation UpsampleToResample.

* Fixes in the transformation TFSliceToSliceReplacer.

* Fixes in tests.

* Small fixes.

* Reverted fixes in TFSliceToSliceReplacer.

* Small fix.

* Added tests for fractional scales in the transformation UpsampleToResample.
This commit is contained in:
Vladimir Gavrilov
2021-02-19 10:52:41 +03:00
committed by GitHub
parent 9559f6f301
commit 2b732ec1d7
2 changed files with 27 additions and 14 deletions

View File

@@ -24,7 +24,7 @@ from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Mul
from extensions.ops.interpolate import Interpolate
from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.partial_infer.utils import int64_array, float32_array
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph, Node, rename_nodes
from mo.middle.replacement import MiddleReplacementPattern
@@ -92,10 +92,10 @@ class UpsampleToResample(MiddleReplacementPattern):
if input_shape_rank == 4:
begin_value = int64_array([get_height_dim(layout, input_shape_rank)])
factor_value = np.array([height_scale, width_scale])
factor_value = float32_array([height_scale, width_scale])
else:
begin_value = int64_array([get_depth_dim(layout, input_shape_rank)])
factor_value = np.array([depth_scale, height_scale, width_scale])
factor_value = float32_array([depth_scale, height_scale, width_scale])
ss = create_op_with_const_inputs(graph, StridedSlice,
{1: begin_value,
@@ -141,7 +141,8 @@ class UpsampleToResample(MiddleReplacementPattern):
mul.out_port(0).connect(interpolate.in_port(1))
axes_node.out_port(0).connect(interpolate.in_port(3))
scales_node = Const(graph, {'name': upsample_name + '/scales', 'value': factor_value}).create_node()
scales_node = Const(graph, {'name': upsample_name + '/scales',
'value': factor_value}).create_node()
scales_node.out_port(0).connect(interpolate.in_port(2))
upsample.in_port(0).get_connection().set_destination(interpolate.in_port(0))

View File

@@ -20,7 +20,7 @@ import numpy as np
from generator import generator, generate
from extensions.middle.UpsampleToResample import UpsampleToResample
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.partial_infer.utils import int64_array, float32_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
@@ -166,15 +166,26 @@ class UpsampleToResampleTest(unittest.TestCase):
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3], [2, 3, 4]),
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3], [2, 3, 4]),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4], [2, 3, 4]),
([2, 10, 20, 30], [1, 1, 5.5, 5.7], [2, 3]),
([2, 20, 30, 40], [1, 1, 3.3, 3.1], [2, 3]),
([2, 10, 20, 30], [1, 1, 6.18, 5.34], [2, 3]),
([2, 20, 30, 40], [1, 1, 3.79, 4.16], [2, 3]),
([2, 3, 20, 30, 40], [1, 1, 3.12, 3.87, 3.92], [2, 3, 4]),
([2, 3, 20, 30, 40], [1, 1, 3.74, 4.873, 3.287], [2, 3, 4]),
([2, 3, 20, 30, 40], [1, 1, 4.8, 3.6, 3.11], [2, 3, 4]),
([2, 3, 20, 30, 40], [1, 1, 3.33, 3.73, 4.765], [2, 3, 4]),
])
def test_conversion(self, input_shape, scales, axes):
input_shape_as_array = int64_array(input_shape)
scales_as_array = float32_array(scales)
graph = build_graph(graph_node_attrs,
graph_edges,
{
'placeholder_data': {'shape': int64_array(input_shape)},
'scales': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
'scales_data': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
'upsample_data': {'shape': int64_array(input_shape) * int64_array(scales)}
'placeholder_data': {'shape': input_shape_as_array},
'scales': {'value': scales_as_array, 'shape': scales_as_array.shape},
'scales_data': {'value': scales_as_array, 'shape': scales_as_array.shape},
'upsample_data':
{'shape': ((input_shape_as_array + 1.e-5) * scales_as_array).astype(np.int64)}
})
graph.graph['layout'] = 'NCHW'
ref_graph = build_graph(new_ref_graph_node_attr,
@@ -185,12 +196,13 @@ class UpsampleToResampleTest(unittest.TestCase):
'ss_end': {'value': int64_array([axes[-1] + 1])},
'ss_begin_data': {'value': int64_array([axes[0]])},
'ss_end_data': {'value': int64_array([axes[-1] + 1])},
'factor': {'value': int64_array(scales)[2:],
'shape': int64_array(scales[2:]).shape},
'factor_data': {'value': int64_array(scales)[2:],
'shape': int64_array(scales[2:]).shape},
'factor': {'value': scales_as_array[2:],
'shape': scales_as_array[2:].shape},
'factor_data': {'value': scales_as_array[2:],
'shape': scales_as_array[2:].shape},
'axes_const': {'value': int64_array(axes), 'shape': int64_array(axes).shape},
'interpolate_data': {'shape': int64_array(input_shape) * int64_array(scales)},
'interpolate_data': {
'shape': (input_shape_as_array * scales_as_array + 1e-5).astype(np.int64)},
})
UpsampleToResample().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'output')