[PT FE] Refactor aten::flatten and aten::transpose conversion (#19098)
* [PT FE] Refactor aten::flatten and aten::transpose conversion * Fix code style * Fix codestyle
This commit is contained in:
@@ -27,7 +27,9 @@ class TestFlatten(PytorchLayerTest):
|
||||
|
||||
return aten_flatten(dim0, dim1), ref_net, "aten::flatten"
|
||||
|
||||
@pytest.mark.parametrize("dim0,dim1", [[0, 1],
|
||||
@pytest.mark.parametrize("dim0,dim1", [[0, -1],
|
||||
[-2, -1],
|
||||
[0, 1],
|
||||
[0, 2],
|
||||
[0, 3],
|
||||
[1, 2],
|
||||
|
||||
@@ -31,13 +31,13 @@ class aten_native_multi_head_attention(torch.nn.Module):
|
||||
# Float masks raise a warning in PyTorch and are (incorrectly) converted to bool,
|
||||
# which later returns NaNs as MHA's output
|
||||
if mask == 0:
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
|
||||
self.mask_type = 0
|
||||
elif mask == 1:
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype(np.bool))
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype("bool"))
|
||||
self.mask_type = 1
|
||||
elif mask == 2:
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
|
||||
self.mask_type = 2
|
||||
else:
|
||||
self.mask = None
|
||||
|
||||
Reference in New Issue
Block a user