[PT FE] Add aten::new_empty (#16312)
* Add new_empty * Remove duplicated code for new_empty
This commit is contained in:
parent
c472b020b7
commit
134ebb8889
@ -65,7 +65,7 @@ OutputVector translate_full_like(NodeContext& context) {
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.get_input(1);
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
if (context.get_input_size() == 7) {
|
||||
if (context.get_input_size() == 7 && !context.input_is_none(2)) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 2)};
|
||||
}
|
||||
auto out = context.input_is_none(3) ? input : context.get_input(3);
|
||||
@ -113,7 +113,7 @@ OutputVector translate_zeros_like(NodeContext& context) {
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
if (context.get_input_size() == 6) {
|
||||
if (context.get_input_size() == 6 && !context.input_is_none(1)) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 1)};
|
||||
}
|
||||
auto out = context.input_is_none(2) ? input : context.get_input(2);
|
||||
@ -153,7 +153,7 @@ OutputVector translate_ones_like(NodeContext& context) {
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
if (context.get_input_size() == 6) {
|
||||
if (context.get_input_size() == 6 && !context.input_is_none(1)) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 1)};
|
||||
}
|
||||
auto out = context.input_is_none(2) ? input : context.get_input(2);
|
||||
@ -172,7 +172,7 @@ OutputVector translate_new_ones(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_empty(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
num_inputs_check(context, 1, 5);
|
||||
auto sizes = context.get_input(0);
|
||||
// In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type.
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
@ -185,7 +185,6 @@ OutputVector translate_empty(NodeContext& context) {
|
||||
}
|
||||
return {empty};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
@ -254,6 +254,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
|
||||
{"aten::narrow", op::translate_narrow},
|
||||
{"aten::ne", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
|
||||
{"aten::neg", op::translate_neg},
|
||||
{"aten::new_empty", op::translate_new_zeros},
|
||||
{"aten::new_full", op::translate_new_full},
|
||||
{"aten::new_ones", op::translate_new_ones},
|
||||
{"aten::new_zeros", op::translate_new_zeros},
|
||||
|
@ -11,7 +11,7 @@ from pytorch_layer_test_class import PytorchLayerTest
|
||||
class TestEmptyNumeric(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(2, 3),)
|
||||
return (np.random.randn(10, 10, 10),)
|
||||
|
||||
def create_model(self, dtype):
|
||||
|
||||
@ -28,14 +28,14 @@ class TestEmptyNumeric(PytorchLayerTest):
|
||||
}
|
||||
super().__init__()
|
||||
self.dtype = dtype_map[dtype]
|
||||
self.zero = torch.tensor([0], dtype=dtype_map[dtype])
|
||||
|
||||
def forward(self, input_tensor):
|
||||
size = input_tensor.shape
|
||||
empty = torch.empty(size, dtype=self.dtype)
|
||||
# We don't want to compare values, just shape and type,
|
||||
# so we multiply the tensor by zero.
|
||||
return empty*self.zero
|
||||
# so we call zeros_like on data. Multiplying by zero would
|
||||
# produce sporadic errors if nan would be in empty.
|
||||
return torch.zeros_like(empty)
|
||||
|
||||
ref_net = None
|
||||
|
||||
@ -50,7 +50,7 @@ class TestEmptyNumeric(PytorchLayerTest):
|
||||
class TestEmptyBoolean(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(3, 4, 3),)
|
||||
return (np.random.randn(10, 10, 10),)
|
||||
|
||||
def create_model(self):
|
||||
|
||||
@ -75,3 +75,71 @@ class TestEmptyBoolean(PytorchLayerTest):
|
||||
@pytest.mark.precommit
|
||||
def test_empty_bool(self, ie_device, precision, ir_version, ):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
class TestNewEmpty(PytorchLayerTest):
|
||||
def _prepare_input(self, input_dtype=np.float32):
|
||||
return (np.random.randn(1, 3, 10, 10).astype(input_dtype),)
|
||||
|
||||
def create_model(self, shape, dtype=None, used_dtype=False):
|
||||
import torch
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"int64": torch.int64,
|
||||
"int32": torch.int32,
|
||||
"uint8": torch.uint8,
|
||||
"int8": torch.int8,
|
||||
"bool": torch.bool
|
||||
}
|
||||
|
||||
class aten_empty(torch.nn.Module):
|
||||
def __init__(self, shape):
|
||||
super(aten_empty, self).__init__()
|
||||
self.shape = shape
|
||||
self.zero = torch.tensor([0])
|
||||
|
||||
def forward(self, input_tensor: torch.Tensor):
|
||||
empty = input_tensor.new_empty(self.shape)
|
||||
# We don't want to compare values, just shape and type,
|
||||
# so we call zeros_like on data. Multiplying by zero would
|
||||
# produce sporadic errors if nan would be in empty.
|
||||
return torch.zeros_like(empty)
|
||||
|
||||
class aten_empty_with_dtype(torch.nn.Module):
|
||||
def __init__(self, shape, dtype):
|
||||
super(aten_empty_with_dtype, self).__init__()
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.zero = torch.tensor([0], dtype=self.dtype)
|
||||
|
||||
def forward(self, input_tensor: torch.Tensor):
|
||||
empty = input_tensor.new_empty(self.shape, dtype=self.dtype)
|
||||
# We don't want to compare values, just shape and type,
|
||||
# so we call zeros_like on data. Multiplying by zero would
|
||||
# produce sporadic errors if nan would be in empty.
|
||||
return torch.zeros_like(empty)
|
||||
|
||||
ref_net = None
|
||||
model = aten_empty(shape)
|
||||
|
||||
if used_dtype:
|
||||
dtype = dtype_map[dtype]
|
||||
model = aten_empty_with_dtype(shape, dtype)
|
||||
|
||||
return model, ref_net, "aten::new_empty"
|
||||
|
||||
@pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]])
|
||||
@pytest.mark.parametrize("input_dtype", [np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_new_empty(self, shape, input_dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(shape), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={'input_dtype': input_dtype})
|
||||
|
||||
@pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]])
|
||||
@pytest.mark.parametrize("input_dtype", [bool, np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64])
|
||||
@pytest.mark.parametrize("dtype", ["bool", "uint8", "int8", "int32", "int64", "float32", "float64"])
|
||||
@pytest.mark.nightly
|
||||
def test_new_empty_with_dtype(self, shape, dtype, input_dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={'input_dtype': input_dtype})
|
||||
|
Loading…
Reference in New Issue
Block a user