Fixed extracting TF models with string constants (#5110)

* Fixed extracting TF models with string constants

* Fixed parsing strings from the TF models
This commit is contained in:
Evgeny Lazarev 2021-04-06 10:33:40 +03:00 committed by GitHub
parent 9c11f38aba
commit 1d26a5600c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 5 deletions

View File

@ -55,6 +55,8 @@ def tf_tensor_content(tf_dtype, shape, pb_tensor):
raise Error("Data type is unsupported: {}. " +
refer_to_faq_msg(50), tf_dtype)
decode_err_msg = 'Failed to parse a tensor with Unicode characters. Note that Inference Engine does not support ' \
'string literals, so the string constant should be eliminated from the graph.'
if pb_tensor.tensor_content:
value = np.array(np.frombuffer(pb_tensor.tensor_content, type_helper[0]))
else:
@ -65,16 +67,17 @@ def tf_tensor_content(tf_dtype, shape, pb_tensor):
try:
value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
except UnicodeDecodeError:
log.error(
'Failed to parse a tensor with Unicode characters. Note that Inference Engine does not support '
'string literals, so the string constant should be eliminated from the graph.',
extra={'is_warning': True})
log.error(decode_err_msg, extra={'is_warning': True})
value = np.array(type_helper[1](pb_tensor))
if len(shape) == 0 or shape.prod() == 0:
if len(value) == 1:
# return scalar if shape is [] otherwise broadcast according to shape
return np.array(value[0], dtype=type_helper[0])
try:
return np.array(value[0], dtype=type_helper[0])
except UnicodeDecodeError:
log.error(decode_err_msg, extra={'is_warning': True})
return np.array(value[0])
else:
# no shape, return value as is
return value

View File

@ -199,3 +199,15 @@ class TensorContentParsing(unittest.TestCase):
self.assertEqual([warning_message], cm.output)
self.assertEqual(ref_val, result)
def test_str_decode_list(self):
pb_tensor = PB({
'dtype': 7,
'string_val': [b'\377\330\377\377\330\377'],
})
shape = int64_array([])
warning_message = 'ERROR:root:Failed to parse a tensor with Unicode characters. Note that Inference Engine ' \
'does not support string literals, so the string constant should be eliminated from the ' \
'graph.'
with self.assertLogs(log.getLogger(), level="ERROR") as cm:
result = tf_tensor_content(pb_tensor.dtype, shape, pb_tensor)
self.assertEqual([warning_message, warning_message], cm.output)