Improve support ONNX Resize-10 created by PyTorch (#1350)

This commit is contained in:
Maxim Vafin 2020-10-29 15:26:23 +03:00 committed by GitHub
parent 04b7822761
commit 3019a34dc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 16 deletions

View File

@ -79,15 +79,6 @@ class UpsampleToResample(MiddleReplacementPattern):
height_scale = upsample['height_scale']
width_scale = upsample['width_scale']
if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
log.debug('Width and height scales are not equal: {} vs {} for node {}'.format(
width_scale, height_scale, upsample_name))
return
if depth_scale is not None and not math.isclose(height_scale, depth_scale, rel_tol=1e-5):
log.debug('Depth and height scales are not equal: {} vs {} for node {}'.format(
depth_scale, height_scale, upsample_name))
return
if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
upsample.in_port(1).disconnect()

View File

@ -100,7 +100,12 @@ ref_graph_edges = [
class UpsampleToResampleTest(unittest.TestCase):
@generate(*[([2, 10, 20, 30], [1, 1, 5, 5],),
([2, 20, 30, 40], [1, 1, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 3],)
([2, 10, 20, 30], [1, 1, 6, 5],),
([2, 20, 30, 40], [1, 1, 3, 4],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3],),
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4],),
])
def test_conversion(self, input_shape, scales):
graph = build_graph(graph_node_attrs, graph_edges,
@ -122,11 +127,7 @@ class UpsampleToResampleTest(unittest.TestCase):
self.assertTrue(flag, resp)
@generate(*[([2, 10, 20, 30], [1, 2, 5, 5],),
([2, 10, 20, 30], [1, 1, 6, 5],),
([2, 20, 30, 40], [1, 1, 3, 4],),
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3],),
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4],),
([2, 3, 20, 30, 40], [1, 2, 3, 3, 3],),
])
def test_pattern_does_not_satisfy(self, input_shape, scales):
graph = build_graph(graph_node_attrs, graph_edges,

View File

@ -63,5 +63,6 @@ class UpsampleOp(Op):
width=out_width)
else:
assert node.in_node(1).value is not None
eps = 1e-5 # This is to make rounding in case of very close number to round to closest instead of down
# generic output shape calculation to support 5D input shape case
node.out_node().shape = np.array(input_shape * node.in_node(1).value).astype(np.int64)
node.out_node().shape = np.array((input_shape + eps) * node.in_node(1).value).astype(np.int64)