[PT FE] Add tests for Speech-Transformer (#20847)
* Add tests for Speech-Transformer * Update tests/model_hub_tests/torch_tests/test_speech-transformer.py * Update tests/model_hub_tests/torch_tests/test_speech-transformer.py
This commit is contained in:
parent
dcdf6750a7
commit
e976e7b90c
72
tests/model_hub_tests/torch_tests/test_speech-transformer.py
Normal file
72
tests/model_hub_tests/torch_tests/test_speech-transformer.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import torch
|
||||
import pytest
|
||||
import subprocess
|
||||
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from openvino import convert_model
|
||||
|
||||
|
||||
# To make tests reproducible we seed the random generator
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class TestSpeechTransformerConvertModel(TestConvertModel):
|
||||
def setup_class(self):
|
||||
self.repo_dir = tempfile.TemporaryDirectory()
|
||||
os.system(
|
||||
f"git clone https://github.com/mvafin/Speech-Transformer.git {self.repo_dir.name}")
|
||||
subprocess.check_call(["git", "checkout", "071eebb7549b66bae2cb93e3391fe99749389456"], cwd=self.repo_dir.name)
|
||||
checkpoint_url = "https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/speech-transformer-cn.pt"
|
||||
subprocess.check_call(["wget", checkpoint_url], cwd=self.repo_dir.name)
|
||||
|
||||
def load_model(self, model_name, model_link):
|
||||
sys.path.append(self.repo_dir.name)
|
||||
from transformer.transformer import Transformer
|
||||
|
||||
filename = os.path.join(self.repo_dir.name, 'speech-transformer-cn.pt')
|
||||
m = Transformer()
|
||||
m.load_state_dict(torch.load(
|
||||
filename, map_location=torch.device('cpu')))
|
||||
|
||||
self.example = (torch.randn(32, 209, 320),
|
||||
torch.stack(sorted(torch.randint(55, 250, [32]), reverse=True)),
|
||||
torch.randint(-1, 4232, [32, 20]))
|
||||
self.input = (torch.randn(32, 209, 320),
|
||||
torch.stack(sorted(torch.randint(55, 400, [32]), reverse=True)),
|
||||
torch.randint(-1, 4232, [32, 25]))
|
||||
return m
|
||||
|
||||
def get_inputs_info(self, model_obj):
|
||||
return None
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
return [i.numpy() for i in self.input]
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
m = convert_model(model_obj, example_input=self.example)
|
||||
return m
|
||||
|
||||
def infer_fw_model(self, model_obj, inputs):
|
||||
fw_outputs = model_obj(*[torch.from_numpy(i) for i in inputs])
|
||||
if isinstance(fw_outputs, dict):
|
||||
for k in fw_outputs.keys():
|
||||
fw_outputs[k] = fw_outputs[k].numpy(force=True)
|
||||
elif isinstance(fw_outputs, (list, tuple)):
|
||||
fw_outputs = [o.numpy(force=True) for o in fw_outputs]
|
||||
else:
|
||||
fw_outputs = [fw_outputs.numpy(force=True)]
|
||||
return fw_outputs
|
||||
|
||||
def teardown_class(self):
|
||||
# remove all downloaded files from cache
|
||||
self.repo_dir.cleanup()
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test_convert_model(self, ie_device):
|
||||
self.run("speech-transformer", None, ie_device)
|
Loading…
Reference in New Issue
Block a user