Skip test_lstm_simple_precommit on gpu (#6293)
* skip test on gpu * try to set timeout * add timeout param * rename to infer_timeout * fix infer_timeout
This commit is contained in:
parent
31907e51e9
commit
267f9a3b77
@ -24,7 +24,7 @@ class CommonLayerTest:
|
||||
def get_framework_results(self, inputs_dict, model_path):
|
||||
pass
|
||||
|
||||
def _test(self, framework_model, ref_net, ie_device, precision, ir_version, temp_dir,
|
||||
def _test(self, framework_model, ref_net, ie_device, precision, ir_version, temp_dir, infer_timeout=60,
|
||||
enabled_transforms='', disabled_transforms='', **kwargs):
|
||||
"""
|
||||
:param enabled_transforms/disabled_transforms: string with idxs of transforms that should be enabled/disabled.
|
||||
@ -74,7 +74,7 @@ class CommonLayerTest:
|
||||
ie_engine = IEInfer(model=path_to_xml,
|
||||
weights=path_to_bin,
|
||||
device=ie_device)
|
||||
infer_res = ie_engine.infer(input_data=inputs_dict)
|
||||
infer_res = ie_engine.infer(input_data=inputs_dict, infer_timeout=infer_timeout)
|
||||
|
||||
if hasattr(self, 'skip_framework') and self.skip_framework:
|
||||
warnings.warn('Framework is skipped')
|
||||
|
@ -25,8 +25,8 @@ class BaseInfer:
|
||||
def fw_infer(self, input_data):
|
||||
raise RuntimeError("This is base class, please implement infer function for the specific framework")
|
||||
|
||||
def infer(self, input_data):
|
||||
self.res = multiprocessing_run(self.fw_infer, [input_data], self.name, timeout=60)
|
||||
def infer(self, input_data, infer_timeout=60):
|
||||
self.res = multiprocessing_run(self.fw_infer, [input_data], self.name, infer_timeout)
|
||||
return self.res
|
||||
|
||||
|
||||
|
@ -140,10 +140,12 @@ class TestLSTM(Caffe2OnnxLayerTest):
|
||||
return onnx_net, None
|
||||
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.timeout(250)
|
||||
@pytest.mark.parametrize('direction', ["forward", "bidirectional", "reverse"])
|
||||
@pytest.mark.parametrize('cell_type', ["LSTM", "GRU", "RNN"])
|
||||
def test_lstm_simple_precommit(self, direction, cell_type, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_lstm(direction, cell_type), ie_device, precision, ir_version, temp_dir=temp_dir)
|
||||
self._test(*self.create_lstm(direction, cell_type), ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
infer_timeout=150)
|
||||
|
||||
# LSTM/RNN/GRU Sequence Generation
|
||||
@pytest.mark.parametrize('direction', ["forward", "bidirectional", "reverse"])
|
||||
|
Loading…
Reference in New Issue
Block a user