Add support resize with 2 inputs (#5927)
* Add support resize with 2 inputs * Add unit tests * Hot fix * Change resize check from port count to connected num port conditions * Fix conditions * Refactoring code according to review * Fix according to review * Change onnresize11 input condition
This commit is contained in:
@@ -34,12 +34,13 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
log.warning('The input shape is not 4D or 5D for op with name {}'.format(resize_name))
|
||||
return
|
||||
|
||||
num_of_inputs = len([port for port in resize.in_ports().values() if not port.disconnected()])
|
||||
assert num_of_inputs in {3, 4}, \
|
||||
"Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name)
|
||||
assert (resize.is_in_port_connected(0) and (resize.is_in_port_connected(2) or resize.is_in_port_connected(3))), \
|
||||
"Scales or sizes inputs must be connected to Node {} with op {}.".format(resize.soft_get("name", resize.id),
|
||||
resize.op)
|
||||
|
||||
assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
|
||||
'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name)
|
||||
'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op,
|
||||
resize.soft_get("name", resize.id))
|
||||
|
||||
layout = graph.graph['layout']
|
||||
|
||||
@@ -74,7 +75,7 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
{'name': resize_name + '/axis',
|
||||
'value': int64_array(np.arange(begin_dim, end_dim))}).create_node()
|
||||
|
||||
shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes'
|
||||
shape_calculation_mode = 'sizes' if resize.is_in_port_connected(3) else 'scales'
|
||||
|
||||
interpolate_node = Interpolate(graph, {'version': 'opset4',
|
||||
'mode': convert_mode(resize.mode),
|
||||
@@ -96,7 +97,7 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
|
||||
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||
|
||||
if num_of_inputs == 3:
|
||||
if not resize.is_in_port_connected(3):
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
|
||||
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
|
||||
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||
|
||||
@@ -35,14 +35,15 @@ class ONNXResize11Op(Op):
|
||||
if input_shape is None:
|
||||
return
|
||||
|
||||
num_of_in_nodes = len(node.in_nodes())
|
||||
assert num_of_in_nodes in {3, 4}, \
|
||||
"Node {} with op {} number of inputs must be equal to 3 or 4.".format(node.name, node.op)
|
||||
assert (node.is_in_port_connected(0) and (node.is_in_port_connected(2) or node.is_in_port_connected(3))), \
|
||||
"One of the scales or sizes inputs must be connected to Node {} with op {}.".format(node.soft_get("name", node.id),
|
||||
node.op)
|
||||
|
||||
assert node.coordinate_transformation_mode != 'tf_crop_and_resize', \
|
||||
'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(node.op, node.name)
|
||||
'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(node.op,
|
||||
node.soft_get("name", node.id))
|
||||
|
||||
if num_of_in_nodes == 3:
|
||||
if not node.is_in_port_connected(3):
|
||||
# i.e. input 'sizes' is not given
|
||||
input2_value = node.in_port(2).data.get_value()
|
||||
assert input2_value is not None, \
|
||||
@@ -53,7 +54,7 @@ class ONNXResize11Op(Op):
|
||||
# i.e. input 'sizes' is given
|
||||
sizes = node.in_port(3).data.get_value()
|
||||
assert sizes is not None, \
|
||||
"Node {} with op {} has no value in input port 3".format(node.name, node.op)
|
||||
"Node {} with op {} has no value in input port 3".format(node.soft_get("name", node.id), node.op)
|
||||
output_shape = input_shape.copy()
|
||||
spatial_dimension_indices = range(2, len(input_shape))
|
||||
output_shape[spatial_dimension_indices] = int64_array(sizes)[2:]
|
||||
|
||||
@@ -33,10 +33,8 @@ graph_edges_sizes = [
|
||||
('input', 'input_data'),
|
||||
('roi', 'roi_data'),
|
||||
('sizes', 'sizes_data'),
|
||||
('scales', 'scales_data'),
|
||||
('input_data', 'onnx_resize11', {'in': 0}),
|
||||
('roi_data', 'onnx_resize11', {'in': 1}),
|
||||
('scales_data', 'onnx_resize11', {'in': 2}),
|
||||
('sizes_data', 'onnx_resize11', {'in': 3}),
|
||||
('onnx_resize11', 'onnx_resize11_data'),
|
||||
('onnx_resize11_data', 'op_output'),
|
||||
@@ -125,3 +123,69 @@ class TestONNXResize11Op(unittest.TestCase):
|
||||
|
||||
self.assertTrue(np.array_equal(graph.node['onnx_resize11_data']['shape'], int64_array(output_shape)),
|
||||
msg.format(scales, output_shape, graph.node['onnx_resize11_data']['shape']))
|
||||
|
||||
@generate(*[([1, 260, 100, 150], [1, 260, 200, 350], [1, 260, 200, 350], [1.0, 1.0, 1.0, 1.0]),
|
||||
([1, 260, 100, 150], [1, 260, 200, 350], [1, 1, 200, 350], [1.0, 1.0, 1.0, 1.0]),
|
||||
([5, 14, 300, 40], [5, 14, 140, 280], [1, 1, 140, 280], [1.0, 1.0, 1.0, 1.0]),
|
||||
([5, 14, 300, 40], [5, 14, 140, 280], [5, 14, 140, 280], [1.0, 1.0, 1.0, 1.0]),
|
||||
([1, 3, 260, 100, 150], [1, 3, 780, 200, 350], [1, 3, 780, 200, 350], [1.0, 1.0, 1.0, 1.0, 1.0]),
|
||||
([1, 3, 450, 100, 150], [1, 3, 260, 200, 350], [1, 3, 260, 200, 350], [1.0, 1.0, 1.0, 1.0, 1.0]),
|
||||
([5, 14, 1000, 300, 40], [5, 14, 500, 140, 280], [1, 1, 500, 140, 280], [1.0, 1.0, 1.0, 1.0, 1.0]),
|
||||
([5, 14, 1000, 300, 40], [5, 14, 500, 140, 280], [5, 14, 500, 140, 280], [1.0, 1.0, 1.0, 1.0, 1.0])])
|
||||
def test_onnx_resize11_using_sizes_without_roi_input(self, input_shape, output_shape, sizes, scales):
|
||||
np_scales = np.array(scales)
|
||||
np_sizes = int64_array(sizes)
|
||||
graph = build_graph(nodes_attrs=graph_node_attrs_sizes,
|
||||
edges=[('input', 'input_data'),
|
||||
('sizes', 'sizes_data'),
|
||||
('input_data', 'onnx_resize11', {'in': 0}),
|
||||
('sizes_data', 'onnx_resize11', {'in': 3}),
|
||||
('onnx_resize11', 'onnx_resize11_data'),
|
||||
('onnx_resize11_data', 'op_output'),
|
||||
],
|
||||
update_attributes={
|
||||
'input_data': {'shape': int64_array(input_shape)},
|
||||
'scales': {'shape': int64_array(np_scales.shape), 'value': np_scales},
|
||||
'scales_data': {'shape': int64_array(np_scales.shape), 'value': np_scales},
|
||||
'sizes': {'shape': int64_array(np_sizes.shape), 'value': np_sizes},
|
||||
'sizes_data': {'shape': int64_array(np_sizes.shape), 'value': np_sizes},
|
||||
})
|
||||
node = Node(graph, 'onnx_resize11')
|
||||
ONNXResize11Op.onnx_resize_infer(node)
|
||||
|
||||
msg = "ONNXResize11 infer failed for case: sizes={}, scales={}, expected_shape={}, actual_shape={}"
|
||||
|
||||
self.assertTrue(np.array_equal(graph.node['onnx_resize11_data']['shape'], int64_array(output_shape)),
|
||||
msg.format(sizes, scales, output_shape, graph.node['onnx_resize11_data']['shape']))
|
||||
|
||||
@generate(*[([1, 260, 100, 150], [1, 260, 200, 350], [1.0, 1.0, 2.0, 350 / 150]),
|
||||
([1, 3, 100, 200], [1, 3, 350, 150], [1.0, 1.0, 3.5, 150 / 200]),
|
||||
([5, 14, 300, 40], [5, 14, 140, 280], [1.0, 1.0, 140 / 300, 7.0]),
|
||||
([5, 14, 300, 40], [5, 14, 140, 560], [1.0, 1.0, 140 / 300, 14.0]),
|
||||
([1, 3, 260, 100, 150], [1, 3, 780, 200, 350], [1.0, 1.0, 3.0, 2.0, 350 / 150]),
|
||||
([1, 3, 450, 100, 150], [1, 3, 260, 200, 350], [1.0, 1.0, 260 / 450, 2.0, 350 / 150]),
|
||||
([5, 14, 1000, 300, 40], [5, 14, 500, 140, 280], [1.0, 1.0, 0.5, 140 / 300, 7.0]),
|
||||
([4, 3, 180, 1340], [4, 3, 60, 804], [1.0, 1.0, 0.33333334, 0.6]),
|
||||
([4, 3, 500, 180, 1340], [4, 3, 750, 60, 804], [1.0, 1.0, 1.5, 0.33333334, 0.6])])
|
||||
def test_onnx_resize_using_scales_without_roi(self, input_shape, output_shape, scales):
|
||||
np_scales = np.array(scales)
|
||||
graph = build_graph(nodes_attrs=graph_node_attrs_scales,
|
||||
edges=[('input', 'input_data'),
|
||||
('scales', 'scales_data'),
|
||||
('input_data', 'onnx_resize11', {'in': 0}),
|
||||
('scales_data', 'onnx_resize11', {'in': 2}),
|
||||
('onnx_resize11', 'onnx_resize11_data'),
|
||||
('onnx_resize11_data', 'op_output'),
|
||||
],
|
||||
update_attributes={
|
||||
'input_data': {'shape': int64_array(input_shape)},
|
||||
'scales': {'shape': int64_array(np_scales.shape), 'value': np_scales},
|
||||
'scales_data': {'shape': int64_array(np_scales.shape), 'value': np_scales},
|
||||
})
|
||||
node = Node(graph, 'onnx_resize11')
|
||||
ONNXResize11Op.onnx_resize_infer(node)
|
||||
|
||||
msg = "ONNXResize11 infer failed for case: scales={}, expected_shape={}, actual_shape={}"
|
||||
|
||||
self.assertTrue(np.array_equal(graph.node['onnx_resize11_data']['shape'], int64_array(output_shape)),
|
||||
msg.format(scales, output_shape, graph.node['onnx_resize11_data']['shape']))
|
||||
|
||||
Reference in New Issue
Block a user