Skip test_lstm_simple_precommit on gpu (#6293)

* skip test on gpu

* try to set timeout

* add timeout param

* rename to infer_timeout

* fix infer_timeout
This commit is contained in:
Victor Kuznetsov 2021-07-01 20:25:27 +03:00 committed by GitHub
parent 31907e51e9
commit 267f9a3b77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 5 deletions

View File

@ -24,7 +24,7 @@ class CommonLayerTest:
def get_framework_results(self, inputs_dict, model_path):
pass
def _test(self, framework_model, ref_net, ie_device, precision, ir_version, temp_dir,
def _test(self, framework_model, ref_net, ie_device, precision, ir_version, temp_dir, infer_timeout=60,
enabled_transforms='', disabled_transforms='', **kwargs):
"""
:param enabled_transforms/disabled_transforms: string with idxs of transforms that should be enabled/disabled.
@ -74,7 +74,7 @@ class CommonLayerTest:
ie_engine = IEInfer(model=path_to_xml,
weights=path_to_bin,
device=ie_device)
infer_res = ie_engine.infer(input_data=inputs_dict)
infer_res = ie_engine.infer(input_data=inputs_dict, infer_timeout=infer_timeout)
if hasattr(self, 'skip_framework') and self.skip_framework:
warnings.warn('Framework is skipped')

View File

@ -25,8 +25,8 @@ class BaseInfer:
def fw_infer(self, input_data):
raise RuntimeError("This is base class, please implement infer function for the specific framework")
def infer(self, input_data):
self.res = multiprocessing_run(self.fw_infer, [input_data], self.name, timeout=60)
def infer(self, input_data, infer_timeout=60):
self.res = multiprocessing_run(self.fw_infer, [input_data], self.name, infer_timeout)
return self.res

View File

@ -140,10 +140,12 @@ class TestLSTM(Caffe2OnnxLayerTest):
return onnx_net, None
@pytest.mark.precommit
@pytest.mark.timeout(250)
@pytest.mark.parametrize('direction', ["forward", "bidirectional", "reverse"])
@pytest.mark.parametrize('cell_type', ["LSTM", "GRU", "RNN"])
def test_lstm_simple_precommit(self, direction, cell_type, ie_device, precision, ir_version, temp_dir):
self._test(*self.create_lstm(direction, cell_type), ie_device, precision, ir_version, temp_dir=temp_dir)
self._test(*self.create_lstm(direction, cell_type), ie_device, precision, ir_version, temp_dir=temp_dir,
infer_timeout=150)
# LSTM/RNN/GRU Sequence Generation
@pytest.mark.parametrize('direction', ["forward", "bidirectional", "reverse"])