diff --git a/tests/model_hub_tests/torch_tests/hf_transformers_models b/tests/model_hub_tests/torch_tests/hf_transformers_models index cf699be6289..73e884d49ea 100644 --- a/tests/model_hub_tests/torch_tests/hf_transformers_models +++ b/tests/model_hub_tests/torch_tests/hf_transformers_models @@ -294,6 +294,7 @@ openai-gpt,openai-gpt OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1,gpt_neox_reward_model,skip,Load problem openmmlab/upernet-convnext-small,upernet,skip,Load problem openMUSE/clip-vit-large-patch14-text-enc,clip_text_model,skip,Load problem +OpenVINO/opt-125m-gptq,opt PatrickHaller/ngme-llama-264M,ngme,skip,Load problem patrickvonplaten/bert2gpt2-cnn_dailymail-fp16,encoder_decoder,skip,Load problem paulhindemith/test-zeroshot,test-zeroshot,skip,Load problem diff --git a/tests/model_hub_tests/torch_tests/requirements.txt b/tests/model_hub_tests/torch_tests/requirements.txt index 25a4a0b4f4b..af6f4cf2212 100644 --- a/tests/model_hub_tests/torch_tests/requirements.txt +++ b/tests/model_hub_tests/torch_tests/requirements.txt @@ -1,10 +1,12 @@ -c ../../constraints.txt --extra-index-url https://download.pytorch.org/whl/cpu +auto-gptq>=0.5.1 av basicsr datasets facexlib numpy +optimum pandas protobuf pyctcdecode diff --git a/tests/model_hub_tests/torch_tests/test_hf_transformers.py b/tests/model_hub_tests/torch_tests/test_hf_transformers.py index e13fa5a6b9c..24d878408a5 100644 --- a/tests/model_hub_tests/torch_tests/test_hf_transformers.py +++ b/tests/model_hub_tests/torch_tests/test_hf_transformers.py @@ -12,6 +12,61 @@ from models_hub_common.utils import cleanup_dir from torch_utils import TestTorchConvertModel from torch_utils import process_pytest_marks +def is_gptq_model(config): + config_dict = config.to_dict() if not isinstance(config, dict) else config + quantization_config = config_dict.get("quantization_config", None) + return quantization_config and quantization_config["quant_method"] == "gptq" + + +def patch_gptq(): + orig_cuda_check = torch.cuda.is_available + orig_post_init_model = None + torch.set_default_dtype(torch.float32) + torch.cuda.is_available = lambda: True + + from optimum.gptq import GPTQQuantizer + + orig_post_init_model = GPTQQuantizer.post_init_model + + def post_init_model(self, model): + from auto_gptq import exllama_set_max_input_length + + class StoreAttr(object): + pass + + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + if self.desc_act and not self.disable_exllama and self.max_input_length is not None: + model = exllama_set_max_input_length(model, self.max_input_length) + return model + + GPTQQuantizer.post_init_model = post_init_model + return orig_cuda_check, orig_post_init_model + + +def unpatch_gptq(orig_cuda_check, orig_post_init_model): + from optimum.gptq import GPTQQuantizer + torch.cuda.is_available = orig_cuda_check + GPTQQuantizer.post_init_model = orig_post_init_model + + +def flattenize_tuples(list_input): + unpacked_pt_res = [] + for r in list_input: + if isinstance(r, (tuple, list)): + unpacked_pt_res.extend(flattenize_tuples(r)) + else: + unpacked_pt_res.append(r) + return unpacked_pt_res + + +def flattenize_outputs(outputs): + if not isinstance(outputs, dict): + outputs = flattenize_tuples(outputs) + return [i.numpy(force=True) for i in outputs] + else: + return dict((k, v.numpy(force=True)) for k, v in outputs.items()) + def filter_example(model, example): try: @@ -41,12 +96,23 @@ class TestTransformersModel(TestTorchConvertModel): url = "http://images.cocodataset.org/val2017/000000039769.jpg" self.image = Image.open(requests.get(url, stream=True).raw) + self.cuda_available, self.gptq_postinit = None, None def load_model(self, name, type): + from transformers import AutoConfig mi = model_info(name) auto_processor = None model = None example = None + try: + config = AutoConfig.from_pretrained(name) + except Exception: + config = {} + is_gptq = is_gptq_model(config) + model_kwargs = {"torchscript": True} + if is_gptq: + self.cuda_available, self.gptq_postinit = patch_gptq() + model_kwargs["torch_dtype"] = torch.float32 try: auto_model = mi.transformersInfo['auto_model'] if "processor" in mi.transformersInfo: @@ -73,7 +139,7 @@ class TestTransformersModel(TestTorchConvertModel): elif "vit-gpt2" in name: from transformers import VisionEncoderDecoderModel, ViTImageProcessor model = VisionEncoderDecoderModel.from_pretrained( - name, torchscript=True) + name, **model_kwargs) feature_extractor = ViTImageProcessor.from_pretrained(name) encoded_input = feature_extractor( images=[self.image], return_tensors="pt") @@ -90,7 +156,7 @@ class TestTransformersModel(TestTorchConvertModel): example = (encoded_input.pixel_values,) elif 'pix2struct' in mi.tags: from transformers import AutoProcessor, Pix2StructForConditionalGeneration - model = Pix2StructForConditionalGeneration.from_pretrained(name) + model = Pix2StructForConditionalGeneration.from_pretrained(name, **model_kwargs) processor = AutoProcessor.from_pretrained(name) import requests @@ -112,7 +178,7 @@ class TestTransformersModel(TestTorchConvertModel): # mms-lid model config does not have auto_model attribute, only direct loading available from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor model = Wav2Vec2ForSequenceClassification.from_pretrained( - name, torchscript=True) + name, **model_kwargs) processor = AutoFeatureExtractor.from_pretrained(name) input_values = processor(torch.randn(16000).numpy(), sampling_rate=16_000, @@ -178,7 +244,7 @@ class TestTransformersModel(TestTorchConvertModel): elif 'musicgen' in mi.tags: from transformers import AutoProcessor, AutoModelForTextToWaveform processor = AutoProcessor.from_pretrained(name) - model = AutoModelForTextToWaveform.from_pretrained(name, torchscript=True) + model = AutoModelForTextToWaveform.from_pretrained(name, **model_kwargs) inputs = processor( text=["80s pop track with bassy drums and synth"], @@ -193,10 +259,10 @@ class TestTransformersModel(TestTorchConvertModel): else: try: if auto_model == "AutoModelForCausalLM": - from transformers import AutoTokenizer, AutoModelForCausalLM + from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained(name) model = AutoModelForCausalLM.from_pretrained( - name, torchscript=True) + name, **model_kwargs) text = "Replace me by any text you'd like." encoded_input = tokenizer(text, return_tensors='pt') inputs_dict = dict(encoded_input) @@ -207,7 +273,7 @@ class TestTransformersModel(TestTorchConvertModel): from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained(name) model = AutoModelForMaskedLM.from_pretrained( - name, torchscript=True) + name, **model_kwargs) text = "Replace me by any text you'd like." encoded_input = tokenizer(text, return_tensors='pt') example = dict(encoded_input) @@ -215,7 +281,7 @@ class TestTransformersModel(TestTorchConvertModel): from transformers import AutoProcessor, AutoModelForImageClassification processor = AutoProcessor.from_pretrained(name) model = AutoModelForImageClassification.from_pretrained( - name, torchscript=True) + name, **model_kwargs) encoded_input = processor( images=self.image, return_tensors="pt") example = dict(encoded_input) @@ -223,7 +289,7 @@ class TestTransformersModel(TestTorchConvertModel): from transformers import AutoTokenizer, AutoModelForSeq2SeqLM tokenizer = AutoTokenizer.from_pretrained(name) model = AutoModelForSeq2SeqLM.from_pretrained( - name, torchscript=True) + name, **model_kwargs) inputs = tokenizer( "Studies have been shown that owning a dog is good for you", return_tensors="pt") decoder_inputs = tokenizer( @@ -238,7 +304,7 @@ class TestTransformersModel(TestTorchConvertModel): from datasets import load_dataset processor = AutoProcessor.from_pretrained(name) model = AutoModelForSpeechSeq2Seq.from_pretrained( - name, torchscript=True) + name, **model_kwargs) inputs = processor(torch.randn(1000).numpy(), sampling_rate=16000, return_tensors="pt") @@ -248,7 +314,7 @@ class TestTransformersModel(TestTorchConvertModel): from datasets import load_dataset processor = AutoProcessor.from_pretrained(name) model = AutoModelForCTC.from_pretrained( - name, torchscript=True) + name, **model_kwargs) input_values = processor(torch.randn(1000).numpy(), return_tensors="pt") example = dict(input_values) @@ -257,7 +323,7 @@ class TestTransformersModel(TestTorchConvertModel): from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering tokenizer = AutoTokenizer.from_pretrained(name) model = AutoModelForTableQuestionAnswering.from_pretrained( - name, torchscript=True) + name, **model_kwargs) data = {"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], "Number of movies": ["87", "53", "69"]} queries = ["What is the name of the first actor?", @@ -287,7 +353,7 @@ class TestTransformersModel(TestTorchConvertModel): pass if model is None: from transformers import AutoModel - model = AutoModel.from_pretrained(name, torchscript=True) + model = AutoModel.from_pretrained(name, **model_kwargs) if hasattr(model, "set_default_language"): model.set_default_language("en_XX") if example is None: @@ -307,6 +373,10 @@ class TestTransformersModel(TestTorchConvertModel): def teardown_method(self): # remove all downloaded files from cache cleanup_dir(hf_hub_cache_dir) + # restore after gptq patching + if self.cuda_available is not None: + unpatch_gptq(self.cuda_available, self.gptq_postinit) + self.cuda_available, self.gptq_postinit = None, None super().teardown_method() @pytest.mark.parametrize("name,type", [("allenai/led-base-16384", "led"), @@ -314,7 +384,8 @@ class TestTransformersModel(TestTorchConvertModel): ("google/flan-t5-base", "t5"), ("google/tapas-large-finetuned-wtq", "tapas"), ("gpt2", "gpt2"), - ("openai/clip-vit-large-patch14", "clip") + ("openai/clip-vit-large-patch14", "clip"), + ("OpenVINO/opt-125m-gptq", 'opt') ]) @pytest.mark.precommit def test_convert_model_precommit(self, name, type, ie_device):