[MO] Update Select shape inference function to support dynamic shapes (#8892)

* Update Select shape inference function to support dynamic shapes

* Update unit test for Select shape inference
This commit is contained in:
Anton Chetverikov 2021-12-01 05:00:55 +03:00 committed by GitHub
parent 184b602a49
commit e9a15b70f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 1 deletions

View File

@ -60,7 +60,9 @@ class Select(Op):
"But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
node_name, condition_shape, a_shape, b_shape)
assert condition_shape[0] == output_shape[0], msg_tf
# check equality only if both values non-dynamic
if is_fully_defined(condition_shape[0]) and is_fully_defined(output_shape[0]):
assert condition_shape[0] == output_shape[0], msg_tf
condition_shape = np.concatenate((condition_shape, np.ones(len(output_shape) - 1)))
output_shape = bi_directional_shape_broadcasting(output_shape, condition_shape)

View File

@ -272,6 +272,18 @@ class TestSelect(unittest.TestCase):
auto_broadcast='numpy', fw_format='tf')
self.assertTrue(flag, msg)
def test_select_infer_tf_condition_dyn(self):
flag, msg = self.build_select_graph_and_infer(condition_value=None,
condition_shape=shape_array([dynamic_dimension_value]),
then_value=None,
then_shape=shape_array([dynamic_dimension_value, 20]),
else_value=None,
else_shape=shape_array([dynamic_dimension_value, 20]),
out_value=None,
out_shape=shape_array([dynamic_dimension_value, 20]),
auto_broadcast='numpy', fw_format='tf')
self.assertTrue(flag, msg)
def test_select_infer_tf_condition_assert_raises(self):
with self.assertRaisesRegex(AssertionError, "if 'condition' is a 1D tensor then it's size"):
self.build_select_graph_and_infer(condition_value=None, condition_shape=shape_array([42]),