Fix ChangeOutputTypeAttributes BackReplacementPattern (#6949)
* Hot fix * Add unit test
This commit is contained in:
parent
838e701e5e
commit
eadeae6c47
@ -52,7 +52,7 @@ class ChangeOutputTypeAttributes(BackReplacementPattern):
|
||||
if node[dst_type] in [np.float32, np.float64] and ir_data_type == np.float16 and \
|
||||
not node.has_and_set('returns_shape_value'):
|
||||
final_type = np.float16
|
||||
elif node.has_and_set('returns_shape_value') and node.dst_type == np.float16:
|
||||
elif node.has_and_set('returns_shape_value') and node[dst_type] == np.float16:
|
||||
# return back FP32 for all nodes with shape values
|
||||
final_type = np.float32
|
||||
|
||||
|
@ -26,6 +26,13 @@ class ChangeOutputTypeAttributesTests(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_range_correct_case_returns_shape_value(self):
|
||||
graph, graph_ref = build_range_test_graphs(start=0, limit=10, delta=1, dst_type_str='FP32',
|
||||
src_type_str='FP16', returns_shape_value=True)
|
||||
ChangeOutputTypeAttributes().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
# starting from ~1000 FP16 absolute difference between neighbor values is more than 1
|
||||
# fails because of shape inconsistency
|
||||
def test_range_different_values(self):
|
||||
@ -58,13 +65,15 @@ class ChangeOutputTypeAttributesTests(unittest.TestCase):
|
||||
self.assertRaises(Error, ChangeOutputTypeAttributes().find_and_replace_pattern, graph)
|
||||
|
||||
|
||||
def build_range_test_graphs(start=0, limit=10, delta=1, dst_type_str='FP16'):
|
||||
def build_range_test_graphs(start=0, limit=10, delta=1, dst_type_str='FP16',
|
||||
src_type_str='FP32', returns_shape_value=None):
|
||||
nodes = {
|
||||
**valued_const_with_data('start', float32_array(start)),
|
||||
**valued_const_with_data('limit', float32_array(limit)),
|
||||
**valued_const_with_data('delta', float32_array(delta)),
|
||||
**regular_op_with_empty_data('range', {'type': 'Range', 'op': 'Range',
|
||||
'output_type': np.float32,
|
||||
'returns_shape_value': returns_shape_value,
|
||||
'output_type': data_type_str_to_np(src_type_str),
|
||||
'infer': Range.infer}),
|
||||
**result('res'),
|
||||
}
|
||||
@ -72,6 +81,7 @@ def build_range_test_graphs(start=0, limit=10, delta=1, dst_type_str='FP16'):
|
||||
nodes_ref = deepcopy(nodes)
|
||||
nodes_ref.update({
|
||||
**regular_op_with_empty_data('range', {'type': 'Range', 'op': 'Range',
|
||||
'returns_shape_value': returns_shape_value,
|
||||
'output_type': data_type_str_to_np(dst_type_str),
|
||||
'infer': Range.infer}),
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user