[PT FE] Use weight share switch in frontend (#18993)

* [PT FE] Use weight share switch in frontend

* Return static for function

* Update src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py

* Fix issue with quantized constants

* Add tests for shared
This commit is contained in:
Maxim Vafin 2023-08-10 21:28:38 +02:00 committed by GitHub
parent 726abefbaa
commit 5f71679fb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 105 additions and 37 deletions

View File

@ -7,7 +7,7 @@
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.runtime import op, PartialShape, Type as OVType, OVAny from openvino.runtime import op, PartialShape, Type as OVType, OVAny
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, torch_tensor_to_ov_const
from openvino.runtime import opset11 as ops from openvino.runtime import opset11 as ops
import typing import typing
@ -29,11 +29,12 @@ class ModelWrapper(torch.nn.Module):
class TorchScriptPythonDecoder (Decoder): class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None): def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True):
Decoder.__init__(self) Decoder.__init__(self)
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted # We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = [] self.m_decoders = []
self._input_signature = None self._input_signature = None
self._shared_memory = shared_memory
if graph_element is None: if graph_element is None:
try: try:
pt_module = self._get_scripted_model(pt_module, example_input) pt_module = self._get_scripted_model(pt_module, example_input)
@ -43,10 +44,9 @@ class TorchScriptPythonDecoder (Decoder):
help_msg = "" help_msg = ""
else: else:
msg = "scripting" msg = "scripting"
help_msg = "Tracing sometimes provide better results, " help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument. "
"please provide valid 'example_input' argument. "
raise RuntimeError( raise RuntimeError(
f"Couldn't get TorchScript module by {msg}. {help_msg}" f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n {help_msg}"
"You can also provide TorchScript module that you obtained" "You can also provide TorchScript module that you obtained"
" yourself, please refer to PyTorch documentation: " " yourself, please refer to PyTorch documentation: "
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.") "https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.")
@ -160,8 +160,11 @@ class TorchScriptPythonDecoder (Decoder):
except Exception: except Exception:
try: try:
scripted = torch.jit.script(pt_module) scripted = torch.jit.script(pt_module)
except Exception: except Exception as se:
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False) try:
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
except Exception as te:
raise f"Tracing failed with exception {te}\nScripting failed with exception: {se}"
skip_freeze = False skip_freeze = False
for n in scripted.inlined_graph.nodes(): for n in scripted.inlined_graph.nodes():
# TODO: switch off freezing for all traced models # TODO: switch off freezing for all traced models
@ -283,7 +286,7 @@ class TorchScriptPythonDecoder (Decoder):
def visit_subgraph(self, node_visitor) -> None: def visit_subgraph(self, node_visitor) -> None:
# make sure topological order is satisfied # make sure topological order is satisfied
for node in self.graph_element.nodes(): for node in self.graph_element.nodes():
decoder = TorchScriptPythonDecoder(self.pt_module, node, alias_db=self.alias_db) decoder = TorchScriptPythonDecoder(self.pt_module, node, alias_db=self.alias_db, shared_memory=self._shared_memory)
self.m_decoders.append(decoder) self.m_decoders.append(decoder)
node_visitor(decoder) node_visitor(decoder)
@ -299,7 +302,7 @@ class TorchScriptPythonDecoder (Decoder):
return list(self.graph_element.blocks()) return list(self.graph_element.blocks())
def get_subgraph_decoder(self, index: int): def get_subgraph_decoder(self, index: int):
decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index], alias_db=self.alias_db) decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index], alias_db=self.alias_db, shared_memory=self._shared_memory)
self.m_decoders.append(decoder) self.m_decoders.append(decoder)
return decoder return decoder
@ -336,7 +339,7 @@ class TorchScriptPythonDecoder (Decoder):
return node return node
@staticmethod @staticmethod
def convert_quantized_tensor(qtensor: torch.Tensor): def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
# need to represent as Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply (scale) # need to represent as Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply (scale)
qscheme = qtensor.qscheme() # torch.per_channel_affine (per_tensor) qscheme = qtensor.qscheme() # torch.per_channel_affine (per_tensor)
if qscheme == torch.per_channel_affine: if qscheme == torch.per_channel_affine:
@ -349,8 +352,8 @@ class TorchScriptPythonDecoder (Decoder):
new_shape[axis] = -1 new_shape[axis] = -1
zero_point_bc = np.reshape(zero_point, new_shape) zero_point_bc = np.reshape(zero_point, new_shape)
scale_bc = np.reshape(scale, new_shape) scale_bc = np.reshape(scale, new_shape)
int8_const = op.Constant(int8_tensor.numpy()) int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32) convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point_bc) sub = ops.subtract(convert, zero_point_bc)
return ops.multiply(sub, scale_bc).outputs() return ops.multiply(sub, scale_bc).outputs()
@ -359,7 +362,7 @@ class TorchScriptPythonDecoder (Decoder):
scale = np.float32(qtensor.q_scale()) scale = np.float32(qtensor.q_scale())
zero_point = np.float32(qtensor.q_zero_point()) zero_point = np.float32(qtensor.q_zero_point())
int8_const = op.Constant(int8_tensor.numpy()) int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32) convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point) sub = ops.subtract(convert, zero_point)
return ops.multiply(sub, scale).outputs() return ops.multiply(sub, scale).outputs()
@ -372,7 +375,7 @@ class TorchScriptPythonDecoder (Decoder):
# We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase # We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase
# TODO: but can be anything. Figure a better way to distinguish # TODO: but can be anything. Figure a better way to distinguish
weight, bias = pt_value.unpack() weight, bias = pt_value.unpack()
res = self.convert_quantized_tensor(weight) res = self.convert_quantized_tensor(weight, self._shared_memory)
if isinstance(bias, torch.Tensor): if isinstance(bias, torch.Tensor):
res += ivalue_to_constant(bias) res += ivalue_to_constant(bias)
else: else:
@ -383,12 +386,15 @@ class TorchScriptPythonDecoder (Decoder):
padding = pt_value.padding() padding = pt_value.padding()
dilation = pt_value.dilation() dilation = pt_value.dilation()
groups = pt_value.groups() groups = pt_value.groups()
res += ivalue_to_constant(stride) + ivalue_to_constant(padding) + ivalue_to_constant(dilation) + ivalue_to_constant(groups) res += ivalue_to_constant(stride, shared_memory=self._shared_memory)
res += ivalue_to_constant(padding, shared_memory=self._shared_memory)
res += ivalue_to_constant(dilation, shared_memory=self._shared_memory)
res += ivalue_to_constant(groups, shared_memory=self._shared_memory)
except: except:
pass pass
return res return res
elif not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)): elif not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)):
return ivalue_to_constant(pt_value) return ivalue_to_constant(pt_value, shared_memory=self._shared_memory)
else: else:
return [] return []
@ -400,10 +406,10 @@ class TorchScriptPythonDecoder (Decoder):
pt_value = self._raw_output(0) pt_value = self._raw_output(0)
pt_type = pt_value.type() pt_type = pt_value.type()
if isinstance(pt_type, torch.TensorType): if isinstance(pt_type, torch.TensorType):
return ivalue_to_constant(pt_value.toIValue()) return ivalue_to_constant(pt_value.toIValue(), shared_memory=self._shared_memory)
if isinstance(pt_type, torch.ListType): if isinstance(pt_type, torch.ListType):
return self._as_constant_list(pt_value) return self._as_constant_list(pt_value)
return ivalue_to_constant(pt_value.toIValue()) return ivalue_to_constant(pt_value.toIValue(), shared_memory=self._shared_memory)
def as_string(self): def as_string(self):
if self.get_op_type() == "prim::Constant": if self.get_op_type() == "prim::Constant":

