[PT FE] Add torch.int16 dtype support (#20735)
* Add torch.int16 dtype support * Add test
This commit is contained in:
parent
82f191b0e7
commit
7b9db3d81b
@ -118,6 +118,7 @@ pt_to_ov_type_map = {
|
||||
"torch.float64": OVType.f64,
|
||||
"torch.uint8": OVType.u8,
|
||||
"torch.int8": OVType.i8,
|
||||
"torch.int16": OVType.i16,
|
||||
"torch.int32": OVType.i32,
|
||||
"torch.int64": OVType.i64,
|
||||
"torch.bool": OVType.boolean,
|
||||
|
@ -249,6 +249,28 @@ def test_pytorch_decoder_can_convert_i8_tensor():
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_i16_tensor():
|
||||
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 2], dtype=torch.int16)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.i16
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_i32_tensor():
|
||||
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
|
||||
|
Loading…
Reference in New Issue
Block a user