[PT FE] aten::to extend op support for device tracing case (#15712)
This commit is contained in:
parent
38943434b6
commit
55d667ce32
@ -30,6 +30,13 @@ OutputVector translate_to(NodeContext& context) {
|
||||
// Input with index 1 is device we skip that input.
|
||||
dtype_idx = 2;
|
||||
memory_format_idx = 5;
|
||||
} else if (context.get_input_size() == 8) {
|
||||
// aten::to(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
|
||||
// pin_memory=None,
|
||||
// bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None)
|
||||
dtype_idx = 1;
|
||||
memory_format_idx = 7;
|
||||
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Unknown aten::to format");
|
||||
}
|
||||
|
@ -91,3 +91,25 @@ class TestAtenTo(PytorchLayerTest):
|
||||
self.input_type = input_type
|
||||
with pytest.raises(OpConversionFailure) as e:
|
||||
self._test(*self.create_model(output_type, memory_format=memory_format), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestAtenToDevice(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.uniform(low=0.0, high=50.0, size=(3,)), np.random.uniform(low=0.0, high=50.0, size=(3,)))
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_to(torch.nn.Module):
|
||||
|
||||
def forward(self, x, y):
|
||||
return x.to(y.device)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_to(), ref_net, "aten::to"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_aten_to_device(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)
|
||||
|
Loading…
Reference in New Issue
Block a user