View File

@ -55,7 +55,26 @@ def get_type_from_py_type(value):
return OVType.dynamic return OVType.dynamic
def ivalue_to_constant(ivalue): def torch_tensor_to_ov_const(torch_t: torch.Tensor, shared_memory=True):
torch_t = torch_t.to(memory_format=torch.contiguous_format)
if torch_t.dtype == torch.bfloat16:
# reinterpret bfloat16 data as float16 to allow conversion to numpy
torch_t = torch_t.view(torch.float16)
narr = torch_t.numpy(force=True)
if not narr.flags['C_CONTIGUOUS']:
narr = np.ascontiguousarray(narr)
# TODO: this tensor doesn't share memory with initial tensor
tensor = Tensor(narr, torch_t.shape, OVType.bf16)
ov_const = op.Constant(tensor, shared_memory=shared_memory)
else:
narr = torch_t.numpy(force=True)
if not narr.flags['C_CONTIGUOUS']:
narr = np.ascontiguousarray(narr)
ov_const = op.Constant(narr, shared_memory=shared_memory)
return ov_const
def ivalue_to_constant(ivalue, shared_memory=True):
ov_type = get_type_from_py_type(ivalue) ov_type = get_type_from_py_type(ivalue)
if ov_type.is_static(): if ov_type.is_static():
return op.Constant(ov_type, Shape([]), [ivalue]).outputs() return op.Constant(ov_type, Shape([]), [ivalue]).outputs()
@ -67,22 +86,7 @@ def ivalue_to_constant(ivalue):
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs() return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()
if isinstance(ivalue, torch.Tensor): if isinstance(ivalue, torch.Tensor):
ivalue = ivalue.to(memory_format=torch.contiguous_format) return torch_tensor_to_ov_const(ivalue, shared_memory=shared_memory).outputs()
if ivalue.dtype == torch.bfloat16:
# reinterpret bfloat16 data as float16 to allow conversion to numpy
ivalue = ivalue.view(torch.float16)
narr = ivalue.numpy(force=True)
if not narr.flags['C_CONTIGUOUS']:
narr = np.ascontiguousarray(narr)
# TODO: this tensor doesn't share memory with initial tensor
tensor = Tensor(narr, ivalue.shape, OVType.bf16)
ov_const = op.Constant(tensor, shared_memory=True)
else:
narr = ivalue.numpy(force=True)
if not narr.flags['C_CONTIGUOUS']:
narr = np.ascontiguousarray(narr)
ov_const = op.Constant(narr, shared_memory=True)
return ov_const.outputs()
return None return None
def get_value_from_getattr(getattr_node, self_module): def get_value_from_getattr(getattr_node, self_module):

