From 250e075ee9b58dbd24a49c0e7b2a02e1e2ac89e5 Mon Sep 17 00:00:00 2001 From: Ruslan Nugmanov Date: Fri, 23 Dec 2022 03:44:44 +0100 Subject: [PATCH] TFlite layer tests (#14760) * tflite layer tests * tflite inference * removed part to remove * clean-ups * removes input preprocessing for tflite * Apply suggestions from code review Co-authored-by: Evgenya Stepyreva Co-authored-by: Evgenya Stepyreva --- .../layer_tests/common/tf_layer_test_class.py | 87 ++++++++++++++++--- tests/layer_tests/conftest.py | 11 +++ .../layer_tests/tensorflow_tests/conftest.py | 1 + 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/tests/layer_tests/common/tf_layer_test_class.py b/tests/layer_tests/common/tf_layer_test_class.py index ffcf03f1c24..375cfc420cc 100644 --- a/tests/layer_tests/common/tf_layer_test_class.py +++ b/tests/layer_tests/common/tf_layer_test_class.py @@ -39,13 +39,42 @@ def save_to_pb(tf_model, path_to_saved_tf_model): return os.path.join(path_to_saved_tf_model, 'model.pb') -class CommonTFLayerTest(CommonLayerTest): - def produce_model_path(self, framework_model, save_path): - return save_to_pb(framework_model, save_path) +def save_pb_to_tflite(pb_model): + import tensorflow as tf - def get_framework_results(self, inputs_dict, model_path): - # Evaluate model via Tensorflow and IE - # Load the Tensorflow model + graph_summary = summarize_graph(pb_model) + inputs = [k for k in graph_summary['inputs'].keys()] + outputs = graph_summary['outputs'] + + converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(pb_model, inputs, outputs) + tflite_model = converter.convert() + + tflite_model_path = os.path.join(os.path.dirname(pb_model), 'model.tflite') + with tf.io.gfile.GFile(tflite_model_path, 'wb') as f: + f.write(tflite_model) + + return tflite_model_path + + +class CommonTFLayerTest(CommonLayerTest): + def prepare_tf_inputs(self, inputs_dict): + input = dict() + for key in inputs_dict.keys(): + data = inputs_dict.get(key) + if self.use_old_api or self.use_new_frontend: + key += ':0' + input[key] = transpose_nchw_to_nhwc(data, self.use_new_frontend, self.use_old_api) + + return input + + def produce_model_path(self, framework_model, save_path): + if not getattr(self, 'tflite', False): + return save_to_pb(framework_model, save_path) + else: + pb_model = save_to_pb(framework_model, save_path) + return save_pb_to_tflite(pb_model) + + def get_tf_results(self, inputs_dict, model_path): import tensorflow as tf from tensorflow.python.platform import gfile @@ -61,14 +90,7 @@ class CommonTFLayerTest(CommonLayerTest): sess.graph.as_default() tf.compat.v1.import_graph_def(graph_def, name='') - input = dict() - for key in inputs_dict.keys(): - data = inputs_dict.get(key) - if self.use_old_api or self.use_new_frontend: - key += ':0' - input[key] = transpose_nchw_to_nhwc(data, self.use_new_frontend, self.use_old_api) - - tf_res = sess.run([out + ":0" for out in outputs_list], input) + tf_res = sess.run([out + ":0" for out in outputs_list], inputs_dict) result = dict() for i, output in enumerate(outputs_list): @@ -76,3 +98,40 @@ class CommonTFLayerTest(CommonLayerTest): result[output] = transpose_nhwc_to_nchw(_tf_res, self.use_new_frontend, self.use_old_api) return result + + def get_tflite_results(self, inputs_dict, model_path): + import tensorflow as tf + interpreter = tf.compat.v1.lite.Interpreter(model_path=model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_name_to_id_mapping = {input['name']: input['index'] for input in input_details} + + for layer, data in inputs_dict.items(): + tensor_index = input_name_to_id_mapping[layer] + tensor_id = next(i for i, tensor in enumerate(input_details) if tensor['index'] == tensor_index) + interpreter.set_tensor(input_details[tensor_id]['index'], data) + + interpreter.invoke() + tf_result = dict() + for output in output_details: + tf_result[output['name']] = interpreter.get_tensor(output['index']) + + result = dict() + for out in tf_result.keys(): + _tf_res = tf_result[out] + result[out] = transpose_nhwc_to_nchw(_tf_res, self.use_new_frontend, + self.use_old_api) + + return tf_result + + def get_framework_results(self, inputs_dict, model_path): + if not getattr(self, 'tflite', False): + # prepare inputs + inputs_dict = self.prepare_tf_inputs(inputs_dict) + # get results from tensorflow + return self.get_tf_results(inputs_dict, model_path) + else: + # get results from tflite + return self.get_tflite_results(inputs_dict, model_path) diff --git a/tests/layer_tests/conftest.py b/tests/layer_tests/conftest.py index 78c0d3a2eba..c7ef4dac77d 100644 --- a/tests/layer_tests/conftest.py +++ b/tests/layer_tests/conftest.py @@ -70,6 +70,11 @@ def pytest_addoption(parser): action="store_true", help="Use old API for model processing in Inference Engine", ) + parser.addoption( + "--tflite", + required=False, + action="store_true", + help="Switch to tflite tests version") @pytest.fixture(scope="session") @@ -90,6 +95,12 @@ def use_old_api(request): return request.config.getoption('use_old_api') +@pytest.fixture(scope="session") +def tflite(request): + """Fixture function for command-line option.""" + return request.config.getoption('tflite') + + @pytest.fixture(scope="session", autouse=True) def checks_for_keys_usage(request): if request.config.getoption('use_old_api') and request.config.getoption('use_new_frontend'): diff --git a/tests/layer_tests/tensorflow_tests/conftest.py b/tests/layer_tests/tensorflow_tests/conftest.py index cc957ad4aac..90e4df4b6b7 100644 --- a/tests/layer_tests/tensorflow_tests/conftest.py +++ b/tests/layer_tests/tensorflow_tests/conftest.py @@ -14,6 +14,7 @@ from common.utils.common_utils import copy_files_by_pattern def pytest_generate_tests(metafunc): test_gen_attrs_names = list(inspect.signature(get_params).parameters) params = get_params() + setattr(metafunc.cls, 'tflite', metafunc.config.getoption('tflite')) metafunc.parametrize(test_gen_attrs_names, params, scope="function")