add gptq model test in pytorch hub tests (#21399)
* conflict * refactor test * fix for case if can not instantiate config
This commit is contained in:
parent
0e2bde2397
commit
d1f72e2d01
@ -294,6 +294,7 @@ openai-gpt,openai-gpt
|
||||
OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1,gpt_neox_reward_model,skip,Load problem
|
||||
openmmlab/upernet-convnext-small,upernet,skip,Load problem
|
||||
openMUSE/clip-vit-large-patch14-text-enc,clip_text_model,skip,Load problem
|
||||
OpenVINO/opt-125m-gptq,opt
|
||||
PatrickHaller/ngme-llama-264M,ngme,skip,Load problem
|
||||
patrickvonplaten/bert2gpt2-cnn_dailymail-fp16,encoder_decoder,skip,Load problem
|
||||
paulhindemith/test-zeroshot,test-zeroshot,skip,Load problem
|
||||
|
@ -1,10 +1,12 @@
|
||||
-c ../../constraints.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
auto-gptq>=0.5.1
|
||||
av
|
||||
basicsr
|
||||
datasets
|
||||
facexlib
|
||||
numpy
|
||||
optimum
|
||||
pandas
|
||||
protobuf
|
||||
pyctcdecode
|
||||
|
@ -12,6 +12,61 @@ from models_hub_common.utils import cleanup_dir
|
||||
from torch_utils import TestTorchConvertModel
|
||||
from torch_utils import process_pytest_marks
|
||||
|
||||
def is_gptq_model(config):
|
||||
config_dict = config.to_dict() if not isinstance(config, dict) else config
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
return quantization_config and quantization_config["quant_method"] == "gptq"
|
||||
|
||||
|
||||
def patch_gptq():
|
||||
orig_cuda_check = torch.cuda.is_available
|
||||
orig_post_init_model = None
|
||||
torch.set_default_dtype(torch.float32)
|
||||
torch.cuda.is_available = lambda: True
|
||||
|
||||
from optimum.gptq import GPTQQuantizer
|
||||
|
||||
orig_post_init_model = GPTQQuantizer.post_init_model
|
||||
|
||||
def post_init_model(self, model):
|
||||
from auto_gptq import exllama_set_max_input_length
|
||||
|
||||
class StoreAttr(object):
|
||||
pass
|
||||
|
||||
model.quantize_config = StoreAttr()
|
||||
model.quantize_config.desc_act = self.desc_act
|
||||
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
|
||||
model = exllama_set_max_input_length(model, self.max_input_length)
|
||||
return model
|
||||
|
||||
GPTQQuantizer.post_init_model = post_init_model
|
||||
return orig_cuda_check, orig_post_init_model
|
||||
|
||||
|
||||
def unpatch_gptq(orig_cuda_check, orig_post_init_model):
|
||||
from optimum.gptq import GPTQQuantizer
|
||||
torch.cuda.is_available = orig_cuda_check
|
||||
GPTQQuantizer.post_init_model = orig_post_init_model
|
||||
|
||||
|
||||
def flattenize_tuples(list_input):
|
||||
unpacked_pt_res = []
|
||||
for r in list_input:
|
||||
if isinstance(r, (tuple, list)):
|
||||
unpacked_pt_res.extend(flattenize_tuples(r))
|
||||
else:
|
||||
unpacked_pt_res.append(r)
|
||||
return unpacked_pt_res
|
||||
|
||||
|
||||
def flattenize_outputs(outputs):
|
||||
if not isinstance(outputs, dict):
|
||||
outputs = flattenize_tuples(outputs)
|
||||
return [i.numpy(force=True) for i in outputs]
|
||||
else:
|
||||
return dict((k, v.numpy(force=True)) for k, v in outputs.items())
|
||||
|
||||
|
||||
def filter_example(model, example):
|
||||
try:
|
||||
@ -41,12 +96,23 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
self.image = Image.open(requests.get(url, stream=True).raw)
|
||||
self.cuda_available, self.gptq_postinit = None, None
|
||||
|
||||
def load_model(self, name, type):
|
||||
from transformers import AutoConfig
|
||||
mi = model_info(name)
|
||||
auto_processor = None
|
||||
model = None
|
||||
example = None
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(name)
|
||||
except Exception:
|
||||
config = {}
|
||||
is_gptq = is_gptq_model(config)
|
||||
model_kwargs = {"torchscript": True}
|
||||
if is_gptq:
|
||||
self.cuda_available, self.gptq_postinit = patch_gptq()
|
||||
model_kwargs["torch_dtype"] = torch.float32
|
||||
try:
|
||||
auto_model = mi.transformersInfo['auto_model']
|
||||
if "processor" in mi.transformersInfo:
|
||||
@ -73,7 +139,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
elif "vit-gpt2" in name:
|
||||
from transformers import VisionEncoderDecoderModel, ViTImageProcessor
|
||||
model = VisionEncoderDecoderModel.from_pretrained(
|
||||
name, torchscript=True)
|
||||
name, **model_kwargs)
|
||||
feature_extractor = ViTImageProcessor.from_pretrained(name)
|
||||
encoded_input = feature_extractor(
|
||||
images=[self.image], return_tensors="pt")
|
||||
@ -90,7 +156,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
example = (encoded_input.pixel_values,)
|
||||
elif 'pix2struct' in mi.tags:
|
||||
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
|
||||
model = Pix2StructForConditionalGeneration.from_pretrained(name)
|
||||
model = Pix2StructForConditionalGeneration.from_pretrained(name, **model_kwargs)
|
||||
processor = AutoProcessor.from_pretrained(name)
|
||||
|
||||
import requests
|
||||
@ -112,7 +178,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
# mms-lid model config does not have auto_model attribute, only direct loading available
|
||||
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained(
|
||||
name, torchscript=True)
|
||||
name, **model_kwargs)
|
||||
processor = AutoFeatureExtractor.from_pretrained(name)
|
||||
input_values = processor(torch.randn(16000).numpy(),
|
||||
sampling_rate=16_000,
|
||||
@ -178,7 +244,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
elif 'musicgen' in mi.tags:
|
||||
from transformers import AutoProcessor, AutoModelForTextToWaveform
|
||||
processor = AutoProcessor.from_pretrained(name)
|
||||
model = AutoModelForTextToWaveform.from_pretrained(name, torchscript=True)
|
||||
model = AutoModelForTextToWaveform.from_pretrained(name, **model_kwargs)
|
||||
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
@ -193,10 +259,10 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
else:
|
||||
try:
|
||||
if auto_model == "AutoModelForCausalLM":
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
name, torchscript=True)
|
||||
name, **model_kwargs)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text, return_tensors='pt')
|
||||
inputs_dict = dict(encoded_input)
|
||||
@ -207,7 +273,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||
tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
name, torchscript=True)
|
||||
name, **model_kwargs)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text, return_tensors='pt')
|
||||
example = dict(encoded_input)
|
||||
@ -215,7 +281,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
from transformers import AutoProcessor, AutoModelForImageClassification
|
||||
processor = AutoProcessor.from_pretrained(name)
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
name, torchscript=True)
|
||||
name, **model_kwargs)
|
||||
encoded_input = processor(
|
||||
images=self.image, return_tensors="pt")
|
||||
example = dict(encoded_input)
|
||||
@ -223,7 +289,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
name, torchscript=True)
|
||||
name, **model_kwargs)
|
||||
inputs = tokenizer(
|
||||
"Studies have been shown that owning a dog is good for you", return_tensors="pt")
|
||||
decoder_inputs = tokenizer(
|
||||
@ -238,7 +304,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
from datasets import load_dataset
|
||||
processor = AutoProcessor.from_pretrained(name)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
name, torchscript=True)
|
||||
name, **model_kwargs)
|
||||
inputs = processor(torch.randn(1000).numpy(),
|
||||
sampling_rate=16000,
|
||||
return_tensors="pt")
|
||||
@ -248,7 +314,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
from datasets import load_dataset
|
||||
processor = AutoProcessor.from_pretrained(name)
|
||||
model = AutoModelForCTC.from_pretrained(
|
||||
name, torchscript=True)
|
||||
name, **model_kwargs)
|
||||
input_values = processor(torch.randn(1000).numpy(),
|
||||
return_tensors="pt")
|
||||
example = dict(input_values)
|
||||
@ -257,7 +323,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
|
||||
tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
model = AutoModelForTableQuestionAnswering.from_pretrained(
|
||||
name, torchscript=True)
|
||||
name, **model_kwargs)
|
||||
data = {"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
|
||||
"Number of movies": ["87", "53", "69"]}
|
||||
queries = ["What is the name of the first actor?",
|
||||
@ -287,7 +353,7 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
pass
|
||||
if model is None:
|
||||
from transformers import AutoModel
|
||||
model = AutoModel.from_pretrained(name, torchscript=True)
|
||||
model = AutoModel.from_pretrained(name, **model_kwargs)
|
||||
if hasattr(model, "set_default_language"):
|
||||
model.set_default_language("en_XX")
|
||||
if example is None:
|
||||
@ -307,6 +373,10 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
def teardown_method(self):
|
||||
# remove all downloaded files from cache
|
||||
cleanup_dir(hf_hub_cache_dir)
|
||||
# restore after gptq patching
|
||||
if self.cuda_available is not None:
|
||||
unpatch_gptq(self.cuda_available, self.gptq_postinit)
|
||||
self.cuda_available, self.gptq_postinit = None, None
|
||||
super().teardown_method()
|
||||
|
||||
@pytest.mark.parametrize("name,type", [("allenai/led-base-16384", "led"),
|
||||
@ -314,7 +384,8 @@ class TestTransformersModel(TestTorchConvertModel):
|
||||
("google/flan-t5-base", "t5"),
|
||||
("google/tapas-large-finetuned-wtq", "tapas"),
|
||||
("gpt2", "gpt2"),
|
||||
("openai/clip-vit-large-patch14", "clip")
|
||||
("openai/clip-vit-large-patch14", "clip"),
|
||||
("OpenVINO/opt-125m-gptq", 'opt')
|
||||
])
|
||||
@pytest.mark.precommit
|
||||
def test_convert_model_precommit(self, name, type, ie_device):
|
||||
|
Loading…
Reference in New Issue
Block a user