Add support resize with 2 inputs (#5927)

* Add support resize with 2 inputs

* Add unit tests

* Hot fix

* Change resize check from port count to connected num port conditions

* Fix conditions

* Refactoring code according to review

* Fix according to review

* Change onnresize11 input condition
This commit is contained in:
iliya mironov
2021-06-08 13:15:54 +03:00
committed by GitHub
parent 503d18c80f
commit ae5608534e
3 changed files with 80 additions and 14 deletions

View File

@@ -34,12 +34,13 @@ def replace_resize(graph: Graph, resize: Node):
log.warning('The input shape is not 4D or 5D for op with name {}'.format(resize_name))
return
num_of_inputs = len([port for port in resize.in_ports().values() if not port.disconnected()])
assert num_of_inputs in {3, 4}, \
"Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name)
assert (resize.is_in_port_connected(0) and (resize.is_in_port_connected(2) or resize.is_in_port_connected(3))), \
"Scales or sizes inputs must be connected to Node {} with op {}.".format(resize.soft_get("name", resize.id),
resize.op)
assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name)
'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op,
resize.soft_get("name", resize.id))
layout = graph.graph['layout']
@@ -74,7 +75,7 @@ def replace_resize(graph: Graph, resize: Node):
{'name': resize_name + '/axis',
'value': int64_array(np.arange(begin_dim, end_dim))}).create_node()
shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes'
shape_calculation_mode = 'sizes' if resize.is_in_port_connected(3) else 'scales'
interpolate_node = Interpolate(graph, {'version': 'opset4',
'mode': convert_mode(resize.mode),
@@ -96,7 +97,7 @@ def replace_resize(graph: Graph, resize: Node):
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
if num_of_inputs == 3:
if not resize.is_in_port_connected(3):
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))

View File

@@ -35,14 +35,15 @@ class ONNXResize11Op(Op):
if input_shape is None:
return
num_of_in_nodes = len(node.in_nodes())
assert num_of_in_nodes in {3, 4}, \
"Node {} with op {} number of inputs must be equal to 3 or 4.".format(node.name, node.op)
assert (node.is_in_port_connected(0) and (node.is_in_port_connected(2) or node.is_in_port_connected(3))), \
"One of the scales or sizes inputs must be connected to Node {} with op {}.".format(node.soft_get("name", node.id),
node.op)
assert node.coordinate_transformation_mode != 'tf_crop_and_resize', \
'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(node.op, node.name)
'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(node.op,
node.soft_get("name", node.id))
if num_of_in_nodes == 3:
if not node.is_in_port_connected(3):
# i.e. input 'sizes' is not given
input2_value = node.in_port(2).data.get_value()
assert input2_value is not None, \
@@ -53,7 +54,7 @@ class ONNXResize11Op(Op):
# i.e. input 'sizes' is given
sizes = node.in_port(3).data.get_value()
assert sizes is not None, \
"Node {} with op {} has no value in input port 3".format(node.name, node.op)
"Node {} with op {} has no value in input port 3".format(node.soft_get("name", node.id), node.op)
output_shape = input_shape.copy()
spatial_dimension_indices = range(2, len(input_shape))
output_shape[spatial_dimension_indices] = int64_array(sizes)[2:]

View File

