[PT FE] Partially disable freezing for int8 and uint8 weights (#18827)

This commit is contained in:
Maxim Vafin 2023-07-28 13:37:01 +02:00 committed by GitHub
parent 9bf5b6effb
commit 481721e979
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 15 deletions

View File

@ -14,7 +14,7 @@ import typing
import torch
import numpy as np
wrapper_template="""
wrapper_template = """
import torch
from typing import *
@ -27,6 +27,7 @@ class ModelWrapper(torch.nn.Module):
return self.model({example_input})
"""
class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None):
Decoder.__init__(self)
@ -64,6 +65,15 @@ class TorchScriptPythonDecoder (Decoder):
self._transform_tensor_list_constants_to_listconstruct(self.graph_element)
self._transform_optional_constants(self.graph_element)
@staticmethod
def _get_preserved_attributes(model) -> list:
preserved_attributes = []
for name, module in model.named_modules():
if hasattr(module, "weight"):
if module.weight.dtype in [torch.int8, torch.uint8]:
preserved_attributes.append(name)
return preserved_attributes
def _get_scripted_model(self, pt_module, example_inputs=None):
import torch
import inspect
@ -156,12 +166,13 @@ class TorchScriptPythonDecoder (Decoder):
first_input = next(n.inputs())
if first_input.node().kind() == "prim::Constant":
ivalue = first_input.toIValue()
if ivalue is not None and ivalue.dtype in [torch.uint8, torch.int8, torch.bfloat16, torch.float16]:
if ivalue is not None and ivalue.dtype in [torch.bfloat16, torch.float16]:
# do not freeze models with compressed constants
skip_freeze = True
break
if not skip_freeze:
f_model = torch.jit.freeze(scripted)
preserved_attrs = self._get_preserved_attributes(scripted)
f_model = torch.jit.freeze(scripted, preserved_attrs=preserved_attrs)
else:
f_model = scripted
else:
@ -493,4 +504,4 @@ class TorchScriptPythonDecoder (Decoder):
const_input = graph.insertConstant(value)
const_input.node().moveBefore(node)
const_input.node().copyMetadata(node)
node.output().replaceAllUsesWith(const_input)
node.output().replaceAllUsesWith(const_input)

View File

@ -112,9 +112,7 @@ public:
/// Returns new nodes for inputs inlined in the op itself
// Used in Torch.FX decoder
virtual OutputVector inlined_inputs(size_t start_index) const {
return {};
}
virtual OutputVector inlined_inputs(size_t start_index) const = 0;
};
} // namespace pytorch

View File

@ -162,14 +162,11 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
// TODO: Eliminate duplication with the main code for Parameters creation
PartialShape ps = node->get_input_shape(i);
auto type = simplified_type_interpret(node->get_input_type(i));
// TODO: Use special API to set custom type specification
std::shared_ptr<v0::Parameter> parameter;
// TODO: Use decoder type or explore adding the missing cast types to Torchscript path
const char* torch_tracing_mode = std::getenv("PYTORCH_TRACING_MODE");
if ((torch_tracing_mode != nullptr) && std::strcmp(torch_tracing_mode, "TORCHFX") == 0)
parameter = std::make_shared<v0::Parameter>(type.as<element::Type>(), ps);
else
parameter = std::make_shared<v0::Parameter>(element::dynamic, ps);
auto dtype = element::dynamic;
if (type.is<element::Type>()) {
dtype = type.as<element::Type>();
}
auto parameter = std::make_shared<v0::Parameter>(dtype, ps);
// TODO: Missing get_input_transpose_order handling for not trivial layouts
(*tensor_map)[input] = parameter;
// set name of parameter to the index of node in the model