[PT FE] Add tests for Speech-Transformer (#20847)

* Add tests for Speech-Transformer

* Update tests/model_hub_tests/torch_tests/test_speech-transformer.py

* Update tests/model_hub_tests/torch_tests/test_speech-transformer.py
This commit is contained in:
Maxim Vafin 2023-11-07 08:32:22 +01:00 committed by GitHub
parent dcdf6750a7
commit e976e7b90c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -0,0 +1,72 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import sys
import tempfile
import torch
import pytest
import subprocess
from models_hub_common.test_convert_model import TestConvertModel
from openvino import convert_model
# To make tests reproducible we seed the random generator
torch.manual_seed(0)
class TestSpeechTransformerConvertModel(TestConvertModel):
def setup_class(self):
self.repo_dir = tempfile.TemporaryDirectory()
os.system(
f"git clone https://github.com/mvafin/Speech-Transformer.git {self.repo_dir.name}")
subprocess.check_call(["git", "checkout", "071eebb7549b66bae2cb93e3391fe99749389456"], cwd=self.repo_dir.name)
checkpoint_url = "https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/speech-transformer-cn.pt"
subprocess.check_call(["wget", checkpoint_url], cwd=self.repo_dir.name)
def load_model(self, model_name, model_link):
sys.path.append(self.repo_dir.name)
from transformer.transformer import Transformer
filename = os.path.join(self.repo_dir.name, 'speech-transformer-cn.pt')
m = Transformer()
m.load_state_dict(torch.load(
filename, map_location=torch.device('cpu')))
self.example = (torch.randn(32, 209, 320),
torch.stack(sorted(torch.randint(55, 250, [32]), reverse=True)),
torch.randint(-1, 4232, [32, 20]))
self.input = (torch.randn(32, 209, 320),
torch.stack(sorted(torch.randint(55, 400, [32]), reverse=True)),
torch.randint(-1, 4232, [32, 25]))
return m
def get_inputs_info(self, model_obj):
return None
def prepare_inputs(self, inputs_info):
return [i.numpy() for i in self.input]
def convert_model(self, model_obj):
m = convert_model(model_obj, example_input=self.example)
return m
def infer_fw_model(self, model_obj, inputs):
fw_outputs = model_obj(*[torch.from_numpy(i) for i in inputs])
if isinstance(fw_outputs, dict):
for k in fw_outputs.keys():
fw_outputs[k] = fw_outputs[k].numpy(force=True)
elif isinstance(fw_outputs, (list, tuple)):
fw_outputs = [o.numpy(force=True) for o in fw_outputs]
else:
fw_outputs = [fw_outputs.numpy(force=True)]
return fw_outputs
def teardown_class(self):
# remove all downloaded files from cache
self.repo_dir.cleanup()
@pytest.mark.nightly
def test_convert_model(self, ie_device):
self.run("speech-transformer", None, ie_device)