diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index 04d95ed55c7..bbecf3014d8 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -182,8 +182,17 @@ class TorchScriptPythonDecoder (Decoder): for n in scripted.inlined_graph.nodes(): # TODO: switch off freezing for all traced models if "quantize" in n.kind(): + # do not freeze quantized models skip_freeze = True break + elif "aten::to" in n.kind(): + first_input = next(n.inputs()) + if first_input.node().kind() == "prim::Constant": + ivalue = first_input.toIValue() + if ivalue is not None and ivalue.dtype in [torch.uint8, torch.int8, torch.bfloat16, torch.float16]: + # do not freeze models with compressed constants + skip_freeze = True + break if not skip_freeze: f_model = torch.jit.freeze(scripted) else: diff --git a/src/frontends/pytorch/src/op/to.cpp b/src/frontends/pytorch/src/op/to.cpp index 2499b8346f5..f10ac4c3692 100644 --- a/src/frontends/pytorch/src/op/to.cpp +++ b/src/frontends/pytorch/src/op/to.cpp @@ -50,8 +50,7 @@ OutputVector translate_to(const NodeContext& context) { } } else if (context.get_input_size() == 8) { // aten::to(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? - // pin_memory=None, - // bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) + // pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) dtype_idx = 1; memory_format_idx = 7; if (context.input_is_none(dtype_idx)) { diff --git a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py index 10a54d3ceba..f695da13c17 100644 --- a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py +++ b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py @@ -78,9 +78,10 @@ def make_pt_model_with_optional_input(): if z is None: logits = self.linear_relu_stack(x * y) return logits - + return NeuralNetwork() + def make_ref_pt_model_one_input(shape, dtype=np.float32): shape = PartialShape(shape) param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype) @@ -105,7 +106,7 @@ def make_ref_pt_model_two_inputs(shape, dtype=np.float32): param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype) param2 = ov.opset8.parameter(shape, name="input_1", dtype=dtype) mul = ov.opset8.multiply(param1, param2) - relu = ov.opset8.relu(mul) + relu = ov.opset8.relu(mul) if dtype != np.float32: relu = ov.opset8.convert(relu, np.float32) sigm = ov.opset8.sigmoid(relu) @@ -125,9 +126,10 @@ def make_ref_pt_model_with_optional_inputs(shape, dtype=np.float32, z_exist=Fals shape = PartialShape(shape) param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype) param2 = ov.opset8.parameter(shape, name="input_1", dtype=dtype) - - op = ov.opset8.multiply(param1, param2) if not z_exist else ov.opset8.add(param1, param2) - relu = ov.opset8.relu(op) + + op = ov.opset8.multiply( + param1, param2) if not z_exist else ov.opset8.add(param1, param2) + relu = ov.opset8.relu(op) if dtype != np.float32: relu = ov.opset8.convert(relu, np.float32) sigm = ov.opset8.sigmoid(relu) @@ -217,10 +219,12 @@ def create_pytorch_nn_module_case7(tmp_dir): sample_input = torch.zeros(1, 3, 10, 10, dtype=torch.int32) - ref_model = make_ref_pt_model_one_input(PartialShape([1, 3, 20, 20]), dtype=np.int32) + ref_model = make_ref_pt_model_one_input( + PartialShape([1, 3, 20, 20]), dtype=np.int32) return pt_model, ref_model, {'example_input': sample_input, "input": ([1, 3, 20, 20], np.int32)} + def create_pytorch_nn_module_torch_size(tmp_dir): pt_model = make_pt_model_one_input() ref_model = make_ref_pt_model_one_input([1, 3, 2, 10]) @@ -277,7 +281,6 @@ def create_pytorch_jit_script_function(tmp_dir): return scripted_fn, ref_model, {'input': [(inp_shape), (inp_shape)]} - def create_pytorch_nn_module_layout_list(tmp_dir): from openvino.runtime import Layout pt_model = make_pt_model_two_inputs() @@ -290,7 +293,7 @@ def create_pytorch_nn_module_layout_list(tmp_dir): return pt_model, ref_model, { 'input_shape': [shape, shape], 'layout': ['nchw', Layout('nhwc')], - } + } def create_pytorch_nn_module_layout_list_case2(tmp_dir): @@ -370,7 +373,7 @@ def create_pytorch_nn_module_mean_list_compression_enabled(tmp_dir): ref_model = Model([sigm], parameter_list, "test") return pt_model, ref_model, { - 'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], + 'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], 'compress_to_fp16': False} @@ -425,9 +428,11 @@ def create_pytorch_nn_module_scale_list_compression_enabled(tmp_dir): param1 = ov.opset8.parameter(shape) param2 = ov.opset8.parameter(shape) const1 = ov.opset8.constant([[[[1, 1, 1]]]], dtype=np.float16) - const1_decompressed = ov.opset8.convert(const1, destination_type=np.float32) + const1_decompressed = ov.opset8.convert( + const1, destination_type=np.float32) const2 = ov.opset8.constant([[[[1, 1, 1]]]], dtype=np.float16) - const2_decompressed = ov.opset8.convert(const2, destination_type=np.float32) + const2_decompressed = ov.opset8.convert( + const2, destination_type=np.float32) mul1 = ov.opset8.multiply(param1, const1_decompressed) mul2 = ov.opset8.multiply(param2, const2_decompressed) mul3 = ov.opset8.multiply(mul1, mul2) @@ -650,37 +655,77 @@ def create_pytorch_module_convert_pytorch_frontend_oob(tmp_dir): def create_pytorch_module_with_optional_inputs_case1(tmp_dir): net = make_pt_model_with_optional_input() - example_input = {"x": torch.zeros((1,3,10,10)), "y": torch.ones((1,3,10,10))} + example_input = {"x": torch.zeros( + (1, 3, 10, 10)), "y": torch.ones((1, 3, 10, 10))} ref_model = make_ref_pt_model_with_optional_inputs([-1, -1, -1, -1]) return net, ref_model, {"example_input": example_input} def create_pytorch_module_with_optional_inputs_case2(tmp_dir): net = make_pt_model_with_optional_input() - example_input = {"x": torch.zeros((1,3,10,10)), "z": torch.ones((1,3,10,10))} - ref_model = make_ref_pt_model_with_optional_inputs([-1, -1, -1, -1], z_exist=True) + example_input = {"x": torch.zeros( + (1, 3, 10, 10)), "z": torch.ones((1, 3, 10, 10))} + ref_model = make_ref_pt_model_with_optional_inputs( + [-1, -1, -1, -1], z_exist=True) return net, ref_model, {"example_input": example_input} def create_pytorch_module_with_optional_inputs_case3(tmp_dir): net = make_pt_model_with_optional_input() - example_input = {"x": torch.zeros((1,3,10,10)), "z": torch.ones((1,3,10,10))} - ref_model = make_ref_pt_model_with_optional_inputs([3, 3, 3, 3], z_exist=True) + example_input = {"x": torch.zeros( + (1, 3, 10, 10)), "z": torch.ones((1, 3, 10, 10))} + ref_model = make_ref_pt_model_with_optional_inputs( + [3, 3, 3, 3], z_exist=True) return net, ref_model, {"example_input": example_input, "input_shape": [[3, 3, 3, 3], [3, 3, 3, 3]]} def create_pytorch_module_with_optional_inputs_case4(tmp_dir): net = make_pt_model_with_optional_input() - ref_model = make_ref_pt_model_with_optional_inputs([3, 3, 3, 3], z_exist=True) + ref_model = make_ref_pt_model_with_optional_inputs( + [3, 3, 3, 3], z_exist=True) return net, ref_model, {"input": [("x", [3, 3, 3, 3]), ("z", [3, 3, 3, 3])]} def create_pytorch_module_with_optional_inputs_case5(tmp_dir): net = make_pt_model_with_optional_input() - ref_model = make_ref_pt_model_with_optional_inputs([1, 3, -1, -1], z_exist=True) + ref_model = make_ref_pt_model_with_optional_inputs( + [1, 3, -1, -1], z_exist=True) return net, ref_model, {"input": ["x", "z"], "input_shape": [[1, 3, -1, -1], [1, 3, -1, -1]]} +def create_pytorch_module_with_compressed_int8_constant(tmp_dir): + import torch + import torch.nn.functional as F + + class Int8Model(torch.nn.Module): + def __init__(self): + super(Int8Model, self).__init__() + self.weights = torch.randint(-127, 128, + [1, 3, 3, 3], dtype=torch.int8) + + def forward(self, x): + cast = self.weights.to(torch.float32) + sub = cast - 0.5 + mul = sub * 0.02 + return F.conv2d(x, mul) + + net = Int8Model() + example_input = (torch.rand((1, 3, 10, 10)),) + traced_model = torch.jit.trace(net, example_input) + shape = [-1, -1, -1, -1] + shape = PartialShape(shape) + param1 = ov.opset10.parameter(shape, dtype=np.float32) + weights = ov.opset10.constant(net.weights.numpy(force=True)) + cast1 = ov.opset10.convert(weights, np.float32) + sub1 = ov.opset10.subtract(cast1, np.float32(0.5).reshape(1, 1, 1, 1)) + mul1 = ov.opset10.multiply(sub1, np.float32(0.02).reshape(1, 1, 1, 1)) + conv = ov.opset10.convolution(param1, mul1, strides=[1, 1], + pads_begin=[0, 0], pads_end=[0, 0], + dilations=[1, 1]) + ref_model = Model([conv], [param1], "test") + return traced_model, ref_model, {"example_input": example_input} + + class TestMoConvertPyTorch(CommonMOConvertTest): test_data = [ create_pytorch_nn_module_case1, @@ -724,6 +769,7 @@ class TestMoConvertPyTorch(CommonMOConvertTest): create_pytorch_module_with_optional_inputs_case4, create_pytorch_module_with_optional_inputs_case5, create_pytorch_nn_module_with_scalar_input, + create_pytorch_module_with_compressed_int8_constant, ] @ pytest.mark.parametrize("create_model", test_data)