Update tests with lstm cases

This commit is contained in:
Malinin, Nikita
2021-10-29 15:11:39 +03:00
parent d94fe7d758
commit 16fe385117
8 changed files with 48 additions and 31 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ba5a652c26a0ff2d72d0a20fa4ea89801432e1a07f130d434bbac51f87704987
size 186

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:43ebd524dc790f0cdf08f4693f74d9b5be550c503c67aa1819bd13496fbee0be
size 341880

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0af1173b44db9f52b9878032b87679cc16521ec0c21fc413228e0854da4efed0
size 526864

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:de67711db3e1fb404a8b61be29954ab5b336d76b8cf7d3c5c98f7b9faab7c6e4
size 283

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eafca25515ecb19d073a96fbaa234c5c9004bf9646419bb1d3f9c0faa25c19c2
size 12739

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a60904b4395f63eb2e6ce515cb23be6468738200badbebaff6b5914e065042a9
size 126817

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:76959e68b61e35bca1e4d0a815a2fdd2a50fbfa1ca6d4cb6218d6d57f76603b2
size 23366

View File

@@ -18,13 +18,14 @@ CPU_CONFIG_PATH = HARDWARE_CONFIG_PATH / 'cpu.json'
GNA_CONFIG_PATH = HARDWARE_CONFIG_PATH / 'gna.json'
TEST_MODELS = [
('mobilenetv2_example', 'pytorch'),
('resnet_example', 'pytorch'),
('googlenet_example', 'pytorch'),
('mobilenetv2_ssd_example', 'pytorch'),
('densenet121_example', 'pytorch'),
('multiple_out_ports_net', 'tf'), # multiple output ports in node case check,
# ('rm_nnet4a', 'kaldi')
('mobilenetv2_example', 'pytorch', 'ANY'),
('resnet_example', 'pytorch', 'ANY'),
('googlenet_example', 'pytorch', 'ANY'),
('mobilenetv2_ssd_example', 'pytorch', 'ANY'),
('densenet121_example', 'pytorch', 'ANY'),
('multiple_out_ports_net', 'tf', 'ANY'),
('lstm_example', 'pytorch', 'GNA'),
('multiple_outputs_net_example', 'dldt', 'GNA')
]
CASCADE_MAP = Dict({
@@ -37,15 +38,16 @@ CASCADE_MAP = Dict({
@pytest.mark.parametrize(
'model_name, model_framework', TEST_MODELS,
ids=['{}_{}'.format(m[0], m[1]) for m in TEST_MODELS])
def test_build_quantization_graph(tmp_path, models, model_name, model_framework):
'model_name, model_framework, target_device', TEST_MODELS,
ids=['{}_{}'.format(m[0], m[1], m[2]) for m in TEST_MODELS])
def test_build_quantization_graph(tmp_path, models, model_name, model_framework, target_device):
model = models.get(model_name, model_framework, tmp_path)
model = load_model(model.model_params)
model = load_model(model.model_params, target_device=target_device)
hardware_config = HardwareConfig.from_json(CPU_CONFIG_PATH.as_posix())
if model_framework == 'kaldi':
if target_device == 'GNA':
hardware_config = HardwareConfig.from_json(GNA_CONFIG_PATH.as_posix())
else:
hardware_config = HardwareConfig.from_json(CPU_CONFIG_PATH.as_posix())
quantization_model = GraphTransformer(hardware_config).insert_fake_quantize(model)
@@ -246,27 +248,21 @@ def test_multibranch_propagation_without_fq_moving():
MODELS_WITH_LSTM = [
# ('rm_lstm4f', 'kaldi', {
# 'prev_memory_output69':
# ['next_lstm_output108', 'lstmprojectedstreams/Shape', 'input_fullyconnected/WithoutBiases'],
# 'prev_memory_state82':
# ['state_filtered_tahn100', 'clamp_scaleshift101/Mul_', 'next_lstm_state98'],
# 'prev_memory_output':
# ['next_lstm_output', 'affinetransform/WithoutBiases'],
# 'prev_memory_state':
# ['state_filtered_tahn', 'clamp_scaleshift/Mul_', 'next_lstm_state']
# })
('lstm_example', 'pytorch', {
'LSTM_15/TensorIterator/22/variable_1':
['Assign_298'],
'LSTM_15/TensorIterator/24/variable_2':
['Assign_305'],
'LSTM_19/TensorIterator/22/variable_1':
['Assign_327'],
'LSTM_19/TensorIterator/24/variable_2':
['Assign_334']
})
]
@pytest.fixture(scope='module', params=MODELS_WITH_LSTM,
ids=['{}_{}'.format(m[0], m[1]) for m in MODELS_WITH_LSTM])
def _params(request):
return request.param
def test_lstm_ends(_params, tmp_path, models):
model_name, model_framework, lstm_ends_ref = _params
def test_lstm_ends(tmp_path, models):
model_name, model_framework, lstm_ends_ref = MODELS_WITH_LSTM[0]
model = models.get(model_name, model_framework, tmp_path)
model = load_model(model.model_params)
read_values = get_nodes_by_type(model, ['ReadValue'])