[PT FE] Add support for aten::numpy_T and aten::feature_dropout (#20136)
* Add support for aten::numpy_t and aten::feature_dropout * Update tests/layer_tests/pytorch_tests/test_transpose.py Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com> --------- Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com>
This commit is contained in:
parent
08bc3b2d7c
commit
7a6c5d0d41
@ -301,6 +301,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::eye", op::translate_eye},
|
||||
{"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine},
|
||||
{"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine},
|
||||
{"aten::feature_dropout", op::skip_node},
|
||||
{"aten::fill_", op::inplace_op<op::translate_fill_>},
|
||||
{"aten::flatten", op::quantizable_op<op::translate_flatten>},
|
||||
{"aten::flip", op::translate_flip},
|
||||
@ -384,6 +385,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::nonzero", op::translate_nonzero},
|
||||
{"aten::norm", op::translate_norm},
|
||||
{"aten::numel", op::translate_numel},
|
||||
{"aten::numpy_T", op::translate_t},
|
||||
{"aten::one_hot", op::translate_one_hot},
|
||||
{"aten::ones", op::translate_ones},
|
||||
{"aten::ones_like", op::translate_ones_like},
|
||||
|
@ -57,12 +57,14 @@ class TestTSmall(PytorchLayerTest):
|
||||
return (np.array(num_dims).astype(input_dtype),)
|
||||
return (np.random.randn(*shape[:num_dims]).astype(input_dtype),)
|
||||
|
||||
def create_model(self, num_dims=2, inplace=False):
|
||||
def create_model(self, mode):
|
||||
class aten_transpose(torch.nn.Module):
|
||||
def __init__(self, inplace):
|
||||
def __init__(self, mode):
|
||||
super(aten_transpose, self).__init__()
|
||||
if inplace:
|
||||
if mode == "inplace":
|
||||
self.forward = self.forward_inplace
|
||||
elif mode == "numpy":
|
||||
self.forward = self.forward_numpy_t
|
||||
|
||||
def forward(self, x):
|
||||
return x.t(), x
|
||||
@ -70,18 +72,21 @@ class TestTSmall(PytorchLayerTest):
|
||||
def forward_inplace(self, x):
|
||||
return x.t_(), x
|
||||
|
||||
def forward_numpy_t(self, x):
|
||||
return x.T, x
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_transpose(inplace), ref_net, "aten::t" if not inplace else "aten::t_"
|
||||
return aten_transpose(mode), ref_net, "aten::t_" if mode == "inplace" else ("aten::numpy_T" if mode == "numpy" else "aten::t")
|
||||
|
||||
@pytest.mark.parametrize("num_dims", [0, 1, 2])
|
||||
@pytest.mark.parametrize("input_dtype", ["float32", "int32"])
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
@pytest.mark.parametrize("mode", [None, "inplace", "numpy"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_t_small(self, num_dims, input_dtype, inplace, ie_device, precision, ir_version):
|
||||
def test_t_small(self, num_dims, input_dtype, mode, ie_device, precision, ir_version):
|
||||
self._test(
|
||||
*self.create_model(num_dims, inplace),
|
||||
*self.create_model(mode),
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
|
Loading…
Reference in New Issue
Block a user