[PT FE] Add aten::new_empty (#16312)

* Add new_empty

* Remove duplicated code for new_empty
This commit is contained in:
Mateusz Mikolajczyk 2023-03-20 10:03:33 +01:00 committed by GitHub
parent c472b020b7
commit 134ebb8889
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 11 deletions

View File

@ -65,7 +65,7 @@ OutputVector translate_full_like(NodeContext& context) {
auto input = context.get_input(0);
auto value = context.get_input(1);
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
if (context.get_input_size() == 7) {
if (context.get_input_size() == 7 && !context.input_is_none(2)) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
}
auto out = context.input_is_none(3) ? input : context.get_input(3);
@ -113,7 +113,7 @@ OutputVector translate_zeros_like(NodeContext& context) {
auto input = context.get_input(0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
if (context.get_input_size() == 6) {
if (context.get_input_size() == 6 && !context.input_is_none(1)) {
return {base_translate_full_with_convert(context, sizes, value, 1)};
}
auto out = context.input_is_none(2) ? input : context.get_input(2);
@ -153,7 +153,7 @@ OutputVector translate_ones_like(NodeContext& context) {
auto input = context.get_input(0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
if (context.get_input_size() == 6) {
if (context.get_input_size() == 6 && !context.input_is_none(1)) {
return {base_translate_full_with_convert(context, sizes, value, 1)};
}
auto out = context.input_is_none(2) ? input : context.get_input(2);
@ -172,7 +172,7 @@ OutputVector translate_new_ones(NodeContext& context) {
};
OutputVector translate_empty(NodeContext& context) {
num_inputs_check(context, 1, 2);
num_inputs_check(context, 1, 5);
auto sizes = context.get_input(0);
// In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type.
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
@ -185,7 +185,6 @@ OutputVector translate_empty(NodeContext& context) {
}
return {empty};
};
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -254,6 +254,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"aten::narrow", op::translate_narrow},
{"aten::ne", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
{"aten::neg", op::translate_neg},
{"aten::new_empty", op::translate_new_zeros},
{"aten::new_full", op::translate_new_full},
{"aten::new_ones", op::translate_new_ones},
{"aten::new_zeros", op::translate_new_zeros},

View File

@ -11,7 +11,7 @@ from pytorch_layer_test_class import PytorchLayerTest
class TestEmptyNumeric(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 3),)
return (np.random.randn(10, 10, 10),)
def create_model(self, dtype):
@ -28,14 +28,14 @@ class TestEmptyNumeric(PytorchLayerTest):
}
super().__init__()
self.dtype = dtype_map[dtype]
self.zero = torch.tensor([0], dtype=dtype_map[dtype])
def forward(self, input_tensor):
size = input_tensor.shape
empty = torch.empty(size, dtype=self.dtype)
# We don't want to compare values, just shape and type,
# so we multiply the tensor by zero.
return empty*self.zero
# so we call zeros_like on data. Multiplying by zero would
# produce sporadic errors if nan would be in empty.
return torch.zeros_like(empty)
ref_net = None
@ -50,7 +50,7 @@ class TestEmptyNumeric(PytorchLayerTest):
class TestEmptyBoolean(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(3, 4, 3),)
return (np.random.randn(10, 10, 10),)
def create_model(self):
@ -75,3 +75,71 @@ class TestEmptyBoolean(PytorchLayerTest):
@pytest.mark.precommit
def test_empty_bool(self, ie_device, precision, ir_version, ):
self._test(*self.create_model(), ie_device, precision, ir_version)
class TestNewEmpty(PytorchLayerTest):
def _prepare_input(self, input_dtype=np.float32):
return (np.random.randn(1, 3, 10, 10).astype(input_dtype),)
def create_model(self, shape, dtype=None, used_dtype=False):
import torch
dtype_map = {
"float32": torch.float32,
"float64": torch.float64,
"int64": torch.int64,
"int32": torch.int32,
"uint8": torch.uint8,
"int8": torch.int8,
"bool": torch.bool
}
class aten_empty(torch.nn.Module):
def __init__(self, shape):
super(aten_empty, self).__init__()
self.shape = shape
self.zero = torch.tensor([0])
def forward(self, input_tensor: torch.Tensor):
empty = input_tensor.new_empty(self.shape)
# We don't want to compare values, just shape and type,
# so we call zeros_like on data. Multiplying by zero would
# produce sporadic errors if nan would be in empty.
return torch.zeros_like(empty)
class aten_empty_with_dtype(torch.nn.Module):
def __init__(self, shape, dtype):
super(aten_empty_with_dtype, self).__init__()
self.shape = shape
self.dtype = dtype
self.zero = torch.tensor([0], dtype=self.dtype)
def forward(self, input_tensor: torch.Tensor):
empty = input_tensor.new_empty(self.shape, dtype=self.dtype)
# We don't want to compare values, just shape and type,
# so we call zeros_like on data. Multiplying by zero would
# produce sporadic errors if nan would be in empty.
return torch.zeros_like(empty)
ref_net = None
model = aten_empty(shape)
if used_dtype:
dtype = dtype_map[dtype]
model = aten_empty_with_dtype(shape, dtype)
return model, ref_net, "aten::new_empty"
@pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]])
@pytest.mark.parametrize("input_dtype", [np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64])
@pytest.mark.nightly
@pytest.mark.precommit
def test_new_empty(self, shape, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape), ie_device, precision, ir_version,
kwargs_to_prepare_input={'input_dtype': input_dtype})
@pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]])
@pytest.mark.parametrize("input_dtype", [bool, np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64])
@pytest.mark.parametrize("dtype", ["bool", "uint8", "int8", "int32", "int64", "float32", "float64"])
@pytest.mark.nightly
def test_new_empty_with_dtype(self, shape, dtype, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version,
kwargs_to_prepare_input={'input_dtype': input_dtype})