add gptq model test in pytorch hub tests (#21399)

* conflict

* refactor test

* fix for case if can not instantiate config
This commit is contained in:
Ekaterina Aidova 2023-12-01 14:56:05 +04:00 committed by GitHub
parent 0e2bde2397
commit d1f72e2d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 88 additions and 14 deletions

View File

@ -294,6 +294,7 @@ openai-gpt,openai-gpt
OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1,gpt_neox_reward_model,skip,Load problem
openmmlab/upernet-convnext-small,upernet,skip,Load problem
openMUSE/clip-vit-large-patch14-text-enc,clip_text_model,skip,Load problem
OpenVINO/opt-125m-gptq,opt
PatrickHaller/ngme-llama-264M,ngme,skip,Load problem
patrickvonplaten/bert2gpt2-cnn_dailymail-fp16,encoder_decoder,skip,Load problem
paulhindemith/test-zeroshot,test-zeroshot,skip,Load problem

View File

@ -1,10 +1,12 @@
-c ../../constraints.txt
--extra-index-url https://download.pytorch.org/whl/cpu
auto-gptq>=0.5.1
av
basicsr
datasets
facexlib
numpy
optimum
pandas
protobuf
pyctcdecode

View File

@ -12,6 +12,61 @@ from models_hub_common.utils import cleanup_dir
from torch_utils import TestTorchConvertModel
from torch_utils import process_pytest_marks
def is_gptq_model(config):
config_dict = config.to_dict() if not isinstance(config, dict) else config
quantization_config = config_dict.get("quantization_config", None)
return quantization_config and quantization_config["quant_method"] == "gptq"
def patch_gptq():
orig_cuda_check = torch.cuda.is_available
orig_post_init_model = None
torch.set_default_dtype(torch.float32)
torch.cuda.is_available = lambda: True
from optimum.gptq import GPTQQuantizer
orig_post_init_model = GPTQQuantizer.post_init_model
def post_init_model(self, model):
from auto_gptq import exllama_set_max_input_length
class StoreAttr(object):
pass
model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
model = exllama_set_max_input_length(model, self.max_input_length)
return model
GPTQQuantizer.post_init_model = post_init_model
return orig_cuda_check, orig_post_init_model
def unpatch_gptq(orig_cuda_check, orig_post_init_model):
from optimum.gptq import GPTQQuantizer
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model
def flattenize_tuples(list_input):
unpacked_pt_res = []
for r in list_input:
if isinstance(r, (tuple, list)):
unpacked_pt_res.extend(flattenize_tuples(r))
else:
unpacked_pt_res.append(r)
return unpacked_pt_res
def flattenize_outputs(outputs):
if not isinstance(outputs, dict):
outputs = flattenize_tuples(outputs)
return [i.numpy(force=True) for i in outputs]
else:
return dict((k, v.numpy(force=True)) for k, v in outputs.items())
def filter_example(model, example):
try:
@ -41,12 +96,23 @@ class TestTransformersModel(TestTorchConvertModel):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
self.image = Image.open(requests.get(url, stream=True).raw)
self.cuda_available, self.gptq_postinit = None, None
def load_model(self, name, type):
from transformers import AutoConfig
mi = model_info(name)
auto_processor = None
model = None
example = None
try:
config = AutoConfig.from_pretrained(name)
except Exception:
config = {}
is_gptq = is_gptq_model(config)
model_kwargs = {"torchscript": True}
if is_gptq:
self.cuda_available, self.gptq_postinit = patch_gptq()
model_kwargs["torch_dtype"] = torch.float32
try:
auto_model = mi.transformersInfo['auto_model']
if "processor" in mi.transformersInfo:
@ -73,7 +139,7 @@ class TestTransformersModel(TestTorchConvertModel):
elif "vit-gpt2" in name:
from transformers import VisionEncoderDecoderModel, ViTImageProcessor
model = VisionEncoderDecoderModel.from_pretrained(
name, torchscript=True)
name, **model_kwargs)
feature_extractor = ViTImageProcessor.from_pretrained(name)
encoded_input = feature_extractor(
images=[self.image], return_tensors="pt")
@ -90,7 +156,7 @@ class TestTransformersModel(TestTorchConvertModel):
example = (encoded_input.pixel_values,)
elif 'pix2struct' in mi.tags:
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
model = Pix2StructForConditionalGeneration.from_pretrained(name)
model = Pix2StructForConditionalGeneration.from_pretrained(name, **model_kwargs)
processor = AutoProcessor.from_pretrained(name)
import requests
@ -112,7 +178,7 @@ class TestTransformersModel(TestTorchConvertModel):
# mms-lid model config does not have auto_model attribute, only direct loading available
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
model = Wav2Vec2ForSequenceClassification.from_pretrained(
name, torchscript=True)
name, **model_kwargs)
processor = AutoFeatureExtractor.from_pretrained(name)
input_values = processor(torch.randn(16000).numpy(),
sampling_rate=16_000,
@ -178,7 +244,7 @@ class TestTransformersModel(TestTorchConvertModel):
elif 'musicgen' in mi.tags:
from transformers import AutoProcessor, AutoModelForTextToWaveform
processor = AutoProcessor.from_pretrained(name)
model = AutoModelForTextToWaveform.from_pretrained(name, torchscript=True)
model = AutoModelForTextToWaveform.from_pretrained(name, **model_kwargs)
inputs = processor(
text=["80s pop track with bassy drums and synth"],
@ -193,10 +259,10 @@ class TestTransformersModel(TestTorchConvertModel):
else:
try:
if auto_model == "AutoModelForCausalLM":
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(
name, torchscript=True)
name, **model_kwargs)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
inputs_dict = dict(encoded_input)
@ -207,7 +273,7 @@ class TestTransformersModel(TestTorchConvertModel):
from transformers import AutoTokenizer, AutoModelForMaskedLM
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForMaskedLM.from_pretrained(
name, torchscript=True)
name, **model_kwargs)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
example = dict(encoded_input)
@ -215,7 +281,7 @@ class TestTransformersModel(TestTorchConvertModel):
from transformers import AutoProcessor, AutoModelForImageClassification
processor = AutoProcessor.from_pretrained(name)
model = AutoModelForImageClassification.from_pretrained(
name, torchscript=True)
name, **model_kwargs)
encoded_input = processor(
images=self.image, return_tensors="pt")
example = dict(encoded_input)
@ -223,7 +289,7 @@ class TestTransformersModel(TestTorchConvertModel):
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForSeq2SeqLM.from_pretrained(
name, torchscript=True)
name, **model_kwargs)
inputs = tokenizer(
"Studies have been shown that owning a dog is good for you", return_tensors="pt")
decoder_inputs = tokenizer(
@ -238,7 +304,7 @@ class TestTransformersModel(TestTorchConvertModel):
from datasets import load_dataset
processor = AutoProcessor.from_pretrained(name)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
name, torchscript=True)
name, **model_kwargs)
inputs = processor(torch.randn(1000).numpy(),
sampling_rate=16000,
return_tensors="pt")
@ -248,7 +314,7 @@ class TestTransformersModel(TestTorchConvertModel):
from datasets import load_dataset
processor = AutoProcessor.from_pretrained(name)
model = AutoModelForCTC.from_pretrained(
name, torchscript=True)
name, **model_kwargs)
input_values = processor(torch.randn(1000).numpy(),
return_tensors="pt")
example = dict(input_values)
@ -257,7 +323,7 @@ class TestTransformersModel(TestTorchConvertModel):
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForTableQuestionAnswering.from_pretrained(
name, torchscript=True)
name, **model_kwargs)
data = {"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
"Number of movies": ["87", "53", "69"]}
queries = ["What is the name of the first actor?",
@ -287,7 +353,7 @@ class TestTransformersModel(TestTorchConvertModel):
pass
if model is None:
from transformers import AutoModel
model = AutoModel.from_pretrained(name, torchscript=True)
model = AutoModel.from_pretrained(name, **model_kwargs)
if hasattr(model, "set_default_language"):
model.set_default_language("en_XX")
if example is None:
@ -307,6 +373,10 @@ class TestTransformersModel(TestTorchConvertModel):
def teardown_method(self):
# remove all downloaded files from cache
cleanup_dir(hf_hub_cache_dir)
# restore after gptq patching
if self.cuda_available is not None:
unpatch_gptq(self.cuda_available, self.gptq_postinit)
self.cuda_available, self.gptq_postinit = None, None
super().teardown_method()
@pytest.mark.parametrize("name,type", [("allenai/led-base-16384", "led"),
@ -314,7 +384,8 @@ class TestTransformersModel(TestTorchConvertModel):
("google/flan-t5-base", "t5"),
("google/tapas-large-finetuned-wtq", "tapas"),
("gpt2", "gpt2"),
("openai/clip-vit-large-patch14", "clip")
("openai/clip-vit-large-patch14", "clip"),
("OpenVINO/opt-125m-gptq", 'opt')
])
@pytest.mark.precommit
def test_convert_model_precommit(self, name, type, ie_device):