[PT FE] Use weight share switch in frontend (#18993)
* [PT FE] Use weight share switch in frontend * Return static for function * Update src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py * Fix issue with quantized constants * Add tests for shared
This commit is contained in:
parent
726abefbaa
commit
5f71679fb9
@ -7,7 +7,7 @@
|
||||
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
|
||||
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
||||
from openvino.runtime import op, PartialShape, Type as OVType, OVAny
|
||||
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map
|
||||
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, torch_tensor_to_ov_const
|
||||
from openvino.runtime import opset11 as ops
|
||||
|
||||
import typing
|
||||
@ -29,11 +29,12 @@ class ModelWrapper(torch.nn.Module):
|
||||
|
||||
|
||||
class TorchScriptPythonDecoder (Decoder):
|
||||
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None):
|
||||
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True):
|
||||
Decoder.__init__(self)
|
||||
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
|
||||
self.m_decoders = []
|
||||
self._input_signature = None
|
||||
self._shared_memory = shared_memory
|
||||
if graph_element is None:
|
||||
try:
|
||||
pt_module = self._get_scripted_model(pt_module, example_input)
|
||||
@ -43,10 +44,9 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
help_msg = ""
|
||||
else:
|
||||
msg = "scripting"
|
||||
help_msg = "Tracing sometimes provide better results, "
|
||||
"please provide valid 'example_input' argument. "
|
||||
help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument. "
|
||||
raise RuntimeError(
|
||||
f"Couldn't get TorchScript module by {msg}. {help_msg}"
|
||||
f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n {help_msg}"
|
||||
"You can also provide TorchScript module that you obtained"
|
||||
" yourself, please refer to PyTorch documentation: "
|
||||
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.")
|
||||
@ -160,8 +160,11 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
except Exception:
|
||||
try:
|
||||
scripted = torch.jit.script(pt_module)
|
||||
except Exception:
|
||||
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
|
||||
except Exception as se:
|
||||
try:
|
||||
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
|
||||
except Exception as te:
|
||||
raise f"Tracing failed with exception {te}\nScripting failed with exception: {se}"
|
||||
skip_freeze = False
|
||||
for n in scripted.inlined_graph.nodes():
|
||||
# TODO: switch off freezing for all traced models
|
||||
@ -283,7 +286,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
def visit_subgraph(self, node_visitor) -> None:
|
||||
# make sure topological order is satisfied
|
||||
for node in self.graph_element.nodes():
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, node, alias_db=self.alias_db)
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, node, alias_db=self.alias_db, shared_memory=self._shared_memory)
|
||||
self.m_decoders.append(decoder)
|
||||
node_visitor(decoder)
|
||||
|
||||
@ -299,7 +302,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return list(self.graph_element.blocks())
|
||||
|
||||
def get_subgraph_decoder(self, index: int):
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index], alias_db=self.alias_db)
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index], alias_db=self.alias_db, shared_memory=self._shared_memory)
|
||||
self.m_decoders.append(decoder)
|
||||
return decoder
|
||||
|
||||
@ -336,7 +339,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def convert_quantized_tensor(qtensor: torch.Tensor):
|
||||
def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
|
||||
# need to represent as Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply (scale)
|
||||
qscheme = qtensor.qscheme() # torch.per_channel_affine (per_tensor)
|
||||
if qscheme == torch.per_channel_affine:
|
||||
@ -350,7 +353,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
zero_point_bc = np.reshape(zero_point, new_shape)
|
||||
scale_bc = np.reshape(scale, new_shape)
|
||||
|
||||
int8_const = op.Constant(int8_tensor.numpy())
|
||||
int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory)
|
||||
convert = ops.convert(int8_const, np.float32)
|
||||
sub = ops.subtract(convert, zero_point_bc)
|
||||
return ops.multiply(sub, scale_bc).outputs()
|
||||
@ -359,7 +362,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
scale = np.float32(qtensor.q_scale())
|
||||
zero_point = np.float32(qtensor.q_zero_point())
|
||||
|
||||
int8_const = op.Constant(int8_tensor.numpy())
|
||||
int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory)
|
||||
convert = ops.convert(int8_const, np.float32)
|
||||
sub = ops.subtract(convert, zero_point)
|
||||
return ops.multiply(sub, scale).outputs()
|
||||
@ -372,7 +375,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
# We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase
|
||||
# TODO: but can be anything. Figure a better way to distinguish
|
||||
weight, bias = pt_value.unpack()
|
||||
res = self.convert_quantized_tensor(weight)
|
||||
res = self.convert_quantized_tensor(weight, self._shared_memory)
|
||||
if isinstance(bias, torch.Tensor):
|
||||
res += ivalue_to_constant(bias)
|
||||
else:
|
||||
@ -383,12 +386,15 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
padding = pt_value.padding()
|
||||
dilation = pt_value.dilation()
|
||||
groups = pt_value.groups()
|
||||
res += ivalue_to_constant(stride) + ivalue_to_constant(padding) + ivalue_to_constant(dilation) + ivalue_to_constant(groups)
|
||||
res += ivalue_to_constant(stride, shared_memory=self._shared_memory)
|
||||
res += ivalue_to_constant(padding, shared_memory=self._shared_memory)
|
||||
res += ivalue_to_constant(dilation, shared_memory=self._shared_memory)
|
||||
res += ivalue_to_constant(groups, shared_memory=self._shared_memory)
|
||||
except:
|
||||
pass
|
||||
return res
|
||||
elif not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)):
|
||||
return ivalue_to_constant(pt_value)
|
||||
return ivalue_to_constant(pt_value, shared_memory=self._shared_memory)
|
||||
else:
|
||||
return []
|
||||
|
||||
@ -400,10 +406,10 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
pt_value = self._raw_output(0)
|
||||
pt_type = pt_value.type()
|
||||
if isinstance(pt_type, torch.TensorType):
|
||||
return ivalue_to_constant(pt_value.toIValue())
|
||||
return ivalue_to_constant(pt_value.toIValue(), shared_memory=self._shared_memory)
|
||||
if isinstance(pt_type, torch.ListType):
|
||||
return self._as_constant_list(pt_value)
|
||||
return ivalue_to_constant(pt_value.toIValue())
|
||||
return ivalue_to_constant(pt_value.toIValue(), shared_memory=self._shared_memory)
|
||||
|
||||
def as_string(self):
|
||||
if self.get_op_type() == "prim::Constant":
|
||||
|
@ -55,7 +55,26 @@ def get_type_from_py_type(value):
|
||||
return OVType.dynamic
|
||||
|
||||
|
||||
def ivalue_to_constant(ivalue):
|
||||
def torch_tensor_to_ov_const(torch_t: torch.Tensor, shared_memory=True):
|
||||
torch_t = torch_t.to(memory_format=torch.contiguous_format)
|
||||
if torch_t.dtype == torch.bfloat16:
|
||||
# reinterpret bfloat16 data as float16 to allow conversion to numpy
|
||||
torch_t = torch_t.view(torch.float16)
|
||||
narr = torch_t.numpy(force=True)
|
||||
if not narr.flags['C_CONTIGUOUS']:
|
||||
narr = np.ascontiguousarray(narr)
|
||||
# TODO: this tensor doesn't share memory with initial tensor
|
||||
tensor = Tensor(narr, torch_t.shape, OVType.bf16)
|
||||
ov_const = op.Constant(tensor, shared_memory=shared_memory)
|
||||
else:
|
||||
narr = torch_t.numpy(force=True)
|
||||
if not narr.flags['C_CONTIGUOUS']:
|
||||
narr = np.ascontiguousarray(narr)
|
||||
ov_const = op.Constant(narr, shared_memory=shared_memory)
|
||||
return ov_const
|
||||
|
||||
|
||||
def ivalue_to_constant(ivalue, shared_memory=True):
|
||||
ov_type = get_type_from_py_type(ivalue)
|
||||
if ov_type.is_static():
|
||||
return op.Constant(ov_type, Shape([]), [ivalue]).outputs()
|
||||
@ -67,22 +86,7 @@ def ivalue_to_constant(ivalue):
|
||||
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()
|
||||
|
||||
if isinstance(ivalue, torch.Tensor):
|
||||
ivalue = ivalue.to(memory_format=torch.contiguous_format)
|
||||
if ivalue.dtype == torch.bfloat16:
|
||||
# reinterpret bfloat16 data as float16 to allow conversion to numpy
|
||||
ivalue = ivalue.view(torch.float16)
|
||||
narr = ivalue.numpy(force=True)
|
||||
if not narr.flags['C_CONTIGUOUS']:
|
||||
narr = np.ascontiguousarray(narr)
|
||||
# TODO: this tensor doesn't share memory with initial tensor
|
||||
tensor = Tensor(narr, ivalue.shape, OVType.bf16)
|
||||
ov_const = op.Constant(tensor, shared_memory=True)
|
||||
else:
|
||||
narr = ivalue.numpy(force=True)
|
||||
if not narr.flags['C_CONTIGUOUS']:
|
||||
narr = np.ascontiguousarray(narr)
|
||||
ov_const = op.Constant(narr, shared_memory=True)
|
||||
return ov_const.outputs()
|
||||
return torch_tensor_to_ov_const(ivalue, shared_memory=shared_memory).outputs()
|
||||
return None
|
||||
|
||||
def get_value_from_getattr(getattr_node, self_module):
|
||||
|
@ -1011,6 +1011,64 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
self._test_by_ref_graph(temp_dir, test_params,
|
||||
graph_ref, compare_tensor_names=False)
|
||||
|
||||
@ pytest.mark.precommit
|
||||
def test_sharing_memory_switched_off(self, ie_device, precision, ir_version, temp_dir):
|
||||
from openvino.tools.ovc import convert_model
|
||||
from openvino.runtime import Core
|
||||
|
||||
class DataModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(DataModel, self).__init__()
|
||||
self.data = torch.tensor([1, 2, 3, 4])
|
||||
|
||||
def forward(self, x):
|
||||
return self.data, x
|
||||
|
||||
data_model = DataModel()
|
||||
test_input = np.array([0, 0, 0, 0])
|
||||
|
||||
# Convert model to OV
|
||||
ov_model = convert_model(data_model, input=([4], Type.i32), share_weights=False)
|
||||
|
||||
# Change value of variables in original model
|
||||
data_model.data[0] *= 2
|
||||
|
||||
# Check model inference
|
||||
core = Core()
|
||||
cmp_model = core.compile_model(ov_model, ie_device)
|
||||
ov_infer1 = cmp_model(test_input)
|
||||
|
||||
assert np.array_equal(ov_infer1[0], [1, 2, 3, 4])
|
||||
|
||||
@ pytest.mark.precommit
|
||||
def test_sharing_memory_switched_on(self, ie_device, precision, ir_version, temp_dir):
|
||||
from openvino.tools.ovc import convert_model
|
||||
from openvino.runtime import Core
|
||||
|
||||
class DataModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(DataModel, self).__init__()
|
||||
self.data = torch.tensor([1, 2, 3, 4])
|
||||
|
||||
def forward(self, x):
|
||||
return self.data, x
|
||||
|
||||
data_model = DataModel()
|
||||
test_input = np.array([0, 0, 0, 0])
|
||||
|
||||
# Convert model to OV
|
||||
ov_model = convert_model(data_model, input=([4], Type.i32), share_weights=True)
|
||||
|
||||
# Change value of variables in original model
|
||||
data_model.data[0] *= 2
|
||||
|
||||
# Check model inference
|
||||
core = Core()
|
||||
cmp_model = core.compile_model(ov_model, ie_device)
|
||||
ov_infer1 = cmp_model(test_input)
|
||||
|
||||
assert np.array_equal(ov_infer1[0], [2, 2, 3, 4])
|
||||
|
||||
|
||||
def create_pt_model_with_custom_op():
|
||||
#
|
||||
|
@ -31,7 +31,7 @@ def get_pytorch_decoder(model, input_shape, example_inputs, args):
|
||||
except:
|
||||
pass
|
||||
inputs = prepare_torch_inputs(example_inputs)
|
||||
decoder = TorchScriptPythonDecoder(model, example_input=inputs)
|
||||
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
|
||||
args['input_model'] = decoder
|
||||
args["framework"] = "pytorch"
|
||||
args["example_input"] = inputs
|
||||
|
@ -31,7 +31,7 @@ def get_pytorch_decoder(model, example_inputs, args):
|
||||
except:
|
||||
pass
|
||||
inputs = prepare_torch_inputs(example_inputs)
|
||||
decoder = TorchScriptPythonDecoder(model, example_input=inputs)
|
||||
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
|
||||
args['input_model'] = decoder
|
||||
args["example_input"] = inputs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user