[PT FE] Add aten::empty operator with layer test (#15490)

This commit is contained in:
Leonard Sikorski 2023-02-09 08:24:08 +01:00 committed by GitHub
parent e77c2ab6d7
commit 92788b1838
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 0 deletions

View File

@ -159,6 +159,20 @@ OutputVector translate_new_ones(NodeContext& context) {
return {base_translate_full_with_convertlike(context, sizes, value, input)};
};
OutputVector translate_empty(NodeContext& context) {
auto sizes = context.get_input(0);
// In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type.
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {0}));
int dtype_id = 1;
ov::Output<ov::Node> empty;
if (!context.input_is_none(dtype_id)) {
empty = base_translate_full_with_convert(context, sizes, value, dtype_id);
} else {
empty = base_translate_full(context, sizes, value);
}
return {empty};
};
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -32,6 +32,7 @@ OP_CONVERTER(translate_convolution_mode);
OP_CONVERTER(translate_dim);
OP_CONVERTER(translate_div);
OP_CONVERTER(translate_elu);
OP_CONVERTER(translate_empty);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_expand);
OP_CONVERTER(translate_expand_as);
@ -177,6 +178,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::dropout_", op::skip_node},
{"aten::elu", op::translate_elu},
{"aten::embedding", op::translate_embedding},
{"aten::empty", op::translate_empty},
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten::exp", op::translate_1to1_match_1_inputs<opset10::Exp>},
{"aten::expand", op::translate_expand},

View File

@ -0,0 +1,77 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestEmptyNumeric(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 3),)
def create_model(self, dtype):
class aten_empty(torch.nn.Module):
def __init__(self, dtype) -> None:
dtype_map = {
"float32": torch.float32,
"float64": torch.float64,
"int64": torch.int64,
"int32": torch.int32,
"uint8": torch.uint8,
"int8": torch.int8
}
super().__init__()
self.dtype = dtype_map[dtype]
self.zero = torch.tensor([0], dtype=dtype_map[dtype])
def forward(self, input_tensor):
size = input_tensor.shape
empty = torch.empty(size, dtype=self.dtype)
# We don't want to compare values, just shape and type,
# so we multiply the tensor by zero.
return empty*self.zero
ref_net = None
return aten_empty(dtype), ref_net, "aten::empty"
@pytest.mark.parametrize('dtype', ("float32", "float64", "int64", "int32", "uint8", "int8"))
@pytest.mark.nightly
@pytest.mark.precommit
def test_empty(self, ie_device, precision, ir_version, dtype):
self._test(*self.create_model(dtype), ie_device, precision, ir_version)
class TestEmptyBoolean(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(3, 4, 3),)
def create_model(self):
class aten_empty(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.false = torch.tensor([False])
def forward(self, input_tensor):
size = input_tensor.shape
empty = torch.empty(size, dtype=torch.bool)
# We don't want to compare values, just shape and type,
# so we do "and" operation with False.
return empty & self.false
ref_net = None
return aten_empty(), ref_net, "aten::empty"
@pytest.mark.nightly
@pytest.mark.precommit
def test_empty_bool(self, ie_device, precision, ir_version, ):
self._test(*self.create_model(), ie_device, precision, ir_version)