Fixed extracting TF models with string constants (#5110)
* Fixed extracting TF models with string constants * Fixed parsing strings from the TF models
This commit is contained in:
parent
9c11f38aba
commit
1d26a5600c
@ -55,6 +55,8 @@ def tf_tensor_content(tf_dtype, shape, pb_tensor):
|
||||
raise Error("Data type is unsupported: {}. " +
|
||||
refer_to_faq_msg(50), tf_dtype)
|
||||
|
||||
decode_err_msg = 'Failed to parse a tensor with Unicode characters. Note that Inference Engine does not support ' \
|
||||
'string literals, so the string constant should be eliminated from the graph.'
|
||||
if pb_tensor.tensor_content:
|
||||
value = np.array(np.frombuffer(pb_tensor.tensor_content, type_helper[0]))
|
||||
else:
|
||||
@ -65,16 +67,17 @@ def tf_tensor_content(tf_dtype, shape, pb_tensor):
|
||||
try:
|
||||
value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
|
||||
except UnicodeDecodeError:
|
||||
log.error(
|
||||
'Failed to parse a tensor with Unicode characters. Note that Inference Engine does not support '
|
||||
'string literals, so the string constant should be eliminated from the graph.',
|
||||
extra={'is_warning': True})
|
||||
log.error(decode_err_msg, extra={'is_warning': True})
|
||||
value = np.array(type_helper[1](pb_tensor))
|
||||
|
||||
if len(shape) == 0 or shape.prod() == 0:
|
||||
if len(value) == 1:
|
||||
# return scalar if shape is [] otherwise broadcast according to shape
|
||||
return np.array(value[0], dtype=type_helper[0])
|
||||
try:
|
||||
return np.array(value[0], dtype=type_helper[0])
|
||||
except UnicodeDecodeError:
|
||||
log.error(decode_err_msg, extra={'is_warning': True})
|
||||
return np.array(value[0])
|
||||
else:
|
||||
# no shape, return value as is
|
||||
return value
|
||||
|
@ -199,3 +199,15 @@ class TensorContentParsing(unittest.TestCase):
|
||||
self.assertEqual([warning_message], cm.output)
|
||||
self.assertEqual(ref_val, result)
|
||||
|
||||
def test_str_decode_list(self):
|
||||
pb_tensor = PB({
|
||||
'dtype': 7,
|
||||
'string_val': [b'\377\330\377\377\330\377'],
|
||||
})
|
||||
shape = int64_array([])
|
||||
warning_message = 'ERROR:root:Failed to parse a tensor with Unicode characters. Note that Inference Engine ' \
|
||||
'does not support string literals, so the string constant should be eliminated from the ' \
|
||||
'graph.'
|
||||
with self.assertLogs(log.getLogger(), level="ERROR") as cm:
|
||||
result = tf_tensor_content(pb_tensor.dtype, shape, pb_tensor)
|
||||
self.assertEqual([warning_message, warning_message], cm.output)
|
||||
|
Loading…
Reference in New Issue
Block a user