[PT FE] aten::to extend op support for device tracing case (#15712)

This commit is contained in:
Ekaterina Aidova 2023-02-15 12:04:09 +04:00 committed by GitHub
parent 38943434b6
commit 55d667ce32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 0 deletions

View File

@ -30,6 +30,13 @@ OutputVector translate_to(NodeContext& context) {
// Input with index 1 is device we skip that input.
dtype_idx = 2;
memory_format_idx = 5;
} else if (context.get_input_size() == 8) {
// aten::to(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
// pin_memory=None,
// bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None)
dtype_idx = 1;
memory_format_idx = 7;
} else {
FRONT_END_OP_CONVERSION_CHECK(false, "Unknown aten::to format");
}

View File

@ -91,3 +91,25 @@ class TestAtenTo(PytorchLayerTest):
self.input_type = input_type
with pytest.raises(OpConversionFailure) as e:
self._test(*self.create_model(output_type, memory_format=memory_format), ie_device, precision, ir_version)
class TestAtenToDevice(PytorchLayerTest):
def _prepare_input(self):
return (np.random.uniform(low=0.0, high=50.0, size=(3,)), np.random.uniform(low=0.0, high=50.0, size=(3,)))
def create_model(self):
import torch
class aten_to(torch.nn.Module):
def forward(self, x, y):
return x.to(y.device)
ref_net = None
return aten_to(), ref_net, "aten::to"
@pytest.mark.nightly
@pytest.mark.precommit
def test_aten_to_device(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)