[PT FE]: fixes type checking for freezing conditions (#19199)
This commit is contained in:
parent
71ac5ee301
commit
89956b65e3
@ -77,7 +77,7 @@ class TorchScriptPythonDecoder (Decoder):
|
|||||||
preserved_attributes = []
|
preserved_attributes = []
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
if module.weight.dtype in [torch.int8, torch.uint8]:
|
if module.weight is not None and module.weight.dtype in [torch.int8, torch.uint8]:
|
||||||
preserved_attributes.append(name)
|
preserved_attributes.append(name)
|
||||||
return preserved_attributes
|
return preserved_attributes
|
||||||
|
|
||||||
@ -176,7 +176,7 @@ class TorchScriptPythonDecoder (Decoder):
|
|||||||
first_input = next(n.inputs())
|
first_input = next(n.inputs())
|
||||||
if first_input.node().kind() == "prim::Constant":
|
if first_input.node().kind() == "prim::Constant":
|
||||||
ivalue = first_input.toIValue()
|
ivalue = first_input.toIValue()
|
||||||
if ivalue is not None and ivalue.dtype in [torch.bfloat16, torch.float16]:
|
if isinstance(ivalue, torch.Tensor) and ivalue.dtype in [torch.bfloat16, torch.float16]:
|
||||||
# do not freeze models with compressed constants
|
# do not freeze models with compressed constants
|
||||||
skip_freeze = True
|
skip_freeze = True
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user