[PT FE]: support aten::empty_like (#21258)

* [PT FE]: support aten::empty_like

* Update src/frontends/pytorch/src/op/full.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Ekaterina Aidova 2023-11-27 11:15:19 +04:00 committed by GitHub
parent a5d53aeaef
commit eaae00c2ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 0 deletions

View File

@ -205,6 +205,35 @@ OutputVector translate_empty(const NodeContext& context) {
return {empty};
};
OutputVector translate_empty_like(const NodeContext& context) {
// aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
// pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
// aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 1, 6);
auto input = context.get_input(0);
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
// In OV uninitialized data is not supported, so we create a tensor filled with zeros with a given shape and type.
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
int dtype_id = 1;
Output<Node> empty;
if (context.get_input_size() == 6) {
if (!context.input_is_none(dtype_id)) {
empty = base_translate_full_with_convert(context, sizes, value, dtype_id);
} else {
empty = base_translate_full(context, sizes, value);
}
} else if (context.get_input_size() == 4) {
auto out = context.input_is_none(3) ? input : context.get_input(3);
empty = base_translate_full_with_convertlike(context, sizes, value, out);
if (!context.input_is_none(3)) {
context.mutate_input(3, empty);
}
} else {
FRONT_END_GENERAL_CHECK(false, "Unexpected number of inputs.");
}
return {empty};
};
OutputVector translate_fill_diagonal(const NodeContext& context) {
// aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)
// realization inspired by numpy:

View File

@ -64,6 +64,7 @@ OP_CONVERTER(translate_elu);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_embedding_bag);
OP_CONVERTER(translate_empty);
OP_CONVERTER(translate_empty_like);
OP_CONVERTER(translate_erf);
OP_CONVERTER(translate_expand);
OP_CONVERTER(translate_expand_as);
@ -332,6 +333,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::embedding", op::translate_embedding},
{"aten::embedding_bag", op::translate_embedding_bag},
{"aten::empty", op::translate_empty},
{"aten::empty_like", op::translate_empty_like},
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten::erf", op::translate_erf},
{"aten::erf_", op::inplace_op<op::translate_erf>},

View File

@ -47,6 +47,60 @@ class TestEmptyNumeric(PytorchLayerTest):
def test_empty(self, ie_device, precision, ir_version, dtype):
self._test(*self.create_model(dtype), ie_device, precision, ir_version)
class TestEmptyLike(PytorchLayerTest):
def _prepare_input(self, shape, dtype=np.float32, out=False):
if not out:
return (np.random.randn(*shape).astype(dtype if dtype is not None else np.float32),)
return (np.random.randn(*shape), np.ones(shape, dtype=(dtype if dtype is not None else np.float32)))
def create_model(self, dtype, out):
class aten_empty_like(torch.nn.Module):
def __init__(self, dtype=None, out=False):
dtype_map = {
"float32": torch.float32,
"float64": torch.float64,
"int64": torch.int64,
"int32": torch.int32,
"uint8": torch.uint8,
"int8": torch.int8
}
super().__init__()
self.dtype = dtype_map.get(dtype, None)
if out:
self.forward = self.forward_out
def forward(self, input_tensor):
empty = torch.empty_like(input_tensor, dtype=self.dtype)
# We don't want to compare values, just shape and type,
# so we call zeros_like on data. Multiplying by zero would
# produce sporadic errors if nan would be in empty.
return torch.zeros_like(empty)
def forward_out(self, input_tensor, out_tensor):
torch.empty_like(input_tensor, out=out_tensor)
# We don't want to compare values, just shape and type,
# so we call zeros_like on data. Multiplying by zero would
# produce sporadic errors if nan would be in empty.
return torch.zeros_like(out_tensor)
ref_net = None
return aten_empty_like(dtype, out), ref_net, "aten::empty_like"
@pytest.mark.parametrize('dtype', (None, "float32", "float64", "int64", "int32", "uint8", "int8"))
@pytest.mark.parametrize("input_shape", [[2,], [1, 10], [10, 5, 2]])
@pytest.mark.parametrize("out", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_empty(self, ie_device, precision, ir_version, dtype, input_shape, out):
self._test(*self.create_model(dtype, out), ie_device, precision, ir_version,
kwargs_to_prepare_input={"shape": input_shape, "out": out, "dtype": dtype})
class TestEmptyBoolean(PytorchLayerTest):
def _prepare_input(self):