[MO] Updating condition in TF Const value extracting (#6322)

* update value extracting condition

* added a comment

* updated unittest

* added a scalar test, little fix for scalar values extracting
This commit is contained in:
Yegor Kruglov 2021-07-08 13:55:29 +03:00 committed by GitHub
parent ba3a667730
commit b56164e804
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 2 deletions

View File

@ -70,8 +70,18 @@ def tf_tensor_content(tf_dtype, shape, pb_tensor):
log.error(decode_err_msg, extra={'is_warning': True})
value = np.array(type_helper[1](pb_tensor))
if len(shape) == 0 or shape.prod() == 0:
if len(value) == 1:
# Ignore an empty value, if len(shape) > 1
# For example, value = [] and shape = [1, 1, 0]
# This is needed to reshape this value later and to return reshaped value = [[[]]]
# Otherwise there can be failures during partial inference, because we are storing an empty value with incorrect
# shape
if len(shape) == 0 or (len(shape) == 1 and shape.prod() == 0):
try:
value_length = len(value)
except TypeError:
# case, when value is a scalar
value_length = 0
if value_length == 1:
# return scalar if shape is [] otherwise broadcast according to shape
try:
return np.array(value[0], dtype=type_helper[0])

View File

@ -211,3 +211,30 @@ class TensorContentParsing(unittest.TestCase):
with self.assertLogs(log.getLogger(), level="ERROR") as cm:
result = tf_tensor_content(pb_tensor.dtype, shape, pb_tensor)
self.assertEqual([warning_message, warning_message], cm.output)
def test_empty_value(self):
pb_tensor = PB({
'dtype': 1,
'float_val': []
})
shape = int64_array([1, 1, 0])
tf_dtype = pb_tensor.dtype
ref = np.array([[[]]], dtype=np.float32)
res = tf_tensor_content(tf_dtype, shape, pb_tensor)
self.assertEqual(res.shape, ref.shape)
self.assertTrue(np.all(res == ref))
def test_scalar_value(self):
pb_tensor = PB({
'dtype': 3,
'int_val': 4
})
shape = int64_array([])
tf_dtype = pb_tensor.dtype
ref = np.array(4, dtype=np.int32)
res = tf_tensor_content(tf_dtype, shape, pb_tensor)
self.assertEqual(ref, res)