diff --git a/tools/pot/tests/data/models/lstm_example/lstm_example.json b/tools/pot/tests/data/models/lstm_example/lstm_example.json new file mode 100644 index 00000000000..de45bbb1b86 --- /dev/null +++ b/tools/pot/tests/data/models/lstm_example/lstm_example.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba5a652c26a0ff2d72d0a20fa4ea89801432e1a07f130d434bbac51f87704987 +size 186 diff --git a/tools/pot/tests/data/models/lstm_example/lstm_example.onnx b/tools/pot/tests/data/models/lstm_example/lstm_example.onnx new file mode 100644 index 00000000000..30507001ee9 --- /dev/null +++ b/tools/pot/tests/data/models/lstm_example/lstm_example.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43ebd524dc790f0cdf08f4693f74d9b5be550c503c67aa1819bd13496fbee0be +size 341880 diff --git a/tools/pot/tests/data/models/multiple_outputs_net_example/multiple_outputs_net_example.bin b/tools/pot/tests/data/models/multiple_outputs_net_example/multiple_outputs_net_example.bin new file mode 100644 index 00000000000..6610aea14ea --- /dev/null +++ b/tools/pot/tests/data/models/multiple_outputs_net_example/multiple_outputs_net_example.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0af1173b44db9f52b9878032b87679cc16521ec0c21fc413228e0854da4efed0 +size 526864 diff --git a/tools/pot/tests/data/models/multiple_outputs_net_example/multiple_outputs_net_example.json b/tools/pot/tests/data/models/multiple_outputs_net_example/multiple_outputs_net_example.json new file mode 100644 index 00000000000..4fc2ca9cd5e --- /dev/null +++ b/tools/pot/tests/data/models/multiple_outputs_net_example/multiple_outputs_net_example.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de67711db3e1fb404a8b61be29954ab5b336d76b8cf7d3c5c98f7b9faab7c6e4 +size 283 diff --git a/tools/pot/tests/data/models/multiple_outputs_net_example/multiple_outputs_net_example.xml b/tools/pot/tests/data/models/multiple_outputs_net_example/multiple_outputs_net_example.xml new file mode 100644 index 00000000000..d7035fe9d87 --- /dev/null +++ b/tools/pot/tests/data/models/multiple_outputs_net_example/multiple_outputs_net_example.xml @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eafca25515ecb19d073a96fbaa234c5c9004bf9646419bb1d3f9c0faa25c19c2 +size 12739 diff --git a/tools/pot/tests/data/reference_models/lstm_example_pytorch.xml b/tools/pot/tests/data/reference_models/lstm_example_pytorch.xml new file mode 100644 index 00000000000..a20b82b8cff --- /dev/null +++ b/tools/pot/tests/data/reference_models/lstm_example_pytorch.xml @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a60904b4395f63eb2e6ce515cb23be6468738200badbebaff6b5914e065042a9 +size 126817 diff --git a/tools/pot/tests/data/reference_models/multiple_outputs_net_example_dldt.xml b/tools/pot/tests/data/reference_models/multiple_outputs_net_example_dldt.xml new file mode 100644 index 00000000000..1d8619041a5 --- /dev/null +++ b/tools/pot/tests/data/reference_models/multiple_outputs_net_example_dldt.xml @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76959e68b61e35bca1e4d0a815a2fdd2a50fbfa1ca6d4cb6218d6d57f76603b2 +size 23366 diff --git a/tools/pot/tests/test_graph.py b/tools/pot/tests/test_graph.py index febe190f2c7..b4dae2e6793 100755 --- a/tools/pot/tests/test_graph.py +++ b/tools/pot/tests/test_graph.py @@ -18,13 +18,14 @@ CPU_CONFIG_PATH = HARDWARE_CONFIG_PATH / 'cpu.json' GNA_CONFIG_PATH = HARDWARE_CONFIG_PATH / 'gna.json' TEST_MODELS = [ - ('mobilenetv2_example', 'pytorch'), - ('resnet_example', 'pytorch'), - ('googlenet_example', 'pytorch'), - ('mobilenetv2_ssd_example', 'pytorch'), - ('densenet121_example', 'pytorch'), - ('multiple_out_ports_net', 'tf'), # multiple output ports in node case check, - # ('rm_nnet4a', 'kaldi') + ('mobilenetv2_example', 'pytorch', 'ANY'), + ('resnet_example', 'pytorch', 'ANY'), + ('googlenet_example', 'pytorch', 'ANY'), + ('mobilenetv2_ssd_example', 'pytorch', 'ANY'), + ('densenet121_example', 'pytorch', 'ANY'), + ('multiple_out_ports_net', 'tf', 'ANY'), + ('lstm_example', 'pytorch', 'GNA'), + ('multiple_outputs_net_example', 'dldt', 'GNA') ] CASCADE_MAP = Dict({ @@ -37,15 +38,16 @@ CASCADE_MAP = Dict({ @pytest.mark.parametrize( - 'model_name, model_framework', TEST_MODELS, - ids=['{}_{}'.format(m[0], m[1]) for m in TEST_MODELS]) -def test_build_quantization_graph(tmp_path, models, model_name, model_framework): + 'model_name, model_framework, target_device', TEST_MODELS, + ids=['{}_{}'.format(m[0], m[1], m[2]) for m in TEST_MODELS]) +def test_build_quantization_graph(tmp_path, models, model_name, model_framework, target_device): model = models.get(model_name, model_framework, tmp_path) - model = load_model(model.model_params) + model = load_model(model.model_params, target_device=target_device) - hardware_config = HardwareConfig.from_json(CPU_CONFIG_PATH.as_posix()) - if model_framework == 'kaldi': + if target_device == 'GNA': hardware_config = HardwareConfig.from_json(GNA_CONFIG_PATH.as_posix()) + else: + hardware_config = HardwareConfig.from_json(CPU_CONFIG_PATH.as_posix()) quantization_model = GraphTransformer(hardware_config).insert_fake_quantize(model) @@ -246,27 +248,21 @@ def test_multibranch_propagation_without_fq_moving(): MODELS_WITH_LSTM = [ - # ('rm_lstm4f', 'kaldi', { - # 'prev_memory_output69': - # ['next_lstm_output108', 'lstmprojectedstreams/Shape', 'input_fullyconnected/WithoutBiases'], - # 'prev_memory_state82': - # ['state_filtered_tahn100', 'clamp_scaleshift101/Mul_', 'next_lstm_state98'], - # 'prev_memory_output': - # ['next_lstm_output', 'affinetransform/WithoutBiases'], - # 'prev_memory_state': - # ['state_filtered_tahn', 'clamp_scaleshift/Mul_', 'next_lstm_state'] - # }) + ('lstm_example', 'pytorch', { + 'LSTM_15/TensorIterator/22/variable_1': + ['Assign_298'], + 'LSTM_15/TensorIterator/24/variable_2': + ['Assign_305'], + 'LSTM_19/TensorIterator/22/variable_1': + ['Assign_327'], + 'LSTM_19/TensorIterator/24/variable_2': + ['Assign_334'] + }) ] -@pytest.fixture(scope='module', params=MODELS_WITH_LSTM, - ids=['{}_{}'.format(m[0], m[1]) for m in MODELS_WITH_LSTM]) -def _params(request): - return request.param - - -def test_lstm_ends(_params, tmp_path, models): - model_name, model_framework, lstm_ends_ref = _params +def test_lstm_ends(tmp_path, models): + model_name, model_framework, lstm_ends_ref = MODELS_WITH_LSTM[0] model = models.get(model_name, model_framework, tmp_path) model = load_model(model.model_params) read_values = get_nodes_by_type(model, ['ReadValue'])