From 3019a34dc8ea710695c41fa7a10a2fcd26bf1906 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 29 Oct 2020 15:26:23 +0300 Subject: [PATCH] Improve support ONNX Resize-10 created by PyTorch (#1350) --- .../extensions/middle/UpsampleToResample.py | 9 --------- .../extensions/middle/UpsampleToResample_test.py | 13 +++++++------ model-optimizer/extensions/ops/upsample.py | 3 ++- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/model-optimizer/extensions/middle/UpsampleToResample.py b/model-optimizer/extensions/middle/UpsampleToResample.py index 4a516976bd9..f73f36370f0 100644 --- a/model-optimizer/extensions/middle/UpsampleToResample.py +++ b/model-optimizer/extensions/middle/UpsampleToResample.py @@ -79,15 +79,6 @@ class UpsampleToResample(MiddleReplacementPattern): height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] - if not math.isclose(height_scale, width_scale, rel_tol=1e-5): - log.debug('Width and height scales are not equal: {} vs {} for node {}'.format( - width_scale, height_scale, upsample_name)) - return - if depth_scale is not None and not math.isclose(height_scale, depth_scale, rel_tol=1e-5): - log.debug('Depth and height scales are not equal: {} vs {} for node {}'.format( - depth_scale, height_scale, upsample_name)) - return - if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() diff --git a/model-optimizer/extensions/middle/UpsampleToResample_test.py b/model-optimizer/extensions/middle/UpsampleToResample_test.py index d87a37fe039..908d32d044b 100644 --- a/model-optimizer/extensions/middle/UpsampleToResample_test.py +++ b/model-optimizer/extensions/middle/UpsampleToResample_test.py @@ -100,7 +100,12 @@ ref_graph_edges = [ class UpsampleToResampleTest(unittest.TestCase): @generate(*[([2, 10, 20, 30], [1, 1, 5, 5],), ([2, 20, 30, 40], [1, 1, 3, 3],), - ([2, 3, 20, 30, 40], [1, 1, 3, 3, 3],) + ([2, 10, 20, 30], [1, 1, 6, 5],), + ([2, 20, 30, 40], [1, 1, 3, 4],), + ([2, 3, 20, 30, 40], [1, 1, 3, 3, 3],), + ([2, 3, 20, 30, 40], [1, 1, 3, 4, 3],), + ([2, 3, 20, 30, 40], [1, 1, 4, 3, 3],), + ([2, 3, 20, 30, 40], [1, 1, 3, 3, 4],), ]) def test_conversion(self, input_shape, scales): graph = build_graph(graph_node_attrs, graph_edges, @@ -122,11 +127,7 @@ class UpsampleToResampleTest(unittest.TestCase): self.assertTrue(flag, resp) @generate(*[([2, 10, 20, 30], [1, 2, 5, 5],), - ([2, 10, 20, 30], [1, 1, 6, 5],), - ([2, 20, 30, 40], [1, 1, 3, 4],), - ([2, 3, 20, 30, 40], [1, 1, 3, 4, 3],), - ([2, 3, 20, 30, 40], [1, 1, 4, 3, 3],), - ([2, 3, 20, 30, 40], [1, 1, 3, 3, 4],), + ([2, 3, 20, 30, 40], [1, 2, 3, 3, 3],), ]) def test_pattern_does_not_satisfy(self, input_shape, scales): graph = build_graph(graph_node_attrs, graph_edges, diff --git a/model-optimizer/extensions/ops/upsample.py b/model-optimizer/extensions/ops/upsample.py index 613622c8b5a..dc2253ec28a 100644 --- a/model-optimizer/extensions/ops/upsample.py +++ b/model-optimizer/extensions/ops/upsample.py @@ -63,5 +63,6 @@ class UpsampleOp(Op): width=out_width) else: assert node.in_node(1).value is not None + eps = 1e-5 # This is to make rounding in case of very close number to round to closest instead of down # generic output shape calculation to support 5D input shape case - node.out_node().shape = np.array(input_shape * node.in_node(1).value).astype(np.int64) + node.out_node().shape = np.array((input_shape + eps) * node.in_node(1).value).astype(np.int64)