Fix ChangeOutputTypeAttributes BackReplacementPattern (#6949)

* Hot fix

* Add unit test
This commit is contained in:
iliya mironov 2021-08-09 19:22:21 +03:00 committed by GitHub
parent 838e701e5e
commit eadeae6c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 3 deletions

View File

@ -52,7 +52,7 @@ class ChangeOutputTypeAttributes(BackReplacementPattern):
if node[dst_type] in [np.float32, np.float64] and ir_data_type == np.float16 and \
not node.has_and_set('returns_shape_value'):
final_type = np.float16
elif node.has_and_set('returns_shape_value') and node.dst_type == np.float16:
elif node.has_and_set('returns_shape_value') and node[dst_type] == np.float16:
# return back FP32 for all nodes with shape values
final_type = np.float32

View File

@ -26,6 +26,13 @@ class ChangeOutputTypeAttributesTests(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_range_correct_case_returns_shape_value(self):
graph, graph_ref = build_range_test_graphs(start=0, limit=10, delta=1, dst_type_str='FP32',
src_type_str='FP16', returns_shape_value=True)
ChangeOutputTypeAttributes().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
self.assertTrue(flag, resp)
# starting from ~1000 FP16 absolute difference between neighbor values is more than 1
# fails because of shape inconsistency
def test_range_different_values(self):
@ -58,13 +65,15 @@ class ChangeOutputTypeAttributesTests(unittest.TestCase):
self.assertRaises(Error, ChangeOutputTypeAttributes().find_and_replace_pattern, graph)
def build_range_test_graphs(start=0, limit=10, delta=1, dst_type_str='FP16'):
def build_range_test_graphs(start=0, limit=10, delta=1, dst_type_str='FP16',
src_type_str='FP32', returns_shape_value=None):
nodes = {
**valued_const_with_data('start', float32_array(start)),
**valued_const_with_data('limit', float32_array(limit)),
**valued_const_with_data('delta', float32_array(delta)),
**regular_op_with_empty_data('range', {'type': 'Range', 'op': 'Range',
'output_type': np.float32,
'returns_shape_value': returns_shape_value,
'output_type': data_type_str_to_np(src_type_str),
'infer': Range.infer}),
**result('res'),
}
@ -72,6 +81,7 @@ def build_range_test_graphs(start=0, limit=10, delta=1, dst_type_str='FP16'):
nodes_ref = deepcopy(nodes)
nodes_ref.update({
**regular_op_with_empty_data('range', {'type': 'Range', 'op': 'Range',
'returns_shape_value': returns_shape_value,
'output_type': data_type_str_to_np(dst_type_str),
'infer': Range.infer}),
})