[MO] Updating condition in TF Const value extracting (#6322)
* update value extracting condition * added a comment * updated unittest * added a scalar test, little fix for scalar values extracting
This commit is contained in:
parent
ba3a667730
commit
b56164e804
@ -70,8 +70,18 @@ def tf_tensor_content(tf_dtype, shape, pb_tensor):
|
||||
log.error(decode_err_msg, extra={'is_warning': True})
|
||||
value = np.array(type_helper[1](pb_tensor))
|
||||
|
||||
if len(shape) == 0 or shape.prod() == 0:
|
||||
if len(value) == 1:
|
||||
# Ignore an empty value, if len(shape) > 1
|
||||
# For example, value = [] and shape = [1, 1, 0]
|
||||
# This is needed to reshape this value later and to return reshaped value = [[[]]]
|
||||
# Otherwise there can be failures during partial inference, because we are storing an empty value with incorrect
|
||||
# shape
|
||||
if len(shape) == 0 or (len(shape) == 1 and shape.prod() == 0):
|
||||
try:
|
||||
value_length = len(value)
|
||||
except TypeError:
|
||||
# case, when value is a scalar
|
||||
value_length = 0
|
||||
if value_length == 1:
|
||||
# return scalar if shape is [] otherwise broadcast according to shape
|
||||
try:
|
||||
return np.array(value[0], dtype=type_helper[0])
|
||||
|
@ -211,3 +211,30 @@ class TensorContentParsing(unittest.TestCase):
|
||||
with self.assertLogs(log.getLogger(), level="ERROR") as cm:
|
||||
result = tf_tensor_content(pb_tensor.dtype, shape, pb_tensor)
|
||||
self.assertEqual([warning_message, warning_message], cm.output)
|
||||
|
||||
def test_empty_value(self):
|
||||
pb_tensor = PB({
|
||||
'dtype': 1,
|
||||
'float_val': []
|
||||
})
|
||||
|
||||
shape = int64_array([1, 1, 0])
|
||||
tf_dtype = pb_tensor.dtype
|
||||
ref = np.array([[[]]], dtype=np.float32)
|
||||
res = tf_tensor_content(tf_dtype, shape, pb_tensor)
|
||||
|
||||
self.assertEqual(res.shape, ref.shape)
|
||||
self.assertTrue(np.all(res == ref))
|
||||
|
||||
def test_scalar_value(self):
|
||||
pb_tensor = PB({
|
||||
'dtype': 3,
|
||||
'int_val': 4
|
||||
})
|
||||
|
||||
shape = int64_array([])
|
||||
tf_dtype = pb_tensor.dtype
|
||||
ref = np.array(4, dtype=np.int32)
|
||||
res = tf_tensor_content(tf_dtype, shape, pb_tensor)
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
Loading…
Reference in New Issue
Block a user