[PT FE] Partially disable freezing for int8 and uint8 weights (#18827)
This commit is contained in:
parent
9bf5b6effb
commit
481721e979
@ -14,7 +14,7 @@ import typing
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
wrapper_template="""
|
||||
wrapper_template = """
|
||||
import torch
|
||||
from typing import *
|
||||
|
||||
@ -27,6 +27,7 @@ class ModelWrapper(torch.nn.Module):
|
||||
return self.model({example_input})
|
||||
"""
|
||||
|
||||
|
||||
class TorchScriptPythonDecoder (Decoder):
|
||||
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None):
|
||||
Decoder.__init__(self)
|
||||
@ -64,6 +65,15 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
self._transform_tensor_list_constants_to_listconstruct(self.graph_element)
|
||||
self._transform_optional_constants(self.graph_element)
|
||||
|
||||
@staticmethod
|
||||
def _get_preserved_attributes(model) -> list:
|
||||
preserved_attributes = []
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "weight"):
|
||||
if module.weight.dtype in [torch.int8, torch.uint8]:
|
||||
preserved_attributes.append(name)
|
||||
return preserved_attributes
|
||||
|
||||
def _get_scripted_model(self, pt_module, example_inputs=None):
|
||||
import torch
|
||||
import inspect
|
||||
@ -156,12 +166,13 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
first_input = next(n.inputs())
|
||||
if first_input.node().kind() == "prim::Constant":
|
||||
ivalue = first_input.toIValue()
|
||||
if ivalue is not None and ivalue.dtype in [torch.uint8, torch.int8, torch.bfloat16, torch.float16]:
|
||||
if ivalue is not None and ivalue.dtype in [torch.bfloat16, torch.float16]:
|
||||
# do not freeze models with compressed constants
|
||||
skip_freeze = True
|
||||
break
|
||||
if not skip_freeze:
|
||||
f_model = torch.jit.freeze(scripted)
|
||||
preserved_attrs = self._get_preserved_attributes(scripted)
|
||||
f_model = torch.jit.freeze(scripted, preserved_attrs=preserved_attrs)
|
||||
else:
|
||||
f_model = scripted
|
||||
else:
|
||||
@ -493,4 +504,4 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
const_input = graph.insertConstant(value)
|
||||
const_input.node().moveBefore(node)
|
||||
const_input.node().copyMetadata(node)
|
||||
node.output().replaceAllUsesWith(const_input)
|
||||
node.output().replaceAllUsesWith(const_input)
|
||||
|
@ -112,9 +112,7 @@ public:
|
||||
|
||||
/// Returns new nodes for inputs inlined in the op itself
|
||||
// Used in Torch.FX decoder
|
||||
virtual OutputVector inlined_inputs(size_t start_index) const {
|
||||
return {};
|
||||
}
|
||||
virtual OutputVector inlined_inputs(size_t start_index) const = 0;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
|
@ -162,14 +162,11 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
// TODO: Eliminate duplication with the main code for Parameters creation
|
||||
PartialShape ps = node->get_input_shape(i);
|
||||
auto type = simplified_type_interpret(node->get_input_type(i));
|
||||
// TODO: Use special API to set custom type specification
|
||||
std::shared_ptr<v0::Parameter> parameter;
|
||||
// TODO: Use decoder type or explore adding the missing cast types to Torchscript path
|
||||
const char* torch_tracing_mode = std::getenv("PYTORCH_TRACING_MODE");
|
||||
if ((torch_tracing_mode != nullptr) && std::strcmp(torch_tracing_mode, "TORCHFX") == 0)
|
||||
parameter = std::make_shared<v0::Parameter>(type.as<element::Type>(), ps);
|
||||
else
|
||||
parameter = std::make_shared<v0::Parameter>(element::dynamic, ps);
|
||||
auto dtype = element::dynamic;
|
||||
if (type.is<element::Type>()) {
|
||||
dtype = type.as<element::Type>();
|
||||
}
|
||||
auto parameter = std::make_shared<v0::Parameter>(dtype, ps);
|
||||
// TODO: Missing get_input_transpose_order handling for not trivial layouts
|
||||
(*tensor_map)[input] = parameter;
|
||||
// set name of parameter to the index of node in the model
|
||||
|
Loading…
Reference in New Issue
Block a user