[PT FE] Add torch.int16 dtype support (#20735)

* Add torch.int16 dtype support

* Add test
This commit is contained in:
Maxim Vafin 2023-10-30 10:12:55 +01:00 committed by GitHub
parent 82f191b0e7
commit 7b9db3d81b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 0 deletions

View File

@ -118,6 +118,7 @@ pt_to_ov_type_map = {
"torch.float64": OVType.f64,
"torch.uint8": OVType.u8,
"torch.int8": OVType.i8,
"torch.int16": OVType.i16,
"torch.int32": OVType.i32,
"torch.int64": OVType.i64,
"torch.bool": OVType.boolean,

View File

@ -249,6 +249,28 @@ def test_pytorch_decoder_can_convert_i8_tensor():
assert ov_const[0].get_partial_shape() == PartialShape([2])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_i16_tensor():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.runtime import PartialShape, Type
class SomeTensor(torch.nn.Module):
def forward(self):
return torch.tensor([1, 2], dtype=torch.int16)
model = get_scripted_model(SomeTensor())
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
"prim::Constant"]
assert len(consts) > 0
some_const = consts[0]
nc_decoder = TorchScriptPythonDecoder(model, some_const)
ov_const = nc_decoder.as_constant()
assert ov_const is not None
assert len(ov_const) == 1
assert ov_const[0].get_element_type() == Type.i16
assert ov_const[0].get_partial_shape() == PartialShape([2])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_i32_tensor():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder