Do not freeze models with compressed constants (#18505)
* Do not freeze models with compressed constants * Add test
This commit is contained in:
parent
acb14d5d6b
commit
1d9be8c76e
@ -182,6 +182,15 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
for n in scripted.inlined_graph.nodes():
|
||||
# TODO: switch off freezing for all traced models
|
||||
if "quantize" in n.kind():
|
||||
# do not freeze quantized models
|
||||
skip_freeze = True
|
||||
break
|
||||
elif "aten::to" in n.kind():
|
||||
first_input = next(n.inputs())
|
||||
if first_input.node().kind() == "prim::Constant":
|
||||
ivalue = first_input.toIValue()
|
||||
if ivalue is not None and ivalue.dtype in [torch.uint8, torch.int8, torch.bfloat16, torch.float16]:
|
||||
# do not freeze models with compressed constants
|
||||
skip_freeze = True
|
||||
break
|
||||
if not skip_freeze:
|
||||
|
@ -50,8 +50,7 @@ OutputVector translate_to(const NodeContext& context) {
|
||||
}
|
||||
} else if (context.get_input_size() == 8) {
|
||||
// aten::to(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
|
||||
// pin_memory=None,
|
||||
// bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None)
|
||||
// pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None)
|
||||
dtype_idx = 1;
|
||||
memory_format_idx = 7;
|
||||
if (context.input_is_none(dtype_idx)) {
|
||||
|
@ -81,6 +81,7 @@ def make_pt_model_with_optional_input():
|
||||
|
||||
return NeuralNetwork()
|
||||
|
||||
|
||||
def make_ref_pt_model_one_input(shape, dtype=np.float32):
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype)
|
||||
@ -126,7 +127,8 @@ def make_ref_pt_model_with_optional_inputs(shape, dtype=np.float32, z_exist=Fals
|
||||
param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype)
|
||||
param2 = ov.opset8.parameter(shape, name="input_1", dtype=dtype)
|
||||
|
||||
op = ov.opset8.multiply(param1, param2) if not z_exist else ov.opset8.add(param1, param2)
|
||||
op = ov.opset8.multiply(
|
||||
param1, param2) if not z_exist else ov.opset8.add(param1, param2)
|
||||
relu = ov.opset8.relu(op)
|
||||
if dtype != np.float32:
|
||||
relu = ov.opset8.convert(relu, np.float32)
|
||||
@ -217,10 +219,12 @@ def create_pytorch_nn_module_case7(tmp_dir):
|
||||
|
||||
sample_input = torch.zeros(1, 3, 10, 10, dtype=torch.int32)
|
||||
|
||||
ref_model = make_ref_pt_model_one_input(PartialShape([1, 3, 20, 20]), dtype=np.int32)
|
||||
ref_model = make_ref_pt_model_one_input(
|
||||
PartialShape([1, 3, 20, 20]), dtype=np.int32)
|
||||
|
||||
return pt_model, ref_model, {'example_input': sample_input, "input": ([1, 3, 20, 20], np.int32)}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_torch_size(tmp_dir):
|
||||
pt_model = make_pt_model_one_input()
|
||||
ref_model = make_ref_pt_model_one_input([1, 3, 2, 10])
|
||||
@ -277,7 +281,6 @@ def create_pytorch_jit_script_function(tmp_dir):
|
||||
return scripted_fn, ref_model, {'input': [(inp_shape), (inp_shape)]}
|
||||
|
||||
|
||||
|
||||
def create_pytorch_nn_module_layout_list(tmp_dir):
|
||||
from openvino.runtime import Layout
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
@ -425,9 +428,11 @@ def create_pytorch_nn_module_scale_list_compression_enabled(tmp_dir):
|
||||
param1 = ov.opset8.parameter(shape)
|
||||
param2 = ov.opset8.parameter(shape)
|
||||
const1 = ov.opset8.constant([[[[1, 1, 1]]]], dtype=np.float16)
|
||||
const1_decompressed = ov.opset8.convert(const1, destination_type=np.float32)
|
||||
const1_decompressed = ov.opset8.convert(
|
||||
const1, destination_type=np.float32)
|
||||
const2 = ov.opset8.constant([[[[1, 1, 1]]]], dtype=np.float16)
|
||||
const2_decompressed = ov.opset8.convert(const2, destination_type=np.float32)
|
||||
const2_decompressed = ov.opset8.convert(
|
||||
const2, destination_type=np.float32)
|
||||
mul1 = ov.opset8.multiply(param1, const1_decompressed)
|
||||
mul2 = ov.opset8.multiply(param2, const2_decompressed)
|
||||
mul3 = ov.opset8.multiply(mul1, mul2)
|
||||
@ -650,37 +655,77 @@ def create_pytorch_module_convert_pytorch_frontend_oob(tmp_dir):
|
||||
|
||||
def create_pytorch_module_with_optional_inputs_case1(tmp_dir):
|
||||
net = make_pt_model_with_optional_input()
|
||||
example_input = {"x": torch.zeros((1,3,10,10)), "y": torch.ones((1,3,10,10))}
|
||||
example_input = {"x": torch.zeros(
|
||||
(1, 3, 10, 10)), "y": torch.ones((1, 3, 10, 10))}
|
||||
ref_model = make_ref_pt_model_with_optional_inputs([-1, -1, -1, -1])
|
||||
return net, ref_model, {"example_input": example_input}
|
||||
|
||||
|
||||
def create_pytorch_module_with_optional_inputs_case2(tmp_dir):
|
||||
net = make_pt_model_with_optional_input()
|
||||
example_input = {"x": torch.zeros((1,3,10,10)), "z": torch.ones((1,3,10,10))}
|
||||
ref_model = make_ref_pt_model_with_optional_inputs([-1, -1, -1, -1], z_exist=True)
|
||||
example_input = {"x": torch.zeros(
|
||||
(1, 3, 10, 10)), "z": torch.ones((1, 3, 10, 10))}
|
||||
ref_model = make_ref_pt_model_with_optional_inputs(
|
||||
[-1, -1, -1, -1], z_exist=True)
|
||||
return net, ref_model, {"example_input": example_input}
|
||||
|
||||
|
||||
def create_pytorch_module_with_optional_inputs_case3(tmp_dir):
|
||||
net = make_pt_model_with_optional_input()
|
||||
example_input = {"x": torch.zeros((1,3,10,10)), "z": torch.ones((1,3,10,10))}
|
||||
ref_model = make_ref_pt_model_with_optional_inputs([3, 3, 3, 3], z_exist=True)
|
||||
example_input = {"x": torch.zeros(
|
||||
(1, 3, 10, 10)), "z": torch.ones((1, 3, 10, 10))}
|
||||
ref_model = make_ref_pt_model_with_optional_inputs(
|
||||
[3, 3, 3, 3], z_exist=True)
|
||||
return net, ref_model, {"example_input": example_input, "input_shape": [[3, 3, 3, 3], [3, 3, 3, 3]]}
|
||||
|
||||
|
||||
def create_pytorch_module_with_optional_inputs_case4(tmp_dir):
|
||||
net = make_pt_model_with_optional_input()
|
||||
ref_model = make_ref_pt_model_with_optional_inputs([3, 3, 3, 3], z_exist=True)
|
||||
ref_model = make_ref_pt_model_with_optional_inputs(
|
||||
[3, 3, 3, 3], z_exist=True)
|
||||
return net, ref_model, {"input": [("x", [3, 3, 3, 3]), ("z", [3, 3, 3, 3])]}
|
||||
|
||||
|
||||
def create_pytorch_module_with_optional_inputs_case5(tmp_dir):
|
||||
net = make_pt_model_with_optional_input()
|
||||
ref_model = make_ref_pt_model_with_optional_inputs([1, 3, -1, -1], z_exist=True)
|
||||
ref_model = make_ref_pt_model_with_optional_inputs(
|
||||
[1, 3, -1, -1], z_exist=True)
|
||||
return net, ref_model, {"input": ["x", "z"], "input_shape": [[1, 3, -1, -1], [1, 3, -1, -1]]}
|
||||
|
||||
|
||||
def create_pytorch_module_with_compressed_int8_constant(tmp_dir):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Int8Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Int8Model, self).__init__()
|
||||
self.weights = torch.randint(-127, 128,
|
||||
[1, 3, 3, 3], dtype=torch.int8)
|
||||
|
||||
def forward(self, x):
|
||||
cast = self.weights.to(torch.float32)
|
||||
sub = cast - 0.5
|
||||
mul = sub * 0.02
|
||||
return F.conv2d(x, mul)
|
||||
|
||||
net = Int8Model()
|
||||
example_input = (torch.rand((1, 3, 10, 10)),)
|
||||
traced_model = torch.jit.trace(net, example_input)
|
||||
shape = [-1, -1, -1, -1]
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset10.parameter(shape, dtype=np.float32)
|
||||
weights = ov.opset10.constant(net.weights.numpy(force=True))
|
||||
cast1 = ov.opset10.convert(weights, np.float32)
|
||||
sub1 = ov.opset10.subtract(cast1, np.float32(0.5).reshape(1, 1, 1, 1))
|
||||
mul1 = ov.opset10.multiply(sub1, np.float32(0.02).reshape(1, 1, 1, 1))
|
||||
conv = ov.opset10.convolution(param1, mul1, strides=[1, 1],
|
||||
pads_begin=[0, 0], pads_end=[0, 0],
|
||||
dilations=[1, 1])
|
||||
ref_model = Model([conv], [param1], "test")
|
||||
return traced_model, ref_model, {"example_input": example_input}
|
||||
|
||||
|
||||
class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
test_data = [
|
||||
create_pytorch_nn_module_case1,
|
||||
@ -724,6 +769,7 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
create_pytorch_module_with_optional_inputs_case4,
|
||||
create_pytorch_module_with_optional_inputs_case5,
|
||||
create_pytorch_nn_module_with_scalar_input,
|
||||
create_pytorch_module_with_compressed_int8_constant,
|
||||
]
|
||||
|
||||
@ pytest.mark.parametrize("create_model", test_data)
|
||||
|
Loading…
Reference in New Issue
Block a user