[PT FE] Unify hub tests, add marks for timm models (#21283)
* Unify hub tests, add marks for timm models * Fix hf tests
This commit is contained in:
parent
4fdbb2d4e8
commit
e52f922d35
@ -8,7 +8,7 @@ import tempfile
|
||||
import torch
|
||||
import pytest
|
||||
import subprocess
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from torch_utils import TestTorchConvertModel
|
||||
from openvino import convert_model, Model, PartialShape, Type
|
||||
import openvino.runtime.opset12 as ops
|
||||
from openvino.frontend import ConversionExtension
|
||||
@ -71,12 +71,13 @@ def read_image(path, idx):
|
||||
return img_tensor.unsqueeze_(0)
|
||||
|
||||
|
||||
class TestAlikedConvertModel(TestConvertModel):
|
||||
class TestAlikedConvertModel(TestTorchConvertModel):
|
||||
def setup_class(self):
|
||||
self.repo_dir = tempfile.TemporaryDirectory()
|
||||
os.system(
|
||||
f"git clone https://github.com/mvafin/ALIKED.git {self.repo_dir.name}")
|
||||
subprocess.check_call(["git", "checkout", "6008af43942925eec7e32006814ef41fbd0858d8"], cwd=self.repo_dir.name)
|
||||
subprocess.check_call(
|
||||
["git", "checkout", "6008af43942925eec7e32006814ef41fbd0858d8"], cwd=self.repo_dir.name)
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install",
|
||||
"-r", os.path.join(self.repo_dir.name, "requirements.txt")])
|
||||
subprocess.check_call(["sh", "build.sh"], cwd=os.path.join(
|
||||
@ -92,15 +93,9 @@ class TestAlikedConvertModel(TestConvertModel):
|
||||
self.example = (img_tensor,)
|
||||
img_tensor2 = read_image(os.path.join(
|
||||
self.repo_dir.name, "assets", "st_pauls_cathedral"), 2)
|
||||
self.input = (img_tensor2,)
|
||||
self.inputs = (img_tensor2,)
|
||||
return m
|
||||
|
||||
def get_inputs_info(self, model_obj):
|
||||
return None
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
return [i.numpy() for i in self.input]
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
m = convert_model(model_obj,
|
||||
example_input=self.example,
|
||||
|
@ -4,12 +4,11 @@
|
||||
import os
|
||||
import pytest
|
||||
import torch
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from openvino import convert_model
|
||||
from torch_utils import TestTorchConvertModel, process_pytest_marks
|
||||
from models_hub_common.utils import get_models_list, compare_two_tensors
|
||||
|
||||
|
||||
class TestDetectron2ConvertModel(TestConvertModel):
|
||||
class TestDetectron2ConvertModel(TestTorchConvertModel):
|
||||
def setup_class(self):
|
||||
from PIL import Image
|
||||
import requests
|
||||
@ -53,16 +52,6 @@ class TestDetectron2ConvertModel(TestConvertModel):
|
||||
self.example = adapter.flattened_inputs
|
||||
return adapter
|
||||
|
||||
def get_inputs_info(self, model_obj):
|
||||
return None
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
return [i.numpy() for i in self.example]
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
ov_model = convert_model(model_obj, example_input=self.example)
|
||||
return ov_model
|
||||
|
||||
def infer_fw_model(self, model_obj, inputs):
|
||||
fw_outputs = model_obj(*[torch.from_numpy(i) for i in inputs])
|
||||
if isinstance(fw_outputs, dict):
|
||||
@ -98,7 +87,7 @@ class TestDetectron2ConvertModel(TestConvertModel):
|
||||
self.run(name, None, ie_device)
|
||||
|
||||
@pytest.mark.parametrize("name",
|
||||
[pytest.param(n, marks=pytest.mark.xfail(reason=r)) if m == "xfail" else n for n, _, m, r in get_models_list(os.path.join(os.path.dirname(__file__), "detectron2_models"))])
|
||||
process_pytest_marks(os.path.join(os.path.dirname(__file__), "detectron2_models")))
|
||||
@pytest.mark.nightly
|
||||
def test_detectron2_all_models(self, name, ie_device):
|
||||
self.run(name, None, ie_device)
|
||||
|
@ -5,28 +5,10 @@ import os
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import model_info
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from openvino import convert_model
|
||||
from models_hub_common.utils import get_models_list, cleanup_dir
|
||||
from torch_utils import TestTorchConvertModel
|
||||
from models_hub_common.utils import cleanup_dir
|
||||
from models_hub_common.constants import hf_hub_cache_dir
|
||||
|
||||
|
||||
def flattenize_tuples(list_input):
|
||||
unpacked_pt_res = []
|
||||
for r in list_input:
|
||||
if isinstance(r, (tuple, list)):
|
||||
unpacked_pt_res.extend(flattenize_tuples(r))
|
||||
else:
|
||||
unpacked_pt_res.append(r)
|
||||
return unpacked_pt_res
|
||||
|
||||
|
||||
def flattenize_outputs(outputs):
|
||||
if not isinstance(outputs, dict):
|
||||
outputs = flattenize_tuples(outputs)
|
||||
return [i.numpy(force=True) for i in outputs]
|
||||
else:
|
||||
return dict((k, v.numpy(force=True)) for k, v in outputs.items())
|
||||
from torch_utils import process_pytest_marks
|
||||
|
||||
|
||||
def filter_example(model, example):
|
||||
@ -48,7 +30,7 @@ def filter_example(model, example):
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class TestTransformersModel(TestConvertModel):
|
||||
class TestTransformersModel(TestTorchConvertModel):
|
||||
def setup_class(self):
|
||||
from PIL import Image
|
||||
import requests
|
||||
@ -105,11 +87,14 @@ class TestTransformersModel(TestConvertModel):
|
||||
model = VIT_GPT2_Model(model)
|
||||
example = (encoded_input.pixel_values,)
|
||||
elif "mms-lid" in name:
|
||||
# mms-lid model config does not have auto_model attribute, only direct loading aviable
|
||||
# mms-lid model config does not have auto_model attribute, only direct loading available
|
||||
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained(name, torchscript=True)
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained(
|
||||
name, torchscript=True)
|
||||
processor = AutoFeatureExtractor.from_pretrained(name)
|
||||
input_values = processor(torch.randn(16000).numpy(), sampling_rate=16_000, return_tensors="pt")
|
||||
input_values = processor(torch.randn(16000).numpy(),
|
||||
sampling_rate=16_000,
|
||||
return_tensors="pt")
|
||||
example = {"input_values": input_values.input_values}
|
||||
elif "retribert" in mi.tags:
|
||||
from transformers import RetriBertTokenizer
|
||||
@ -211,7 +196,9 @@ class TestTransformersModel(TestConvertModel):
|
||||
processor = AutoProcessor.from_pretrained(name)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
name, torchscript=True)
|
||||
inputs = processor(torch.randn(1000).numpy(), sampling_rate=16000, return_tensors="pt")
|
||||
inputs = processor(torch.randn(1000).numpy(),
|
||||
sampling_rate=16000,
|
||||
return_tensors="pt")
|
||||
example = dict(inputs)
|
||||
elif auto_model == "AutoModelForCTC":
|
||||
from transformers import AutoProcessor, AutoModelForCTC
|
||||
@ -219,7 +206,8 @@ class TestTransformersModel(TestConvertModel):
|
||||
processor = AutoProcessor.from_pretrained(name)
|
||||
model = AutoModelForCTC.from_pretrained(
|
||||
name, torchscript=True)
|
||||
input_values = processor(torch.randn(1000).numpy(), return_tensors="pt")
|
||||
input_values = processor(torch.randn(1000).numpy(),
|
||||
return_tensors="pt")
|
||||
example = dict(input_values)
|
||||
elif auto_model == "AutoModelForTableQuestionAnswering":
|
||||
import pandas as pd
|
||||
@ -273,29 +261,6 @@ class TestTransformersModel(TestConvertModel):
|
||||
model(*self.example)
|
||||
return model
|
||||
|
||||
def get_inputs_info(self, model_obj):
|
||||
return None
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
if isinstance(self.example, dict):
|
||||
return dict((k, v.numpy()) for k, v in self.example.items())
|
||||
else:
|
||||
return [i.numpy() for i in self.example]
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
ov_model = convert_model(model_obj,
|
||||
example_input=self.example,
|
||||
verbose=True)
|
||||
return ov_model
|
||||
|
||||
def infer_fw_model(self, model_obj, inputs):
|
||||
if isinstance(inputs, dict):
|
||||
inps = dict((k, torch.from_numpy(v)) for k, v in inputs.items())
|
||||
fw_outputs = model_obj(**inps)
|
||||
else:
|
||||
fw_outputs = model_obj(*[torch.from_numpy(i) for i in inputs])
|
||||
return flattenize_outputs(fw_outputs)
|
||||
|
||||
def teardown_method(self):
|
||||
# remove all downloaded files from cache
|
||||
cleanup_dir(hf_hub_cache_dir)
|
||||
@ -312,8 +277,7 @@ class TestTransformersModel(TestConvertModel):
|
||||
def test_convert_model_precommit(self, name, type, ie_device):
|
||||
self.run(model_name=name, model_link=type, ie_device=ie_device)
|
||||
|
||||
@pytest.mark.parametrize("name",
|
||||
[pytest.param(n, marks=pytest.mark.xfail(reason=r) if m == "xfail" else pytest.mark.skip(reason=r)) if m else n for n, _, m, r in get_models_list(os.path.join(os.path.dirname(__file__), "hf_transformers_models"))])
|
||||
@pytest.mark.parametrize("name", process_pytest_marks(os.path.join(os.path.dirname(__file__), "hf_transformers_models")))
|
||||
@pytest.mark.nightly
|
||||
def test_convert_model_all_models(self, name, ie_device):
|
||||
self.run(model_name=name, model_link=None, ie_device=ie_device)
|
||||
|
@ -10,20 +10,21 @@ import subprocess
|
||||
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from openvino import convert_model
|
||||
from torch_utils import TestTorchConvertModel
|
||||
|
||||
|
||||
# To make tests reproducible we seed the random generator
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class TestSpeechTransformerConvertModel(TestConvertModel):
|
||||
class TestSpeechTransformerConvertModel(TestTorchConvertModel):
|
||||
def setup_class(self):
|
||||
self.repo_dir = tempfile.TemporaryDirectory()
|
||||
os.system(
|
||||
f"git clone https://github.com/mvafin/Speech-Transformer.git {self.repo_dir.name}")
|
||||
subprocess.check_call(["git", "checkout", "071eebb7549b66bae2cb93e3391fe99749389456"], cwd=self.repo_dir.name)
|
||||
checkpoint_url = "https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/speech-transformer-cn.pt"
|
||||
subprocess.check_call(["wget", checkpoint_url], cwd=self.repo_dir.name)
|
||||
subprocess.check_call(["wget", "-nv", checkpoint_url], cwd=self.repo_dir.name)
|
||||
|
||||
def load_model(self, model_name, model_link):
|
||||
sys.path.append(self.repo_dir.name)
|
||||
@ -37,21 +38,11 @@ class TestSpeechTransformerConvertModel(TestConvertModel):
|
||||
self.example = (torch.randn(32, 209, 320),
|
||||
torch.stack(sorted(torch.randint(55, 250, [32]), reverse=True)),
|
||||
torch.randint(-1, 4232, [32, 20]))
|
||||
self.input = (torch.randn(32, 209, 320),
|
||||
self.inputs = (torch.randn(32, 209, 320),
|
||||
torch.stack(sorted(torch.randint(55, 400, [32]), reverse=True)),
|
||||
torch.randint(-1, 4232, [32, 25]))
|
||||
return m
|
||||
|
||||
def get_inputs_info(self, model_obj):
|
||||
return None
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
return [i.numpy() for i in self.input]
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
m = convert_model(model_obj, example_input=self.example)
|
||||
return m
|
||||
|
||||
def infer_fw_model(self, model_obj, inputs):
|
||||
fw_outputs = model_obj(*[torch.from_numpy(i) for i in inputs])
|
||||
if isinstance(fw_outputs, dict):
|
||||
|
@ -1,13 +1,13 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import timm
|
||||
import torch
|
||||
import pytest
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from torch_utils import TestTorchConvertModel, process_pytest_marks
|
||||
from models_hub_common.constants import hf_hub_cache_dir
|
||||
from models_hub_common.utils import cleanup_dir
|
||||
from openvino import convert_model
|
||||
from models_hub_common.utils import cleanup_dir, get_models_list
|
||||
|
||||
|
||||
def filter_timm(timm_list: list) -> list:
|
||||
@ -15,7 +15,7 @@ def filter_timm(timm_list: list) -> list:
|
||||
filtered_list = []
|
||||
ignore_set = {"base", "mini", "small", "xxtiny", "xtiny", "tiny", "lite", "nano", "pico", "medium", "big",
|
||||
"large", "xlarge", "xxlarge", "huge", "gigantic", "giant", "enormous", "xs", "xxs", "s", "m", "l", "xl"}
|
||||
for name in timm_list:
|
||||
for name in sorted(timm_list):
|
||||
# first: remove datasets
|
||||
name_parts = name.split(".")
|
||||
_name = "_".join(name.split(".")[:-1]) if len(name_parts) > 1 else name
|
||||
@ -30,15 +30,14 @@ def filter_timm(timm_list: list) -> list:
|
||||
|
||||
|
||||
def get_all_models() -> list:
|
||||
m_list = timm.list_pretrained()
|
||||
return filter_timm(m_list)
|
||||
return process_pytest_marks(os.path.join(os.path.dirname(__file__), "timm_models"))
|
||||
|
||||
|
||||
# To make tests reproducible we seed the random generator
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class TestTimmConvertModel(TestConvertModel):
|
||||
class TestTimmConvertModel(TestTorchConvertModel):
|
||||
def load_model(self, model_name, model_link):
|
||||
m = timm.create_model(model_name, pretrained=True)
|
||||
cfg = timm.get_pretrained_cfg(model_name)
|
||||
@ -47,16 +46,6 @@ class TestTimmConvertModel(TestConvertModel):
|
||||
self.inputs = (torch.randn(shape),)
|
||||
return m
|
||||
|
||||
def get_inputs_info(self, model_obj):
|
||||
return None
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
return [i.numpy() for i in self.inputs]
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
ov_model = convert_model(model_obj, example_input=self.example)
|
||||
return ov_model
|
||||
|
||||
def infer_fw_model(self, model_obj, inputs):
|
||||
fw_outputs = model_obj(*[torch.from_numpy(i) for i in inputs])
|
||||
if isinstance(fw_outputs, dict):
|
||||
@ -85,3 +74,11 @@ class TestTimmConvertModel(TestConvertModel):
|
||||
@pytest.mark.parametrize("name", get_all_models())
|
||||
def test_convert_model_all_models(self, name, ie_device):
|
||||
self.run(name, None, ie_device)
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test_models_list_complete(self, ie_device):
|
||||
m_list = timm.list_pretrained()
|
||||
all_models_ref = set(filter_timm(m_list))
|
||||
all_models = set([m for m, _, _, _ in get_models_list(
|
||||
os.path.join(os.path.dirname(__file__), "timm_models"))])
|
||||
assert all_models == all_models_ref, f"Lists of models are not equal."
|
||||
|
@ -6,9 +6,7 @@ import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import torchvision.transforms.functional as F
|
||||
from openvino import convert_model
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from models_hub_common.utils import get_models_list
|
||||
from torch_utils import process_pytest_marks, TestTorchConvertModel
|
||||
|
||||
|
||||
def get_all_models() -> list:
|
||||
@ -52,7 +50,7 @@ def prepare_frames_for_raft(name, frames1, frames2):
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class TestTorchHubConvertModel(TestConvertModel):
|
||||
class TestTorchHubConvertModel(TestTorchConvertModel):
|
||||
def setup_class(self):
|
||||
self.cache_dir = tempfile.TemporaryDirectory()
|
||||
# set temp dir for torch cache
|
||||
@ -83,16 +81,6 @@ class TestTorchHubConvertModel(TestConvertModel):
|
||||
self.inputs = (torch.randn(1, 3, 224, 224),)
|
||||
return m
|
||||
|
||||
def get_inputs_info(self, model_obj):
|
||||
return None
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
return [i.numpy() for i in self.inputs]
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
ov_model = convert_model(model_obj, example_input=self.example)
|
||||
return ov_model
|
||||
|
||||
def infer_fw_model(self, model_obj, inputs):
|
||||
fw_outputs = model_obj(*[torch.from_numpy(i) for i in inputs])
|
||||
if isinstance(fw_outputs, dict):
|
||||
@ -114,8 +102,7 @@ class TestTorchHubConvertModel(TestConvertModel):
|
||||
def test_convert_model_precommit(self, model_name, ie_device):
|
||||
self.run(model_name, None, ie_device)
|
||||
|
||||
@pytest.mark.parametrize("name",
|
||||
[pytest.param(n, marks=pytest.mark.xfail(reason=r)) if m == "xfail" else n for n, _, m, r in get_models_list(os.path.join(os.path.dirname(__file__), "torchvision_models"))])
|
||||
@pytest.mark.parametrize("name", process_pytest_marks(os.path.join(os.path.dirname(__file__), "torchvision_models")))
|
||||
@pytest.mark.nightly
|
||||
def test_convert_model_all_models(self, name, ie_device):
|
||||
self.run(name, None, ie_device)
|
||||
|
499
tests/model_hub_tests/torch_tests/timm_models
Normal file
499
tests/model_hub_tests/torch_tests/timm_models
Normal file
@ -0,0 +1,499 @@
|
||||
bat_resnext26ts.ch_in1k,None
|
||||
beit_base_patch16_224.in22k_ft_in22k,None
|
||||
beitv2_base_patch16_224.in1k_ft_in1k,None
|
||||
botnet26t_256.c1_in1k,None
|
||||
caformer_b36.sail_in1k,None
|
||||
caformer_m36.sail_in1k,None
|
||||
caformer_s18.sail_in1k,None
|
||||
caformer_s36.sail_in1k,None
|
||||
cait_m36_384.fb_dist_in1k,None
|
||||
cait_m48_448.fb_dist_in1k,None
|
||||
cait_s24_224.fb_dist_in1k,None
|
||||
cait_s36_384.fb_dist_in1k,None
|
||||
cait_xs24_384.fb_dist_in1k,None
|
||||
cait_xxs24_224.fb_dist_in1k,None
|
||||
cait_xxs36_224.fb_dist_in1k,None
|
||||
coat_lite_medium.in1k,None
|
||||
coatnet_0_rw_224.sw_in1k,None
|
||||
coatnet_bn_0_rw_224.sw_in1k,None
|
||||
coatnet_rmlp_1_rw2_224.sw_in12k,None
|
||||
coatnet_rmlp_1_rw_224.sw_in1k,None
|
||||
coatnext_nano_rw_224.sw_in1k,None
|
||||
convformer_b36.sail_in1k,None
|
||||
convformer_m36.sail_in1k,None
|
||||
convformer_s18.sail_in1k,None
|
||||
convformer_s36.sail_in1k,None
|
||||
convit_base.fb_in1k,None,xfail,Trace failed
|
||||
convmixer_1024_20_ks9_p14.in1k,None
|
||||
convmixer_1536_20.in1k,None
|
||||
convnext_atto.d2_in1k,None
|
||||
convnext_atto_ols.a2_in1k,None
|
||||
convnext_base.clip_laion2b,None
|
||||
convnext_femto.d1_in1k,None
|
||||
convnext_femto_ols.d1_in1k,None
|
||||
convnext_large_mlp.clip_laion2b_augreg,None
|
||||
convnext_nano_ols.d1h_in1k,None
|
||||
convnext_tiny_hnf.a2h_in1k,None
|
||||
convnextv2_atto.fcmae,None
|
||||
convnextv2_base.fcmae,None
|
||||
convnextv2_femto.fcmae,None
|
||||
crossvit_15_240.in1k,None
|
||||
crossvit_15_dagger_240.in1k,None
|
||||
cs3darknet_focus_l.c2ns_in1k,None
|
||||
cs3darknet_l.c2ns_in1k,None
|
||||
cs3darknet_x.c2ns_in1k,None
|
||||
cs3edgenet_x.c2_in1k,None
|
||||
cs3se_edgenet_x.c2ns_in1k,None
|
||||
cs3sedarknet_l.c2ns_in1k,None
|
||||
cs3sedarknet_x.c2ns_in1k,None
|
||||
cspdarknet53.ra_in1k,None
|
||||
cspresnet50.ra_in1k,None
|
||||
cspresnext50.ra_in1k,None
|
||||
darknet53.c2ns_in1k,None
|
||||
darknetaa53.c2ns_in1k,None
|
||||
davit_base.msft_in1k,None
|
||||
deit3_base_patch16_224.fb_in1k,None
|
||||
deit3_huge_patch14_224.fb_in1k,None
|
||||
deit_base_distilled_patch16_224.fb_in1k,None
|
||||
deit_base_patch16_224.fb_in1k,None
|
||||
densenet121.ra_in1k,None
|
||||
densenet161.tv_in1k,None
|
||||
densenet169.tv_in1k,None
|
||||
densenet201.tv_in1k,None
|
||||
densenetblur121d.ra_in1k,None
|
||||
dla102.in1k,None
|
||||
dla102x.in1k,None
|
||||
dla102x2.in1k,None
|
||||
dla169.in1k,None
|
||||
dla34.in1k,None
|
||||
dla46_c.in1k,None
|
||||
dla46x_c.in1k,None
|
||||
dla60.in1k,None
|
||||
dla60_res2net.in1k,None
|
||||
dla60_res2next.in1k,None
|
||||
dla60x.in1k,None
|
||||
dla60x_c.in1k,None
|
||||
dm_nfnet_f0.dm_in1k,None
|
||||
dm_nfnet_f1.dm_in1k,None
|
||||
dm_nfnet_f2.dm_in1k,None
|
||||
dm_nfnet_f3.dm_in1k,None
|
||||
dm_nfnet_f4.dm_in1k,None
|
||||
dm_nfnet_f5.dm_in1k,None
|
||||
dm_nfnet_f6.dm_in1k,None
|
||||
dpn107.mx_in1k,None
|
||||
dpn131.mx_in1k,None
|
||||
dpn68.mx_in1k,None
|
||||
dpn68b.mx_in1k,None
|
||||
dpn92.mx_in1k,None
|
||||
dpn98.mx_in1k,None
|
||||
eca_botnext26ts_256.c1_in1k,None
|
||||
eca_halonext26ts.c1_in1k,None
|
||||
eca_nfnet_l0.ra2_in1k,None
|
||||
eca_nfnet_l1.ra2_in1k,None
|
||||
eca_nfnet_l2.ra3_in1k,None
|
||||
eca_resnet33ts.ra2_in1k,None
|
||||
eca_resnext26ts.ch_in1k,None
|
||||
ecaresnet101d.miil_in1k,None
|
||||
ecaresnet101d_pruned.miil_in1k,None
|
||||
ecaresnet269d.ra2_in1k,None
|
||||
ecaresnet26t.ra2_in1k,None
|
||||
ecaresnet50d.miil_in1k,None
|
||||
ecaresnet50d_pruned.miil_in1k,None
|
||||
ecaresnet50t.a1_in1k,None
|
||||
ecaresnetlight.miil_in1k,None
|
||||
edgenext_base.in21k_ft_in1k,None
|
||||
edgenext_small_rw.sw_in1k,None
|
||||
edgenext_x_small.in1k,None
|
||||
edgenext_xx_small.in1k,None
|
||||
efficientformer_l1.snap_dist_in1k,None
|
||||
efficientformer_l3.snap_dist_in1k,None
|
||||
efficientformer_l7.snap_dist_in1k,None
|
||||
efficientformerv2_l.snap_dist_in1k,None
|
||||
efficientformerv2_s0.snap_dist_in1k,None
|
||||
efficientformerv2_s1.snap_dist_in1k,None
|
||||
efficientformerv2_s2.snap_dist_in1k,None
|
||||
efficientnet_b0.ra_in1k,None
|
||||
efficientnet_b1.ft_in1k,None
|
||||
efficientnet_b1_pruned.in1k,None
|
||||
efficientnet_b2.ra_in1k,None
|
||||
efficientnet_b2_pruned.in1k,None
|
||||
efficientnet_b3.ra2_in1k,None
|
||||
efficientnet_b3_pruned.in1k,None
|
||||
efficientnet_b4.ra2_in1k,None
|
||||
efficientnet_b5.sw_in12k,None
|
||||
efficientnet_el.ra_in1k,None
|
||||
efficientnet_el_pruned.in1k,None
|
||||
efficientnet_em.ra2_in1k,None
|
||||
efficientnet_es.ra_in1k,None
|
||||
efficientnet_es_pruned.in1k,None
|
||||
efficientnet_lite0.ra_in1k,None
|
||||
efficientnetv2_rw_m.agc_in1k,None
|
||||
efficientnetv2_rw_t.ra2_in1k,None
|
||||
efficientvit_b0.r224_in1k,None
|
||||
efficientvit_b1.r224_in1k,None
|
||||
efficientvit_b2.r224_in1k,None
|
||||
efficientvit_b3.r224_in1k,None
|
||||
efficientvit_m0.r224_in1k,None
|
||||
efficientvit_m1.r224_in1k,None
|
||||
efficientvit_m2.r224_in1k,None
|
||||
efficientvit_m3.r224_in1k,None
|
||||
efficientvit_m4.r224_in1k,None
|
||||
efficientvit_m5.r224_in1k,None
|
||||
ese_vovnet19b_dw.ra_in1k,None
|
||||
ese_vovnet39b.ra_in1k,None
|
||||
eva02_base_patch14_224.mim_in22k,None
|
||||
eva02_base_patch16_clip_224.merged2b,None
|
||||
eva02_enormous_patch14_clip_224.laion2b,None
|
||||
eva_giant_patch14_224.clip_ft_in1k,None
|
||||
eva_giant_patch14_clip_224.laion400m,None
|
||||
fastvit_ma36.apple_dist_in1k,None
|
||||
fastvit_s12.apple_dist_in1k,None
|
||||
fastvit_sa12.apple_dist_in1k,None
|
||||
fastvit_sa24.apple_dist_in1k,None
|
||||
fastvit_sa36.apple_dist_in1k,None
|
||||
fastvit_t12.apple_dist_in1k,None
|
||||
fastvit_t8.apple_dist_in1k,None
|
||||
fbnetc_100.rmsp_in1k,None
|
||||
fbnetv3_b.ra2_in1k,None
|
||||
fbnetv3_d.ra2_in1k,None
|
||||
fbnetv3_g.ra2_in1k,None
|
||||
flexivit_base.1000ep_in21k,None
|
||||
focalnet_base_lrf.ms_in1k,None
|
||||
focalnet_base_srf.ms_in1k,None
|
||||
focalnet_huge_fl3.ms_in22k,None
|
||||
focalnet_huge_fl4.ms_in22k,None
|
||||
gc_efficientnetv2_rw_t.agc_in1k,None
|
||||
gcresnet33ts.ra2_in1k,None,xfail,Descriptors shape is incompatible with provided dimensions
|
||||
gcresnet50t.ra2_in1k,None,xfail,Descriptors shape is incompatible with provided dimensions
|
||||
gcresnext26ts.ch_in1k,None,xfail,Descriptors shape is incompatible with provided dimensions
|
||||
gcresnext50ts.ch_in1k,None,xfail,Descriptors shape is incompatible with provided dimensions
|
||||
gcvit_base.in1k,None
|
||||
gernet_l.idstcv_in1k,None
|
||||
ghostnet_100.in1k,None
|
||||
ghostnetv2_100.in1k,None
|
||||
gmixer_24_224.ra3_in1k,None
|
||||
gmlp_s16_224.ra3_in1k,None
|
||||
halo2botnet50ts_256.a1h_in1k,None
|
||||
halonet26t.a1h_in1k,None
|
||||
halonet50ts.a1h_in1k,None
|
||||
haloregnetz_b.ra3_in1k,None
|
||||
hardcorenas_a.miil_green_in1k,None
|
||||
hardcorenas_b.miil_green_in1k,None
|
||||
hardcorenas_c.miil_green_in1k,None
|
||||
hardcorenas_d.miil_green_in1k,None
|
||||
hardcorenas_e.miil_green_in1k,None
|
||||
hardcorenas_f.miil_green_in1k,None
|
||||
hrnet_w18.ms_aug_in1k,None
|
||||
hrnet_w18_small_v2.gluon_in1k,None
|
||||
hrnet_w18_ssld.paddle_in1k,None
|
||||
hrnet_w30.ms_in1k,None
|
||||
hrnet_w32.ms_in1k,None
|
||||
hrnet_w40.ms_in1k,None
|
||||
hrnet_w44.ms_in1k,None
|
||||
hrnet_w48.ms_in1k,None
|
||||
hrnet_w48_ssld.paddle_in1k,None
|
||||
hrnet_w64.ms_in1k,None
|
||||
inception_next_base.sail_in1k,None
|
||||
inception_resnet_v2.tf_ens_adv_in1k,None
|
||||
inception_v3.gluon_in1k,None
|
||||
inception_v4.tf_in1k,None
|
||||
lambda_resnet26rpt_256.c1_in1k,None
|
||||
lambda_resnet26t.c1_in1k,None
|
||||
lambda_resnet50ts.a1h_in1k,None
|
||||
lamhalobotnet50ts_256.a1h_in1k,None
|
||||
lcnet_050.ra2_in1k,None
|
||||
legacy_senet154.in1k,None
|
||||
legacy_seresnet101.in1k,None
|
||||
legacy_seresnet152.in1k,None
|
||||
legacy_seresnet18.in1k,None
|
||||
legacy_seresnet34.in1k,None
|
||||
legacy_seresnet50.in1k,None
|
||||
legacy_seresnext101_32x4d.in1k,None
|
||||
legacy_seresnext26_32x4d.in1k,None
|
||||
legacy_seresnext50_32x4d.in1k,None
|
||||
legacy_xception.tf_in1k,None
|
||||
levit_128.fb_dist_in1k,None
|
||||
levit_128s.fb_dist_in1k,None
|
||||
levit_conv_128.fb_dist_in1k,None
|
||||
levit_conv_128s.fb_dist_in1k,None
|
||||
maxvit_base_tf_224.in1k,None
|
||||
maxvit_nano_rw_256.sw_in1k,None
|
||||
maxvit_rmlp_base_rw_224.sw_in12k,None
|
||||
maxxvit_rmlp_nano_rw_256.sw_in1k,None
|
||||
maxxvitv2_nano_rw_256.sw_in1k,None
|
||||
maxxvitv2_rmlp_base_rw_224.sw_in12k,None
|
||||
mixer_b16_224.goog_in21k,None
|
||||
mixer_l16_224.goog_in21k,None
|
||||
mixnet_l.ft_in1k,None
|
||||
mnasnet_100.rmsp_in1k,None
|
||||
mobilenetv2_050.lamb_in1k,None
|
||||
mobilenetv2_110d.ra_in1k,None
|
||||
mobilenetv2_120d.ra_in1k,None
|
||||
mobilenetv3_large_100.miil_in21k,None
|
||||
mobilenetv3_rw.rmsp_in1k,None
|
||||
mobileone_s0.apple_in1k,None
|
||||
mobileone_s1.apple_in1k,None
|
||||
mobileone_s2.apple_in1k,None
|
||||
mobileone_s3.apple_in1k,None
|
||||
mobileone_s4.apple_in1k,None
|
||||
mobilevit_s.cvnets_in1k,None
|
||||
mobilevitv2_050.cvnets_in1k,None
|
||||
mvitv2_base.fb_in1k,None
|
||||
mvitv2_base_cls.fb_inw21k,None
|
||||
nasnetalarge.tf_in1k,None
|
||||
nest_base_jx.goog_in1k,None
|
||||
nf_regnet_b1.ra2_in1k,None
|
||||
nf_resnet50.ra2_in1k,None
|
||||
nfnet_l0.ra2_in1k,None
|
||||
pit_b_224.in1k,None
|
||||
pit_b_distilled_224.in1k,None
|
||||
pit_s_224.in1k,None
|
||||
pit_s_distilled_224.in1k,None
|
||||
pit_ti_224.in1k,None
|
||||
pit_ti_distilled_224.in1k,None
|
||||
pnasnet5large.tf_in1k,None
|
||||
poolformer_m36.sail_in1k,None
|
||||
poolformer_m48.sail_in1k,None
|
||||
poolformer_s12.sail_in1k,None
|
||||
poolformer_s24.sail_in1k,None
|
||||
poolformer_s36.sail_in1k,None
|
||||
poolformerv2_m36.sail_in1k,None
|
||||
poolformerv2_m48.sail_in1k,None
|
||||
poolformerv2_s12.sail_in1k,None
|
||||
poolformerv2_s24.sail_in1k,None
|
||||
poolformerv2_s36.sail_in1k,None
|
||||
pvt_v2_b0.in1k,None
|
||||
pvt_v2_b1.in1k,None
|
||||
pvt_v2_b2.in1k,None
|
||||
pvt_v2_b2_li.in1k,None
|
||||
pvt_v2_b3.in1k,None
|
||||
pvt_v2_b4.in1k,None
|
||||
pvt_v2_b5.in1k,None
|
||||
regnetv_040.ra3_in1k,None
|
||||
regnetx_002.pycls_in1k,None
|
||||
regnetx_004_tv.tv2_in1k,None
|
||||
regnety_002.pycls_in1k,None
|
||||
regnety_008_tv.tv2_in1k,None
|
||||
regnetz_040.ra3_in1k,None
|
||||
regnetz_040_h.ra3_in1k,None
|
||||
regnetz_b16.ra3_in1k,None
|
||||
regnetz_c16.ra3_in1k,None
|
||||
regnetz_c16_evos.ch_in1k,None
|
||||
regnetz_d32.ra3_in1k,None
|
||||
regnetz_d8.ra3_in1k,None
|
||||
regnetz_d8_evos.ch_in1k,None
|
||||
regnetz_e8.ra3_in1k,None
|
||||
repghostnet_050.in1k,None
|
||||
repvgg_a0.rvgg_in1k,None
|
||||
repvgg_a1.rvgg_in1k,None
|
||||
repvgg_a2.rvgg_in1k,None
|
||||
repvgg_b0.rvgg_in1k,None
|
||||
repvgg_b1.rvgg_in1k,None
|
||||
repvgg_b1g4.rvgg_in1k,None
|
||||
repvgg_b2.rvgg_in1k,None
|
||||
repvgg_b2g4.rvgg_in1k,None
|
||||
repvgg_b3.rvgg_in1k,None
|
||||
repvgg_b3g4.rvgg_in1k,None
|
||||
repvgg_d2se.rvgg_in1k,None
|
||||
repvit_m1.dist_in1k,None
|
||||
repvit_m2.dist_in1k,None
|
||||
repvit_m3.dist_in1k,None
|
||||
res2net101_26w_4s.in1k,None
|
||||
res2net101d.in1k,None
|
||||
res2net50_14w_8s.in1k,None
|
||||
res2net50_26w_4s.in1k,None
|
||||
res2net50_26w_6s.in1k,None
|
||||
res2net50_26w_8s.in1k,None
|
||||
res2net50_48w_2s.in1k,None
|
||||
res2net50d.in1k,None
|
||||
res2next50.in1k,None
|
||||
resmlp_12_224.fb_dino,None
|
||||
resnest101e.in1k,None
|
||||
resnest14d.gluon_in1k,None
|
||||
resnest200e.in1k,None
|
||||
resnest269e.in1k,None
|
||||
resnest26d.gluon_in1k,None
|
||||
resnest50d.in1k,None
|
||||
resnest50d_1s4x24d.in1k,None
|
||||
resnest50d_4s2x40d.in1k,None
|
||||
resnet101.a1_in1k,None
|
||||
resnet101c.gluon_in1k,None
|
||||
resnet101d.gluon_in1k,None
|
||||
resnet101s.gluon_in1k,None
|
||||
resnet10t.c3_in1k,None
|
||||
resnet14t.c3_in1k,None
|
||||
resnet152.a1_in1k,None
|
||||
resnet152c.gluon_in1k,None
|
||||
resnet152d.gluon_in1k,None
|
||||
resnet152s.gluon_in1k,None
|
||||
resnet18.a1_in1k,None
|
||||
resnet18d.ra2_in1k,None
|
||||
resnet200d.ra2_in1k,None
|
||||
resnet26.bt_in1k,None
|
||||
resnet26d.bt_in1k,None
|
||||
resnet26t.ra2_in1k,None
|
||||
resnet32ts.ra2_in1k,None
|
||||
resnet33ts.ra2_in1k,None
|
||||
resnet34.a1_in1k,None
|
||||
resnet34d.ra2_in1k,None
|
||||
resnet50.a1_in1k,None
|
||||
resnet50_gn.a1h_in1k,None
|
||||
resnet50c.gluon_in1k,None
|
||||
resnet50d.a1_in1k,None
|
||||
resnet50s.gluon_in1k,None
|
||||
resnet51q.ra2_in1k,None
|
||||
resnet61q.ra2_in1k,None
|
||||
resnetaa101d.sw_in12k,None
|
||||
resnetaa50.a1h_in1k,None
|
||||
resnetaa50d.d_in12k,None
|
||||
resnetblur50.bt_in1k,None
|
||||
resnetrs101.tf_in1k,None
|
||||
resnetrs152.tf_in1k,None
|
||||
resnetrs200.tf_in1k,None
|
||||
resnetrs270.tf_in1k,None
|
||||
resnetrs350.tf_in1k,None
|
||||
resnetrs420.tf_in1k,None
|
||||
resnetrs50.tf_in1k,None
|
||||
resnetv2_101.a1h_in1k,None
|
||||
resnetv2_101x1_bit.goog_in21k,None
|
||||
resnetv2_101x3_bit.goog_in21k,None
|
||||
resnetv2_152x2_bit.goog_in21k,None
|
||||
resnetv2_152x4_bit.goog_in21k,None
|
||||
resnetv2_50d_evos.ah_in1k,None
|
||||
resnetv2_50d_gn.ah_in1k,None
|
||||
resnetv2_50x1_bit.goog_distilled_in1k,None
|
||||
resnetv2_50x3_bit.goog_in21k,None
|
||||
resnext101_32x16d.fb_ssl_yfcc100m_ft_in1k,None
|
||||
resnext101_32x32d.fb_wsl_ig1b_ft_in1k,None
|
||||
resnext101_32x4d.fb_ssl_yfcc100m_ft_in1k,None
|
||||
resnext101_32x8d.fb_ssl_yfcc100m_ft_in1k,None
|
||||
resnext101_64x4d.c1_in1k,None
|
||||
resnext26ts.ra2_in1k,None
|
||||
resnext50_32x4d.a1_in1k,None
|
||||
resnext50d_32x4d.bt_in1k,None
|
||||
rexnet_100.nav_in1k,None
|
||||
rexnetr_200.sw_in12k,None
|
||||
samvit_base_patch16.sa1b,None
|
||||
sebotnet33ts_256.a1h_in1k,None
|
||||
sehalonet33ts.ra2_in1k,None
|
||||
selecsls42b.in1k,None
|
||||
selecsls60.in1k,None
|
||||
selecsls60b.in1k,None
|
||||
semnasnet_075.rmsp_in1k,None
|
||||
senet154.gluon_in1k,None
|
||||
sequencer2d_l.in1k,None,xfail,Unsupported aten::lstm
|
||||
seresnet152d.ra2_in1k,None
|
||||
seresnet33ts.ra2_in1k,None
|
||||
seresnet50.a1_in1k,None
|
||||
seresnext101_32x4d.gluon_in1k,None
|
||||
seresnext101_32x8d.ah_in1k,None
|
||||
seresnext101_64x4d.gluon_in1k,None
|
||||
seresnext101d_32x8d.ah_in1k,None
|
||||
seresnext26d_32x4d.bt_in1k,None
|
||||
seresnext26t_32x4d.bt_in1k,None
|
||||
seresnext26ts.ch_in1k,None
|
||||
seresnext50_32x4d.gluon_in1k,None
|
||||
seresnextaa101d_32x8d.ah_in1k,None
|
||||
seresnextaa201d_32x8d.sw_in12k,None
|
||||
skresnet18.ra_in1k,None
|
||||
skresnet34.ra_in1k,None
|
||||
skresnext50_32x4d.ra_in1k,None
|
||||
spnasnet_100.rmsp_in1k,None
|
||||
swin_base_patch4_window12_384.ms_in1k,None
|
||||
swin_base_patch4_window7_224.ms_in1k,None
|
||||
swin_s3_base_224.ms_in1k,None
|
||||
swinv2_base_window12_192.ms_in22k,None
|
||||
swinv2_base_window12to16_192to256.ms_in22k_ft_in1k,None
|
||||
swinv2_base_window12to24_192to384.ms_in22k_ft_in1k,None
|
||||
swinv2_base_window16_256.ms_in1k,None
|
||||
swinv2_base_window8_256.ms_in1k,None
|
||||
swinv2_cr_small_224.sw_in1k,None
|
||||
swinv2_cr_small_ns_224.sw_in1k,None
|
||||
tf_efficientnet_b0.aa_in1k,None
|
||||
tf_efficientnet_b1.aa_in1k,None
|
||||
tf_efficientnet_b2.aa_in1k,None
|
||||
tf_efficientnet_b3.aa_in1k,None
|
||||
tf_efficientnet_b4.aa_in1k,None
|
||||
tf_efficientnet_b5.aa_in1k,None
|
||||
tf_efficientnet_b6.aa_in1k,None
|
||||
tf_efficientnet_b7.aa_in1k,None
|
||||
tf_efficientnet_b8.ap_in1k,None
|
||||
tf_efficientnet_cc_b0_4e.in1k,None,xfail,Unsupported dynamic weights shape
|
||||
tf_efficientnet_cc_b0_8e.in1k,None,xfail,Unsupported dynamic weights shape
|
||||
tf_efficientnet_cc_b1_8e.in1k,None,xfail,Unsupported dynamic weights shape
|
||||
tf_efficientnet_el.in1k,None
|
||||
tf_efficientnet_em.in1k,None
|
||||
tf_efficientnet_es.in1k,None
|
||||
tf_efficientnet_l2.ns_jft_in1k,None
|
||||
tf_efficientnet_lite0.in1k,None
|
||||
tf_efficientnet_lite1.in1k,None
|
||||
tf_efficientnet_lite2.in1k,None
|
||||
tf_efficientnet_lite3.in1k,None
|
||||
tf_efficientnet_lite4.in1k,None
|
||||
tf_efficientnetv2_b0.in1k,None
|
||||
tf_efficientnetv2_b1.in1k,None
|
||||
tf_efficientnetv2_b2.in1k,None
|
||||
tf_efficientnetv2_b3.in1k,None
|
||||
tf_efficientnetv2_l.in1k,None
|
||||
tf_mixnet_l.in1k,None
|
||||
tf_mobilenetv3_large_075.in1k,None
|
||||
tf_mobilenetv3_large_minimal_100.in1k,None
|
||||
tiny_vit_11m_224.dist_in22k,None
|
||||
tiny_vit_21m_224.dist_in22k,None
|
||||
tiny_vit_5m_224.dist_in22k,None
|
||||
tinynet_a.in1k,None
|
||||
tinynet_b.in1k,None
|
||||
tinynet_c.in1k,None
|
||||
tinynet_d.in1k,None
|
||||
tinynet_e.in1k,None
|
||||
tnt_s_patch16_224,None
|
||||
tresnet_l.miil_in1k,None
|
||||
tresnet_v2_l.miil_in21k,None
|
||||
twins_pcpvt_base.in1k,None
|
||||
twins_svt_base.in1k,None
|
||||
vgg11.tv_in1k,None
|
||||
vgg11_bn.tv_in1k,None
|
||||
vgg13.tv_in1k,None
|
||||
vgg13_bn.tv_in1k,None
|
||||
vgg16.tv_in1k,None
|
||||
vgg16_bn.tv_in1k,None
|
||||
vgg19.tv_in1k,None
|
||||
vgg19_bn.tv_in1k,None
|
||||
visformer_small.in1k,None
|
||||
vit_base_patch14_dinov2.lvd142m,None
|
||||
vit_base_patch16_224.augreg2_in21k_ft_in1k,None
|
||||
vit_base_patch16_224_miil.in21k,None
|
||||
vit_base_patch16_clip_224.datacompxl,None
|
||||
vit_base_patch16_rpn_224.sw_in1k,None
|
||||
vit_base_patch32_224.augreg_in1k,None
|
||||
vit_base_patch32_clip_224.laion2b,None
|
||||
vit_base_patch8_224.augreg2_in21k_ft_in1k,None
|
||||
vit_base_r50_s16_224.orig_in21k,None
|
||||
vit_giant_patch14_clip_224.laion2b,None
|
||||
vit_gigantic_patch16_224_ijepa.in22k,None
|
||||
vit_huge_patch14_224.mae,None
|
||||
vit_huge_patch14_224_ijepa.in1k,None
|
||||
vit_large_r50_s32_224.augreg_in21k,None
|
||||
vit_medium_patch16_gap_240.sw_in12k,None
|
||||
vit_relpos_base_patch16_224.sw_in1k,None
|
||||
vit_relpos_base_patch16_clsgap_224.sw_in1k,None
|
||||
vit_relpos_base_patch32_plus_rpn_256.sw_in1k,None
|
||||
vit_relpos_medium_patch16_cls_224.sw_in1k,None
|
||||
vit_relpos_medium_patch16_rpn_224.sw_in1k,None
|
||||
vit_small_r26_s32_224.augreg_in21k,None
|
||||
vit_srelpos_medium_patch16_224.sw_in1k,None
|
||||
vit_tiny_r_s16_p8_224.augreg_in21k,None
|
||||
volo_d1_224.sail_in1k,None,xfail,Unsupported aten::col2im
|
||||
volo_d2_224.sail_in1k,None,xfail,Unsupported aten::col2im
|
||||
volo_d3_224.sail_in1k,None,xfail,Unsupported aten::col2im
|
||||
volo_d4_224.sail_in1k,None,xfail,Unsupported aten::col2im
|
||||
volo_d5_224.sail_in1k,None,xfail,Unsupported aten::col2im
|
||||
wide_resnet101_2.tv2_in1k,None
|
||||
wide_resnet50_2.racm_in1k,None
|
||||
xception41.tf_in1k,None
|
||||
xception41p.ra3_in1k,None
|
||||
xception65.ra3_in1k,None
|
||||
xception65p.ra3_in1k,None
|
||||
xception71.tf_in1k,None
|
||||
xcit_large_24_p16_224.fb_dist_in1k,None
|
||||
xcit_large_24_p8_224.fb_dist_in1k,None
|
60
tests/model_hub_tests/torch_tests/torch_utils.py
Normal file
60
tests/model_hub_tests/torch_tests/torch_utils.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from models_hub_common.utils import get_models_list
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from openvino import convert_model
|
||||
|
||||
|
||||
def flattenize_tuples(list_input):
|
||||
if not isinstance(list_input, (tuple, list)):
|
||||
return [list_input]
|
||||
unpacked_pt_res = []
|
||||
for r in list_input:
|
||||
unpacked_pt_res.extend(flattenize_tuples(r))
|
||||
return unpacked_pt_res
|
||||
|
||||
|
||||
def flattenize_structure(outputs):
|
||||
if not isinstance(outputs, dict):
|
||||
outputs = flattenize_tuples(outputs)
|
||||
return [i.numpy(force=True) if isinstance(i, torch.Tensor) else i for i in outputs]
|
||||
else:
|
||||
return dict((k, v.numpy(force=True) if isinstance(v, torch.Tensor) else v) for k, v in outputs.items())
|
||||
|
||||
|
||||
def process_pytest_marks(filepath: str):
|
||||
return [pytest.param(n, marks=pytest.mark.xfail(reason=r) if m == "xfail" else pytest.mark.skip(reason=r)) if m else n for n, _, m, r in get_models_list(filepath)]
|
||||
|
||||
|
||||
class TestTorchConvertModel(TestConvertModel):
|
||||
def setup_class(self):
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
def load_model(self, model_name, model_link):
|
||||
raise "load_model is not implemented"
|
||||
|
||||
def get_inputs_info(self, model_obj):
|
||||
return None
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
inputs = getattr(self, "inputs", self.example)
|
||||
if isinstance(inputs, dict):
|
||||
return dict((k, v.numpy()) for k, v in inputs.items())
|
||||
else:
|
||||
return [i.numpy() for i in inputs]
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
ov_model = convert_model(
|
||||
model_obj, example_input=self.example, verbose=True)
|
||||
return ov_model
|
||||
|
||||
def infer_fw_model(self, model_obj, inputs):
|
||||
if isinstance(inputs, dict):
|
||||
inps = dict((k, torch.from_numpy(v)) for k, v in inputs.items())
|
||||
fw_outputs = model_obj(**inps)
|
||||
else:
|
||||
fw_outputs = model_obj(*[torch.from_numpy(i) for i in inputs])
|
||||
return flattenize_structure(fw_outputs)
|
Loading…
Reference in New Issue
Block a user