From 870a6c061a5ee28a724b3e0d56a58a774c5a995e Mon Sep 17 00:00:00 2001 From: Yegor Kruglov Date: Sun, 27 Dec 2020 22:15:17 +0300 Subject: [PATCH] Fix of ascii decode problem in TF Const with string value (#3672) * added try/except to tf_tensor_content * code refactoring * added unittest * update test --- .../mo/front/tf/extractors/utils.py | 12 +++++++++++- .../mo/front/tf/extractors/utils_test.py | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/model-optimizer/mo/front/tf/extractors/utils.py b/model-optimizer/mo/front/tf/extractors/utils.py index f17d5ed3f5e..1b74a71ccb4 100644 --- a/model-optimizer/mo/front/tf/extractors/utils.py +++ b/model-optimizer/mo/front/tf/extractors/utils.py @@ -72,7 +72,17 @@ def tf_tensor_content(tf_dtype, shape, pb_tensor): value = np.array(np.frombuffer(pb_tensor.tensor_content, type_helper[0])) else: # load typed value - value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0]) + if type_helper[0] != np.str: + value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0]) + else: + try: + value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0]) + except UnicodeDecodeError: + log.error( + 'Failed to parse a tensor with Unicode characters. Note that Inference Engine does not support ' + 'string literals, so the string constant should be eliminated from the graph.', + extra={'is_warning': True}) + value = np.array(type_helper[1](pb_tensor)) if len(shape) == 0 or shape.prod() == 0: if len(value) == 1: diff --git a/model-optimizer/mo/front/tf/extractors/utils_test.py b/model-optimizer/mo/front/tf/extractors/utils_test.py index f10466b029f..15306206866 100644 --- a/model-optimizer/mo/front/tf/extractors/utils_test.py +++ b/model-optimizer/mo/front/tf/extractors/utils_test.py @@ -14,10 +14,12 @@ limitations under the License. """ +import logging as log import unittest import numpy as np +from mo.front.common.partial_infer.utils import int64_array from mo.front.tf.extractors.utils import collect_tf_attrs, tf_tensor_content from mo.utils.unittest.extractors import PB @@ -193,3 +195,20 @@ class TensorContentParsing(unittest.TestCase): [[6, 6], [6, 6], [6, 6], [6, 6], [6, 6]]] res = tf_tensor_content(tf_dtype, shape, pb_tensor) self.assertTrue(np.all(res == ref)) + + def test_str_decode(self): + pb_tensor = PB({ + 'dtype': 7, + 'string_val': [b"\037\000\036\000\002\000\303\237\035\000\002"] + }) + tf_dtype = pb_tensor.dtype + shape = int64_array([1]) + warning_message = 'ERROR:root:Failed to parse a tensor with Unicode characters. Note that Inference Engine ' \ + 'does not support string literals, so the string constant should be eliminated from the ' \ + 'graph.' + ref_val = np.array([b'\x1f\x00\x1e\x00\x02\x00\xc3\x9f\x1d\x00\x02']) + with self.assertLogs(log.getLogger(), level="ERROR") as cm: + result = tf_tensor_content(tf_dtype, shape, pb_tensor) + self.assertEqual([warning_message], cm.output) + self.assertEqual(ref_val, result) +