Fix of ascii decode problem in TF Const with string value (#3672)
* added try/except to tf_tensor_content * code refactoring * added unittest * update test
This commit is contained in:
parent
278b662e56
commit
870a6c061a
@ -72,7 +72,17 @@ def tf_tensor_content(tf_dtype, shape, pb_tensor):
|
||||
value = np.array(np.frombuffer(pb_tensor.tensor_content, type_helper[0]))
|
||||
else:
|
||||
# load typed value
|
||||
value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
|
||||
if type_helper[0] != np.str:
|
||||
value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
|
||||
else:
|
||||
try:
|
||||
value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
|
||||
except UnicodeDecodeError:
|
||||
log.error(
|
||||
'Failed to parse a tensor with Unicode characters. Note that Inference Engine does not support '
|
||||
'string literals, so the string constant should be eliminated from the graph.',
|
||||
extra={'is_warning': True})
|
||||
value = np.array(type_helper[1](pb_tensor))
|
||||
|
||||
if len(shape) == 0 or shape.prod() == 0:
|
||||
if len(value) == 1:
|
||||
|
@ -14,10 +14,12 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging as log
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.extractors.utils import collect_tf_attrs, tf_tensor_content
|
||||
from mo.utils.unittest.extractors import PB
|
||||
|
||||
@ -193,3 +195,20 @@ class TensorContentParsing(unittest.TestCase):
|
||||
[[6, 6], [6, 6], [6, 6], [6, 6], [6, 6]]]
|
||||
res = tf_tensor_content(tf_dtype, shape, pb_tensor)
|
||||
self.assertTrue(np.all(res == ref))
|
||||
|
||||
def test_str_decode(self):
|
||||
pb_tensor = PB({
|
||||
'dtype': 7,
|
||||
'string_val': [b"\037\000\036\000\002\000\303\237\035\000\002"]
|
||||
})
|
||||
tf_dtype = pb_tensor.dtype
|
||||
shape = int64_array([1])
|
||||
warning_message = 'ERROR:root:Failed to parse a tensor with Unicode characters. Note that Inference Engine ' \
|
||||
'does not support string literals, so the string constant should be eliminated from the ' \
|
||||
'graph.'
|
||||
ref_val = np.array([b'\x1f\x00\x1e\x00\x02\x00\xc3\x9f\x1d\x00\x02'])
|
||||
with self.assertLogs(log.getLogger(), level="ERROR") as cm:
|
||||
result = tf_tensor_content(tf_dtype, shape, pb_tensor)
|
||||
self.assertEqual([warning_message], cm.output)
|
||||
self.assertEqual(ref_val, result)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user