@@ -33,10 +33,8 @@ graph_edges_sizes = [
('input', 'input_data'),
('roi', 'roi_data'),
('sizes', 'sizes_data'),
('scales', 'scales_data'),
('input_data', 'onnx_resize11', {'in': 0}),
('roi_data', 'onnx_resize11', {'in': 1}),
('scales_data', 'onnx_resize11', {'in': 2}),
('sizes_data', 'onnx_resize11', {'in': 3}),
('onnx_resize11', 'onnx_resize11_data'),
('onnx_resize11_data', 'op_output'),
@@ -125,3 +123,69 @@ class TestONNXResize11Op(unittest.TestCase):
self.assertTrue(np.array_equal(graph.node['onnx_resize11_data']['shape'], int64_array(output_shape)),
msg.format(scales, output_shape, graph.node['onnx_resize11_data']['shape']))
@generate(*[([1, 260, 100, 150], [1, 260, 200, 350], [1, 260, 200, 350], [1.0, 1.0, 1.0, 1.0]),
([1, 260, 100, 150], [1, 260, 200, 350], [1, 1, 200, 350], [1.0, 1.0, 1.0, 1.0]),
([5, 14, 300, 40], [5, 14, 140, 280], [1, 1, 140, 280], [1.0, 1.0, 1.0, 1.0]),
([5, 14, 300, 40], [5, 14, 140, 280], [5, 14, 140, 280], [1.0, 1.0, 1.0, 1.0]),
([1, 3, 260, 100, 150], [1, 3, 780, 200, 350], [1, 3, 780, 200, 350], [1.0, 1.0, 1.0, 1.0, 1.0]),
([1, 3, 450, 100, 150], [1, 3, 260, 200, 350], [1, 3, 260, 200, 350], [1.0, 1.0, 1.0, 1.0, 1.0]),
([5, 14, 1000, 300, 40], [5, 14, 500, 140, 280], [1, 1, 500, 140, 280], [1.0, 1.0, 1.0, 1.0, 1.0]),
([5, 14, 1000, 300, 40], [5, 14, 500, 140, 280], [5, 14, 500, 140, 280], [1.0, 1.0, 1.0, 1.0, 1.0])])
def test_onnx_resize11_using_sizes_without_roi_input(self, input_shape, output_shape, sizes, scales):
np_scales = np.array(scales)
np_sizes = int64_array(sizes)
graph = build_graph(nodes_attrs=graph_node_attrs_sizes,
edges=[('input', 'input_data'),
('sizes', 'sizes_data'),
('input_data', 'onnx_resize11', {'in': 0}),
('sizes_data', 'onnx_resize11', {'in': 3}),
('onnx_resize11', 'onnx_resize11_data'),
('onnx_resize11_data', 'op_output'),
],
update_attributes={
'input_data': {'shape': int64_array(input_shape)},
'scales': {'shape': int64_array(np_scales.shape), 'value': np_scales},
'scales_data': {'shape': int64_array(np_scales.shape), 'value': np_scales},
'sizes': {'shape': int64_array(np_sizes.shape), 'value': np_sizes},
'sizes_data': {'shape': int64_array(np_sizes.shape), 'value': np_sizes},
})
node = Node(graph, 'onnx_resize11')
ONNXResize11Op.onnx_resize_infer(node)
msg = "ONNXResize11 infer failed for case: sizes={}, scales={}, expected_shape={}, actual_shape={}"
self.assertTrue(np.array_equal(graph.node['onnx_resize11_data']['shape'], int64_array(output_shape)),
msg.format(sizes, scales, output_shape, graph.node['onnx_resize11_data']['shape']))
@generate(*[([1, 260, 100, 150], [1, 260, 200, 350], [1.0, 1.0, 2.0, 350 / 150]),
([1, 3, 100, 200], [1, 3, 350, 150], [1.0, 1.0, 3.5, 150 / 200]),
([5, 14, 300, 40], [5, 14, 140, 280], [1.0, 1.0, 140 / 300, 7.0]),
([5, 14, 300, 40], [5, 14, 140, 560], [1.0, 1.0, 140 / 300, 14.0]),
([1, 3, 260, 100, 150], [1, 3, 780, 200, 350], [1.0, 1.0, 3.0, 2.0, 350 / 150]),
([1, 3, 450, 100, 150], [1, 3, 260, 200, 350], [1.0, 1.0, 260 / 450, 2.0, 350 / 150]),
([5, 14, 1000, 300, 40], [5, 14, 500, 140, 280], [1.0, 1.0, 0.5, 140 / 300, 7.0]),
([4, 3, 180, 1340], [4, 3, 60, 804], [1.0, 1.0, 0.33333334, 0.6]),
([4, 3, 500, 180, 1340], [4, 3, 750, 60, 804], [1.0, 1.0, 1.5, 0.33333334, 0.6])])
def test_onnx_resize_using_scales_without_roi(self, input_shape, output_shape, scales):
np_scales = np.array(scales)
graph = build_graph(nodes_attrs=graph_node_attrs_scales,
edges=[('input', 'input_data'),
('scales', 'scales_data'),
('input_data', 'onnx_resize11', {'in': 0}),
('scales_data', 'onnx_resize11', {'in': 2}),
('onnx_resize11', 'onnx_resize11_data'),
('onnx_resize11_data', 'op_output'),
],
update_attributes={
'input_data': {'shape': int64_array(input_shape)},
'scales': {'shape': int64_array(np_scales.shape), 'value': np_scales},
'scales_data': {'shape': int64_array(np_scales.shape), 'value': np_scales},
})
node = Node(graph, 'onnx_resize11')
ONNXResize11Op.onnx_resize_infer(node)
msg = "ONNXResize11 infer failed for case: scales={}, expected_shape={}, actual_shape={}"
self.assertTrue(np.array_equal(graph.node['onnx_resize11_data']['shape'], int64_array(output_shape)),
msg.format(scales, output_shape, graph.node['onnx_resize11_data']['shape']))