Fix of ascii decode problem in TF Const with string value (#3672)

* added try/except to tf_tensor_content

* code refactoring

* added unittest

* update test
This commit is contained in:
Yegor Kruglov 2020-12-27 22:15:17 +03:00 committed by GitHub
parent 278b662e56
commit 870a6c061a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 1 deletions

View File

@ -72,7 +72,17 @@ def tf_tensor_content(tf_dtype, shape, pb_tensor):
value = np.array(np.frombuffer(pb_tensor.tensor_content, type_helper[0]))
else:
# load typed value
value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
if type_helper[0] != np.str:
value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
else:
try:
value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
except UnicodeDecodeError:
log.error(
'Failed to parse a tensor with Unicode characters. Note that Inference Engine does not support '
'string literals, so the string constant should be eliminated from the graph.',
extra={'is_warning': True})
value = np.array(type_helper[1](pb_tensor))
if len(shape) == 0 or shape.prod() == 0:
if len(value) == 1:

View File

@ -14,10 +14,12 @@
limitations under the License.
"""
import logging as log
import unittest
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.extractors.utils import collect_tf_attrs, tf_tensor_content
from mo.utils.unittest.extractors import PB
@ -193,3 +195,20 @@ class TensorContentParsing(unittest.TestCase):
[[6, 6], [6, 6], [6, 6], [6, 6], [6, 6]]]
res = tf_tensor_content(tf_dtype, shape, pb_tensor)
self.assertTrue(np.all(res == ref))
def test_str_decode(self):
pb_tensor = PB({
'dtype': 7,
'string_val': [b"\037\000\036\000\002\000\303\237\035\000\002"]
})
tf_dtype = pb_tensor.dtype
shape = int64_array([1])
warning_message = 'ERROR:root:Failed to parse a tensor with Unicode characters. Note that Inference Engine ' \
'does not support string literals, so the string constant should be eliminated from the ' \
'graph.'
ref_val = np.array([b'\x1f\x00\x1e\x00\x02\x00\xc3\x9f\x1d\x00\x02'])
with self.assertLogs(log.getLogger(), level="ERROR") as cm:
result = tf_tensor_content(tf_dtype, shape, pb_tensor)
self.assertEqual([warning_message], cm.output)
self.assertEqual(ref_val, result)