View File

@ -1011,6 +1011,64 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
self._test_by_ref_graph(temp_dir, test_params, self._test_by_ref_graph(temp_dir, test_params,
graph_ref, compare_tensor_names=False) graph_ref, compare_tensor_names=False)
@ pytest.mark.precommit
def test_sharing_memory_switched_off(self, ie_device, precision, ir_version, temp_dir):
from openvino.tools.ovc import convert_model
from openvino.runtime import Core
class DataModel(torch.nn.Module):
def __init__(self):
super(DataModel, self).__init__()
self.data = torch.tensor([1, 2, 3, 4])
def forward(self, x):
return self.data, x
data_model = DataModel()
test_input = np.array([0, 0, 0, 0])
# Convert model to OV
ov_model = convert_model(data_model, input=([4], Type.i32), share_weights=False)
# Change value of variables in original model
data_model.data[0] *= 2
# Check model inference
core = Core()
cmp_model = core.compile_model(ov_model, ie_device)
ov_infer1 = cmp_model(test_input)
assert np.array_equal(ov_infer1[0], [1, 2, 3, 4])
@ pytest.mark.precommit
def test_sharing_memory_switched_on(self, ie_device, precision, ir_version, temp_dir):
from openvino.tools.ovc import convert_model
from openvino.runtime import Core
class DataModel(torch.nn.Module):
def __init__(self):
super(DataModel, self).__init__()
self.data = torch.tensor([1, 2, 3, 4])
def forward(self, x):
return self.data, x
data_model = DataModel()
test_input = np.array([0, 0, 0, 0])
# Convert model to OV
ov_model = convert_model(data_model, input=([4], Type.i32), share_weights=True)
# Change value of variables in original model
data_model.data[0] *= 2
# Check model inference
core = Core()
cmp_model = core.compile_model(ov_model, ie_device)
ov_infer1 = cmp_model(test_input)
assert np.array_equal(ov_infer1[0], [2, 2, 3, 4])
def create_pt_model_with_custom_op(): def create_pt_model_with_custom_op():
# #

View File

@ -31,7 +31,7 @@ def get_pytorch_decoder(model, input_shape, example_inputs, args):
except: except:
pass pass
inputs = prepare_torch_inputs(example_inputs) inputs = prepare_torch_inputs(example_inputs)
decoder = TorchScriptPythonDecoder(model, example_input=inputs) decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
args['input_model'] = decoder args['input_model'] = decoder
args["framework"] = "pytorch" args["framework"] = "pytorch"
args["example_input"] = inputs args["example_input"] = inputs

View File

@ -31,7 +31,7 @@ def get_pytorch_decoder(model, example_inputs, args):
except: except:
pass pass
inputs = prepare_torch_inputs(example_inputs) inputs = prepare_torch_inputs(example_inputs)
decoder = TorchScriptPythonDecoder(model, example_input=inputs) decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
args['input_model'] = decoder args['input_model'] = decoder
args["example_input"] = inputs args["example_input"] = inputs