[PT FE] Add aten::empty operator with layer test (#15490)
This commit is contained in:
parent
e77c2ab6d7
commit
92788b1838
@ -159,6 +159,20 @@ OutputVector translate_new_ones(NodeContext& context) {
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, input)};
|
||||
};
|
||||
|
||||
OutputVector translate_empty(NodeContext& context) {
|
||||
auto sizes = context.get_input(0);
|
||||
// In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type.
|
||||
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
int dtype_id = 1;
|
||||
ov::Output<ov::Node> empty;
|
||||
if (!context.input_is_none(dtype_id)) {
|
||||
empty = base_translate_full_with_convert(context, sizes, value, dtype_id);
|
||||
} else {
|
||||
empty = base_translate_full(context, sizes, value);
|
||||
}
|
||||
return {empty};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
@ -32,6 +32,7 @@ OP_CONVERTER(translate_convolution_mode);
|
||||
OP_CONVERTER(translate_dim);
|
||||
OP_CONVERTER(translate_div);
|
||||
OP_CONVERTER(translate_elu);
|
||||
OP_CONVERTER(translate_empty);
|
||||
OP_CONVERTER(translate_embedding);
|
||||
OP_CONVERTER(translate_expand);
|
||||
OP_CONVERTER(translate_expand_as);
|
||||
@ -177,6 +178,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::dropout_", op::skip_node},
|
||||
{"aten::elu", op::translate_elu},
|
||||
{"aten::embedding", op::translate_embedding},
|
||||
{"aten::empty", op::translate_empty},
|
||||
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
|
||||
{"aten::exp", op::translate_1to1_match_1_inputs<opset10::Exp>},
|
||||
{"aten::expand", op::translate_expand},
|
||||
|
77
tests/layer_tests/pytorch_tests/test_empty.py
Normal file
77
tests/layer_tests/pytorch_tests/test_empty.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestEmptyNumeric(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(2, 3),)
|
||||
|
||||
def create_model(self, dtype):
|
||||
|
||||
class aten_empty(torch.nn.Module):
|
||||
|
||||
def __init__(self, dtype) -> None:
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"int64": torch.int64,
|
||||
"int32": torch.int32,
|
||||
"uint8": torch.uint8,
|
||||
"int8": torch.int8
|
||||
}
|
||||
super().__init__()
|
||||
self.dtype = dtype_map[dtype]
|
||||
self.zero = torch.tensor([0], dtype=dtype_map[dtype])
|
||||
|
||||
def forward(self, input_tensor):
|
||||
size = input_tensor.shape
|
||||
empty = torch.empty(size, dtype=self.dtype)
|
||||
# We don't want to compare values, just shape and type,
|
||||
# so we multiply the tensor by zero.
|
||||
return empty*self.zero
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_empty(dtype), ref_net, "aten::empty"
|
||||
|
||||
@pytest.mark.parametrize('dtype', ("float32", "float64", "int64", "int32", "uint8", "int8"))
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_empty(self, ie_device, precision, ir_version, dtype):
|
||||
self._test(*self.create_model(dtype), ie_device, precision, ir_version)
|
||||
|
||||
class TestEmptyBoolean(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(3, 4, 3),)
|
||||
|
||||
def create_model(self):
|
||||
|
||||
class aten_empty(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.false = torch.tensor([False])
|
||||
|
||||
def forward(self, input_tensor):
|
||||
size = input_tensor.shape
|
||||
empty = torch.empty(size, dtype=torch.bool)
|
||||
# We don't want to compare values, just shape and type,
|
||||
# so we do "and" operation with False.
|
||||
return empty & self.false
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_empty(), ref_net, "aten::empty"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_empty_bool(self, ie_device, precision, ir_version, ):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user