[PT FE] Add aten::empty operator with layer test (#15490)
This commit is contained in:
parent
e77c2ab6d7
commit
92788b1838
@ -159,6 +159,20 @@ OutputVector translate_new_ones(NodeContext& context) {
|
|||||||
return {base_translate_full_with_convertlike(context, sizes, value, input)};
|
return {base_translate_full_with_convertlike(context, sizes, value, input)};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
OutputVector translate_empty(NodeContext& context) {
|
||||||
|
auto sizes = context.get_input(0);
|
||||||
|
// In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type.
|
||||||
|
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {0}));
|
||||||
|
int dtype_id = 1;
|
||||||
|
ov::Output<ov::Node> empty;
|
||||||
|
if (!context.input_is_none(dtype_id)) {
|
||||||
|
empty = base_translate_full_with_convert(context, sizes, value, dtype_id);
|
||||||
|
} else {
|
||||||
|
empty = base_translate_full(context, sizes, value);
|
||||||
|
}
|
||||||
|
return {empty};
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pytorch
|
} // namespace pytorch
|
||||||
} // namespace frontend
|
} // namespace frontend
|
||||||
|
@ -32,6 +32,7 @@ OP_CONVERTER(translate_convolution_mode);
|
|||||||
OP_CONVERTER(translate_dim);
|
OP_CONVERTER(translate_dim);
|
||||||
OP_CONVERTER(translate_div);
|
OP_CONVERTER(translate_div);
|
||||||
OP_CONVERTER(translate_elu);
|
OP_CONVERTER(translate_elu);
|
||||||
|
OP_CONVERTER(translate_empty);
|
||||||
OP_CONVERTER(translate_embedding);
|
OP_CONVERTER(translate_embedding);
|
||||||
OP_CONVERTER(translate_expand);
|
OP_CONVERTER(translate_expand);
|
||||||
OP_CONVERTER(translate_expand_as);
|
OP_CONVERTER(translate_expand_as);
|
||||||
@ -177,6 +178,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"aten::dropout_", op::skip_node},
|
{"aten::dropout_", op::skip_node},
|
||||||
{"aten::elu", op::translate_elu},
|
{"aten::elu", op::translate_elu},
|
||||||
{"aten::embedding", op::translate_embedding},
|
{"aten::embedding", op::translate_embedding},
|
||||||
|
{"aten::empty", op::translate_empty},
|
||||||
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
|
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
|
||||||
{"aten::exp", op::translate_1to1_match_1_inputs<opset10::Exp>},
|
{"aten::exp", op::translate_1to1_match_1_inputs<opset10::Exp>},
|
||||||
{"aten::expand", op::translate_expand},
|
{"aten::expand", op::translate_expand},
|
||||||
|
77
tests/layer_tests/pytorch_tests/test_empty.py
Normal file
77
tests/layer_tests/pytorch_tests/test_empty.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmptyNumeric(PytorchLayerTest):
|
||||||
|
|
||||||
|
def _prepare_input(self):
|
||||||
|
return (np.random.randn(2, 3),)
|
||||||
|
|
||||||
|
def create_model(self, dtype):
|
||||||
|
|
||||||
|
class aten_empty(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dtype) -> None:
|
||||||
|
dtype_map = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64,
|
||||||
|
"int64": torch.int64,
|
||||||
|
"int32": torch.int32,
|
||||||
|
"uint8": torch.uint8,
|
||||||
|
"int8": torch.int8
|
||||||
|
}
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype_map[dtype]
|
||||||
|
self.zero = torch.tensor([0], dtype=dtype_map[dtype])
|
||||||
|
|
||||||
|
def forward(self, input_tensor):
|
||||||
|
size = input_tensor.shape
|
||||||
|
empty = torch.empty(size, dtype=self.dtype)
|
||||||
|
# We don't want to compare values, just shape and type,
|
||||||
|
# so we multiply the tensor by zero.
|
||||||
|
return empty*self.zero
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_empty(dtype), ref_net, "aten::empty"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dtype', ("float32", "float64", "int64", "int32", "uint8", "int8"))
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_empty(self, ie_device, precision, ir_version, dtype):
|
||||||
|
self._test(*self.create_model(dtype), ie_device, precision, ir_version)
|
||||||
|
|
||||||
|
class TestEmptyBoolean(PytorchLayerTest):
|
||||||
|
|
||||||
|
def _prepare_input(self):
|
||||||
|
return (np.random.randn(3, 4, 3),)
|
||||||
|
|
||||||
|
def create_model(self):
|
||||||
|
|
||||||
|
class aten_empty(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.false = torch.tensor([False])
|
||||||
|
|
||||||
|
def forward(self, input_tensor):
|
||||||
|
size = input_tensor.shape
|
||||||
|
empty = torch.empty(size, dtype=torch.bool)
|
||||||
|
# We don't want to compare values, just shape and type,
|
||||||
|
# so we do "and" operation with False.
|
||||||
|
return empty & self.false
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_empty(), ref_net, "aten::empty"
|
||||||
|
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_empty_bool(self, ie_device, precision, ir_version, ):
|
||||||
|
self._test(*self.create_model(), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user