From 5f71679fb9ec77a9b26ee965c0bdb55ef938937c Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 10 Aug 2023 21:28:38 +0200 Subject: [PATCH] [PT FE] Use weight share switch in frontend (#18993) * [PT FE] Use weight share switch in frontend * Return static for function * Update src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py * Fix issue with quantized constants * Add tests for shared --- .../openvino/frontend/pytorch/ts_decoder.py | 42 ++++++++------ .../src/openvino/frontend/pytorch/utils.py | 38 ++++++------ .../test_mo_convert_pytorch.py | 58 +++++++++++++++++++ .../mo/moc_frontend/pytorch_frontend_utils.py | 2 +- .../moc_frontend/pytorch_frontend_utils.py | 2 +- 5 files changed, 105 insertions(+), 37 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index e07a23dc5cc..1090dce0163 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -7,7 +7,7 @@ from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType from openvino.runtime import op, PartialShape, Type as OVType, OVAny -from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map +from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, torch_tensor_to_ov_const from openvino.runtime import opset11 as ops import typing @@ -29,11 +29,12 @@ class ModelWrapper(torch.nn.Module): class TorchScriptPythonDecoder (Decoder): - def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None): + def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True): Decoder.__init__(self) # We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted self.m_decoders = [] self._input_signature = None + self._shared_memory = shared_memory if graph_element is None: try: pt_module = self._get_scripted_model(pt_module, example_input) @@ -43,10 +44,9 @@ class TorchScriptPythonDecoder (Decoder): help_msg = "" else: msg = "scripting" - help_msg = "Tracing sometimes provide better results, " - "please provide valid 'example_input' argument. " + help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument. " raise RuntimeError( - f"Couldn't get TorchScript module by {msg}. {help_msg}" + f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n {help_msg}" "You can also provide TorchScript module that you obtained" " yourself, please refer to PyTorch documentation: " "https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.") @@ -160,8 +160,11 @@ class TorchScriptPythonDecoder (Decoder): except Exception: try: scripted = torch.jit.script(pt_module) - except Exception: - scripted = torch.jit.trace(pt_module, **input_parameters, strict=False) + except Exception as se: + try: + scripted = torch.jit.trace(pt_module, **input_parameters, strict=False) + except Exception as te: + raise f"Tracing failed with exception {te}\nScripting failed with exception: {se}" skip_freeze = False for n in scripted.inlined_graph.nodes(): # TODO: switch off freezing for all traced models @@ -283,7 +286,7 @@ class TorchScriptPythonDecoder (Decoder): def visit_subgraph(self, node_visitor) -> None: # make sure topological order is satisfied for node in self.graph_element.nodes(): - decoder = TorchScriptPythonDecoder(self.pt_module, node, alias_db=self.alias_db) + decoder = TorchScriptPythonDecoder(self.pt_module, node, alias_db=self.alias_db, shared_memory=self._shared_memory) self.m_decoders.append(decoder) node_visitor(decoder) @@ -299,7 +302,7 @@ class TorchScriptPythonDecoder (Decoder): return list(self.graph_element.blocks()) def get_subgraph_decoder(self, index: int): - decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index], alias_db=self.alias_db) + decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index], alias_db=self.alias_db, shared_memory=self._shared_memory) self.m_decoders.append(decoder) return decoder @@ -336,7 +339,7 @@ class TorchScriptPythonDecoder (Decoder): return node @staticmethod - def convert_quantized_tensor(qtensor: torch.Tensor): + def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool): # need to represent as Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply (scale) qscheme = qtensor.qscheme() # torch.per_channel_affine (per_tensor) if qscheme == torch.per_channel_affine: @@ -349,8 +352,8 @@ class TorchScriptPythonDecoder (Decoder): new_shape[axis] = -1 zero_point_bc = np.reshape(zero_point, new_shape) scale_bc = np.reshape(scale, new_shape) - - int8_const = op.Constant(int8_tensor.numpy()) + + int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory) convert = ops.convert(int8_const, np.float32) sub = ops.subtract(convert, zero_point_bc) return ops.multiply(sub, scale_bc).outputs() @@ -359,7 +362,7 @@ class TorchScriptPythonDecoder (Decoder): scale = np.float32(qtensor.q_scale()) zero_point = np.float32(qtensor.q_zero_point()) - int8_const = op.Constant(int8_tensor.numpy()) + int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory) convert = ops.convert(int8_const, np.float32) sub = ops.subtract(convert, zero_point) return ops.multiply(sub, scale).outputs() @@ -372,7 +375,7 @@ class TorchScriptPythonDecoder (Decoder): # We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase # TODO: but can be anything. Figure a better way to distinguish weight, bias = pt_value.unpack() - res = self.convert_quantized_tensor(weight) + res = self.convert_quantized_tensor(weight, self._shared_memory) if isinstance(bias, torch.Tensor): res += ivalue_to_constant(bias) else: @@ -383,12 +386,15 @@ class TorchScriptPythonDecoder (Decoder): padding = pt_value.padding() dilation = pt_value.dilation() groups = pt_value.groups() - res += ivalue_to_constant(stride) + ivalue_to_constant(padding) + ivalue_to_constant(dilation) + ivalue_to_constant(groups) + res += ivalue_to_constant(stride, shared_memory=self._shared_memory) + res += ivalue_to_constant(padding, shared_memory=self._shared_memory) + res += ivalue_to_constant(dilation, shared_memory=self._shared_memory) + res += ivalue_to_constant(groups, shared_memory=self._shared_memory) except: pass return res elif not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)): - return ivalue_to_constant(pt_value) + return ivalue_to_constant(pt_value, shared_memory=self._shared_memory) else: return [] @@ -400,10 +406,10 @@ class TorchScriptPythonDecoder (Decoder): pt_value = self._raw_output(0) pt_type = pt_value.type() if isinstance(pt_type, torch.TensorType): - return ivalue_to_constant(pt_value.toIValue()) + return ivalue_to_constant(pt_value.toIValue(), shared_memory=self._shared_memory) if isinstance(pt_type, torch.ListType): return self._as_constant_list(pt_value) - return ivalue_to_constant(pt_value.toIValue()) + return ivalue_to_constant(pt_value.toIValue(), shared_memory=self._shared_memory) def as_string(self): if self.get_op_type() == "prim::Constant": diff --git a/src/bindings/python/src/openvino/frontend/pytorch/utils.py b/src/bindings/python/src/openvino/frontend/pytorch/utils.py index 7c491dc8aa2..0e7ffd66780 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/utils.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/utils.py @@ -55,7 +55,26 @@ def get_type_from_py_type(value): return OVType.dynamic -def ivalue_to_constant(ivalue): +def torch_tensor_to_ov_const(torch_t: torch.Tensor, shared_memory=True): + torch_t = torch_t.to(memory_format=torch.contiguous_format) + if torch_t.dtype == torch.bfloat16: + # reinterpret bfloat16 data as float16 to allow conversion to numpy + torch_t = torch_t.view(torch.float16) + narr = torch_t.numpy(force=True) + if not narr.flags['C_CONTIGUOUS']: + narr = np.ascontiguousarray(narr) + # TODO: this tensor doesn't share memory with initial tensor + tensor = Tensor(narr, torch_t.shape, OVType.bf16) + ov_const = op.Constant(tensor, shared_memory=shared_memory) + else: + narr = torch_t.numpy(force=True) + if not narr.flags['C_CONTIGUOUS']: + narr = np.ascontiguousarray(narr) + ov_const = op.Constant(narr, shared_memory=shared_memory) + return ov_const + + +def ivalue_to_constant(ivalue, shared_memory=True): ov_type = get_type_from_py_type(ivalue) if ov_type.is_static(): return op.Constant(ov_type, Shape([]), [ivalue]).outputs() @@ -67,22 +86,7 @@ def ivalue_to_constant(ivalue): return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs() if isinstance(ivalue, torch.Tensor): - ivalue = ivalue.to(memory_format=torch.contiguous_format) - if ivalue.dtype == torch.bfloat16: - # reinterpret bfloat16 data as float16 to allow conversion to numpy - ivalue = ivalue.view(torch.float16) - narr = ivalue.numpy(force=True) - if not narr.flags['C_CONTIGUOUS']: - narr = np.ascontiguousarray(narr) - # TODO: this tensor doesn't share memory with initial tensor - tensor = Tensor(narr, ivalue.shape, OVType.bf16) - ov_const = op.Constant(tensor, shared_memory=True) - else: - narr = ivalue.numpy(force=True) - if not narr.flags['C_CONTIGUOUS']: - narr = np.ascontiguousarray(narr) - ov_const = op.Constant(narr, shared_memory=True) - return ov_const.outputs() + return torch_tensor_to_ov_const(ivalue, shared_memory=shared_memory).outputs() return None def get_value_from_getattr(getattr_node, self_module): diff --git a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py index ae48f1ce1ae..307a1dcde3f 100644 --- a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py +++ b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py @@ -1011,6 +1011,64 @@ class TestMoConvertPyTorch(CommonMOConvertTest): self._test_by_ref_graph(temp_dir, test_params, graph_ref, compare_tensor_names=False) + @ pytest.mark.precommit + def test_sharing_memory_switched_off(self, ie_device, precision, ir_version, temp_dir): + from openvino.tools.ovc import convert_model + from openvino.runtime import Core + + class DataModel(torch.nn.Module): + def __init__(self): + super(DataModel, self).__init__() + self.data = torch.tensor([1, 2, 3, 4]) + + def forward(self, x): + return self.data, x + + data_model = DataModel() + test_input = np.array([0, 0, 0, 0]) + + # Convert model to OV + ov_model = convert_model(data_model, input=([4], Type.i32), share_weights=False) + + # Change value of variables in original model + data_model.data[0] *= 2 + + # Check model inference + core = Core() + cmp_model = core.compile_model(ov_model, ie_device) + ov_infer1 = cmp_model(test_input) + + assert np.array_equal(ov_infer1[0], [1, 2, 3, 4]) + + @ pytest.mark.precommit + def test_sharing_memory_switched_on(self, ie_device, precision, ir_version, temp_dir): + from openvino.tools.ovc import convert_model + from openvino.runtime import Core + + class DataModel(torch.nn.Module): + def __init__(self): + super(DataModel, self).__init__() + self.data = torch.tensor([1, 2, 3, 4]) + + def forward(self, x): + return self.data, x + + data_model = DataModel() + test_input = np.array([0, 0, 0, 0]) + + # Convert model to OV + ov_model = convert_model(data_model, input=([4], Type.i32), share_weights=True) + + # Change value of variables in original model + data_model.data[0] *= 2 + + # Check model inference + core = Core() + cmp_model = core.compile_model(ov_model, ie_device) + ov_infer1 = cmp_model(test_input) + + assert np.array_equal(ov_infer1[0], [2, 2, 3, 4]) + def create_pt_model_with_custom_op(): # diff --git a/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py b/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py index 7cb46d92300..8c489dc4127 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py @@ -31,7 +31,7 @@ def get_pytorch_decoder(model, input_shape, example_inputs, args): except: pass inputs = prepare_torch_inputs(example_inputs) - decoder = TorchScriptPythonDecoder(model, example_input=inputs) + decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True)) args['input_model'] = decoder args["framework"] = "pytorch" args["example_input"] = inputs diff --git a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py index 89c5ce11ae5..da2abdb21f3 100644 --- a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +++ b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py @@ -31,7 +31,7 @@ def get_pytorch_decoder(model, example_inputs, args): except: pass inputs = prepare_torch_inputs(example_inputs) - decoder = TorchScriptPythonDecoder(model, example_input=inputs) + decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True)) args['input_model'] = decoder args["example_input"] = inputs