[PT FE] Fix issue with http error when using torch.hub (#19901)
* [PT FE] Fix issue with http error when using torch.hub * Mark failing models as xfail * Remove incorrect model names
This commit is contained in:
parent
dbab89f047
commit
c10b45fe9e
@ -99,7 +99,7 @@ class TestConvertModel:
|
||||
fw_outputs = self.infer_fw_model(fw_model, inputs)
|
||||
print("Infer ov::Model")
|
||||
ov_outputs = self.infer_ov_model(ov_model, inputs, ie_device)
|
||||
print("Compare TensorFlow and OpenVINO results")
|
||||
print("Compare framework and OpenVINO results")
|
||||
self.compare_results(fw_outputs, ov_outputs)
|
||||
|
||||
def run(self, model_name, model_link, ie_device):
|
||||
|
@ -22,7 +22,7 @@ def get_models_list(file_name: str):
|
||||
model_name, model_link = model_info.split(',')
|
||||
elif len(model_info.split(',')) == 4:
|
||||
model_name, model_link, mark, reason = model_info.split(',')
|
||||
assert mark == "skip", "Incorrect failure mark for model info {}".format(model_info)
|
||||
assert mark in ["skip", "xfail"], "Incorrect failure mark for model info {}".format(model_info)
|
||||
models.append((model_name, model_link, mark, reason))
|
||||
|
||||
return models
|
||||
|
@ -1,16 +1,18 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import torchvision.transforms.functional as F
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from openvino import convert_model
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from models_hub_common.utils import get_models_list
|
||||
|
||||
|
||||
def get_all_models() -> list:
|
||||
m_list = torch.hub.list("pytorch/vision")
|
||||
m_list = torch.hub.list("pytorch/vision", skip_validation=True)
|
||||
m_list.remove("get_model_weights")
|
||||
m_list.remove("get_weight")
|
||||
return m_list
|
||||
@ -36,7 +38,8 @@ def get_video():
|
||||
|
||||
|
||||
def prepare_frames_for_raft(name, frames1, frames2):
|
||||
w = torch.hub.load("pytorch/vision", "get_model_weights", name=name).DEFAULT
|
||||
w = torch.hub.load("pytorch/vision", "get_model_weights",
|
||||
name=name, skip_validation=True).DEFAULT
|
||||
img1_batch = torch.stack(frames1)
|
||||
img2_batch = torch.stack(frames2)
|
||||
img1_batch = F.resize(img1_batch, size=[520, 960], antialias=False)
|
||||
@ -50,13 +53,14 @@ torch.manual_seed(0)
|
||||
|
||||
|
||||
class TestTorchHubConvertModel(TestConvertModel):
|
||||
def setup_method(self):
|
||||
def setup_class(self):
|
||||
self.cache_dir = tempfile.TemporaryDirectory()
|
||||
# set temp dir for torch cache
|
||||
torch.hub.set_dir(str(self.cache_dir.name))
|
||||
|
||||
def load_model(self, model_name, model_link):
|
||||
m = torch.hub.load("pytorch/vision", model_name, weights='DEFAULT')
|
||||
m = torch.hub.load("pytorch/vision", model_name,
|
||||
weights='DEFAULT', skip_validation=True)
|
||||
m.eval()
|
||||
if model_name == "s3d" or any([m in model_name for m in ["swin3d", "r3d_18", "mc3_18", "r2plus1d_18"]]):
|
||||
self.example = (torch.randn([1, 3, 224, 224, 224]),)
|
||||
@ -109,7 +113,8 @@ class TestTorchHubConvertModel(TestConvertModel):
|
||||
def test_convert_model_precommit(self, model_name, ie_device):
|
||||
self.run(model_name, None, ie_device)
|
||||
|
||||
@pytest.mark.parametrize("model_name", get_all_models())
|
||||
@pytest.mark.parametrize("name",
|
||||
[pytest.param(n, marks=pytest.mark.xfail) if m == "xfail" else n for n, _, m, r in get_models_list(os.path.join(os.path.dirname(__file__), "torchvision_models"))])
|
||||
@pytest.mark.nightly
|
||||
def test_convert_model_all_models(self, model_name, ie_device):
|
||||
self.run(model_name, None, ie_device)
|
||||
def test_convert_model_all_models(self, name, ie_device):
|
||||
self.run(name, None, ie_device)
|
||||
|
97
tests/model_hub_tests/torch_tests/torchvision_models
Normal file
97
tests/model_hub_tests/torch_tests/torchvision_models
Normal file
@ -0,0 +1,97 @@
|
||||
alexnet,none
|
||||
convnext_base,none
|
||||
convnext_large,none
|
||||
convnext_small,none
|
||||
convnext_tiny,none
|
||||
deeplabv3_mobilenet_v3_large,none
|
||||
deeplabv3_resnet101,none
|
||||
deeplabv3_resnet50,none
|
||||
densenet121,none
|
||||
densenet161,none
|
||||
densenet169,none
|
||||
densenet201,none
|
||||
efficientnet_b0,none
|
||||
efficientnet_b1,none
|
||||
efficientnet_b2,none
|
||||
efficientnet_b3,none
|
||||
efficientnet_b4,none
|
||||
efficientnet_b5,none
|
||||
efficientnet_b6,none
|
||||
efficientnet_b7,none
|
||||
efficientnet_v2_l,none
|
||||
efficientnet_v2_m,none
|
||||
efficientnet_v2_s,none
|
||||
fcn_resnet101,none
|
||||
fcn_resnet50,none
|
||||
googlenet,none
|
||||
inception_v3,none
|
||||
lraspp_mobilenet_v3_large,none
|
||||
maxvit_t,none
|
||||
mc3_18,none
|
||||
mnasnet0_5,none
|
||||
mnasnet0_75,none
|
||||
mnasnet1_0,none
|
||||
mnasnet1_3,none
|
||||
mobilenet_v2,none
|
||||
mobilenet_v3_large,none
|
||||
mobilenet_v3_small,none
|
||||
mvit_v1_b,none
|
||||
mvit_v2_s,none
|
||||
r2plus1d_18,none
|
||||
r3d_18,none
|
||||
raft_large,none
|
||||
raft_small,none
|
||||
regnet_x_16gf,none
|
||||
regnet_x_1_6gf,none
|
||||
regnet_x_32gf,none
|
||||
regnet_x_3_2gf,none
|
||||
regnet_x_400mf,none
|
||||
regnet_x_800mf,none
|
||||
regnet_x_8gf,none
|
||||
regnet_y_128gf,none
|
||||
regnet_y_16gf,none
|
||||
regnet_y_1_6gf,none
|
||||
regnet_y_32gf,none
|
||||
regnet_y_3_2gf,none
|
||||
regnet_y_400mf,none
|
||||
regnet_y_800mf,none
|
||||
regnet_y_8gf,none
|
||||
resnet101,none
|
||||
resnet152,none
|
||||
resnet18,none
|
||||
resnet34,none
|
||||
resnet50,none
|
||||
resnext101_32x8d,none
|
||||
resnext101_64x4d,none
|
||||
resnext50_32x4d,none
|
||||
s3d,none
|
||||
shufflenet_v2_x0_5,none
|
||||
shufflenet_v2_x1_0,none
|
||||
shufflenet_v2_x1_5,none
|
||||
shufflenet_v2_x2_0,none
|
||||
squeezenet1_0,none
|
||||
squeezenet1_1,none
|
||||
swin3d_b,none
|
||||
swin3d_s,none
|
||||
swin3d_t,none
|
||||
swin_b,none
|
||||
swin_s,none
|
||||
swin_t,none
|
||||
swin_v2_b,none
|
||||
swin_v2_s,none
|
||||
swin_v2_t,none
|
||||
vgg11,none
|
||||
vgg11_bn,none
|
||||
vgg13,none
|
||||
vgg13_bn,none
|
||||
vgg16,none
|
||||
vgg16_bn,none
|
||||
vgg19,none
|
||||
vgg19_bn,none
|
||||
vit_b_16,none,xfail,Tracing fails
|
||||
vit_b_32,none,xfail,Tracing fails
|
||||
vit_h_14,none,xfail,Tracing fails
|
||||
vit_l_16,none,xfail,Tracing fails
|
||||
vit_l_32,none,xfail,Tracing fails
|
||||
wide_resnet101_2,none
|
||||
wide_resnet50_2,none
|
Loading…
Reference in New Issue
Block a user