From 92788b183889eb5d21828d12df74697ab0e7eba2 Mon Sep 17 00:00:00 2001 From: Leonard Sikorski Date: Thu, 9 Feb 2023 08:24:08 +0100 Subject: [PATCH] [PT FE] Add aten::empty operator with layer test (#15490) --- src/frontends/pytorch/src/op/full.cpp | 14 ++++ src/frontends/pytorch/src/op_table.cpp | 2 + tests/layer_tests/pytorch_tests/test_empty.py | 77 +++++++++++++++++++ 3 files changed, 93 insertions(+) create mode 100644 tests/layer_tests/pytorch_tests/test_empty.py diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index 188192afcd1..26bbe80b0ed 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -159,6 +159,20 @@ OutputVector translate_new_ones(NodeContext& context) { return {base_translate_full_with_convertlike(context, sizes, value, input)}; }; +OutputVector translate_empty(NodeContext& context) { + auto sizes = context.get_input(0); + // In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type. + auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {0})); + int dtype_id = 1; + ov::Output empty; + if (!context.input_is_none(dtype_id)) { + empty = base_translate_full_with_convert(context, sizes, value, dtype_id); + } else { + empty = base_translate_full(context, sizes, value); + } + return {empty}; +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index d0720d0db6c..0b43072cbb9 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -32,6 +32,7 @@ OP_CONVERTER(translate_convolution_mode); OP_CONVERTER(translate_dim); OP_CONVERTER(translate_div); OP_CONVERTER(translate_elu); +OP_CONVERTER(translate_empty); OP_CONVERTER(translate_embedding); OP_CONVERTER(translate_expand); OP_CONVERTER(translate_expand_as); @@ -177,6 +178,7 @@ const std::map get_supported_ops() { {"aten::dropout_", op::skip_node}, {"aten::elu", op::translate_elu}, {"aten::embedding", op::translate_embedding}, + {"aten::empty", op::translate_empty}, {"aten::eq", op::translate_1to1_match_2_inputs_align_types}, {"aten::exp", op::translate_1to1_match_1_inputs}, {"aten::expand", op::translate_expand}, diff --git a/tests/layer_tests/pytorch_tests/test_empty.py b/tests/layer_tests/pytorch_tests/test_empty.py new file mode 100644 index 00000000000..2a7a5b1e470 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_empty.py @@ -0,0 +1,77 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestEmptyNumeric(PytorchLayerTest): + + def _prepare_input(self): + return (np.random.randn(2, 3),) + + def create_model(self, dtype): + + class aten_empty(torch.nn.Module): + + def __init__(self, dtype) -> None: + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, + "int8": torch.int8 + } + super().__init__() + self.dtype = dtype_map[dtype] + self.zero = torch.tensor([0], dtype=dtype_map[dtype]) + + def forward(self, input_tensor): + size = input_tensor.shape + empty = torch.empty(size, dtype=self.dtype) + # We don't want to compare values, just shape and type, + # so we multiply the tensor by zero. + return empty*self.zero + + ref_net = None + + return aten_empty(dtype), ref_net, "aten::empty" + + @pytest.mark.parametrize('dtype', ("float32", "float64", "int64", "int32", "uint8", "int8")) + @pytest.mark.nightly + @pytest.mark.precommit + def test_empty(self, ie_device, precision, ir_version, dtype): + self._test(*self.create_model(dtype), ie_device, precision, ir_version) + +class TestEmptyBoolean(PytorchLayerTest): + + def _prepare_input(self): + return (np.random.randn(3, 4, 3),) + + def create_model(self): + + class aten_empty(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.false = torch.tensor([False]) + + def forward(self, input_tensor): + size = input_tensor.shape + empty = torch.empty(size, dtype=torch.bool) + # We don't want to compare values, just shape and type, + # so we do "and" operation with False. + return empty & self.false + + ref_net = None + + return aten_empty(), ref_net, "aten::empty" + + @pytest.mark.nightly + @pytest.mark.precommit + def test_empty_bool(self, ie_device, precision, ir_version, ): + self._test(*self.create_model(), ie_device, precision, ir_version)