[MO] Update Select shape inference function to support dynamic shapes (#8892)
* Update Select shape inference function to support dynamic shapes * Update unit test for Select shape inference
This commit is contained in:
parent
184b602a49
commit
e9a15b70f5
@ -60,7 +60,9 @@ class Select(Op):
|
||||
"But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
|
||||
node_name, condition_shape, a_shape, b_shape)
|
||||
|
||||
assert condition_shape[0] == output_shape[0], msg_tf
|
||||
# check equality only if both values non-dynamic
|
||||
if is_fully_defined(condition_shape[0]) and is_fully_defined(output_shape[0]):
|
||||
assert condition_shape[0] == output_shape[0], msg_tf
|
||||
condition_shape = np.concatenate((condition_shape, np.ones(len(output_shape) - 1)))
|
||||
|
||||
output_shape = bi_directional_shape_broadcasting(output_shape, condition_shape)
|
||||
|
@ -272,6 +272,18 @@ class TestSelect(unittest.TestCase):
|
||||
auto_broadcast='numpy', fw_format='tf')
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
def test_select_infer_tf_condition_dyn(self):
|
||||
flag, msg = self.build_select_graph_and_infer(condition_value=None,
|
||||
condition_shape=shape_array([dynamic_dimension_value]),
|
||||
then_value=None,
|
||||
then_shape=shape_array([dynamic_dimension_value, 20]),
|
||||
else_value=None,
|
||||
else_shape=shape_array([dynamic_dimension_value, 20]),
|
||||
out_value=None,
|
||||
out_shape=shape_array([dynamic_dimension_value, 20]),
|
||||
auto_broadcast='numpy', fw_format='tf')
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
def test_select_infer_tf_condition_assert_raises(self):
|
||||
with self.assertRaisesRegex(AssertionError, "if 'condition' is a 1D tensor then it's size"):
|
||||
self.build_select_graph_and_infer(condition_value=None, condition_shape=shape_array([42]),
|
||||
|
Loading…
Reference in New Issue
Block a user