[PT FE] Refactor aten::flatten and aten::transpose conversion (#19098)

* [PT FE] Refactor aten::flatten and aten::transpose conversion

* Fix code style

* Fix codestyle
This commit is contained in:
Maxim Vafin
2023-08-10 10:28:36 +02:00
committed by GitHub
parent 9deef1480a
commit f5f221a3a9
9 changed files with 21 additions and 42 deletions

View File

@@ -27,7 +27,9 @@ class TestFlatten(PytorchLayerTest):
return aten_flatten(dim0, dim1), ref_net, "aten::flatten"
@pytest.mark.parametrize("dim0,dim1", [[0, 1],
@pytest.mark.parametrize("dim0,dim1", [[0, -1],
[-2, -1],
[0, 1],
[0, 2],
[0, 3],
[1, 2],

View File

@@ -31,13 +31,13 @@ class aten_native_multi_head_attention(torch.nn.Module):
# Float masks raise a warning in PyTorch and are (incorrectly) converted to bool,
# which later returns NaNs as MHA's output
if mask == 0:
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
self.mask_type = 0
elif mask == 1:
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype(np.bool))
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype("bool"))
self.mask_type = 1
elif mask == 2:
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
self.mask_type = 2
else:
self.mask = None