fix DynamicStitch tests (#14005)
This commit is contained in:
@@ -19,7 +19,10 @@ class TestParallelDynamicStitch(CommonTFLayerTest):
|
||||
indices_shape = inputs_info[indices_in_name]
|
||||
num_elements = num_elements + np.prod(indices_shape, dtype=int)
|
||||
|
||||
indices_array = np.arange(np.random.randint(1, num_elements+1), dtype=np.intc)
|
||||
# we support DynamicStitch via decomposition to subgraph with ScatterUpdate op
|
||||
# ScatterUpdate has undefined behavior if there are multiple identical indexes
|
||||
# indices_array = np.arange(np.random.randint(1, num_elements+1), dtype=np.intc)
|
||||
indices_array = np.arange(num_elements, dtype=np.intc)
|
||||
np.random.shuffle(indices_array)
|
||||
indices_array = np.resize(indices_array, num_elements)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user