TFlite layer tests (#14760)
* tflite layer tests * tflite inference * removed part to remove * clean-ups * removes input preprocessing for tflite * Apply suggestions from code review Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com> Co-authored-by: Evgenya Stepyreva <eva.my.link@gmail.com>
This commit is contained in:
parent
a7b3ae6a9d
commit
250e075ee9
@ -39,13 +39,42 @@ def save_to_pb(tf_model, path_to_saved_tf_model):
|
|||||||
return os.path.join(path_to_saved_tf_model, 'model.pb')
|
return os.path.join(path_to_saved_tf_model, 'model.pb')
|
||||||
|
|
||||||
|
|
||||||
class CommonTFLayerTest(CommonLayerTest):
|
def save_pb_to_tflite(pb_model):
|
||||||
def produce_model_path(self, framework_model, save_path):
|
import tensorflow as tf
|
||||||
return save_to_pb(framework_model, save_path)
|
|
||||||
|
|
||||||
def get_framework_results(self, inputs_dict, model_path):
|
graph_summary = summarize_graph(pb_model)
|
||||||
# Evaluate model via Tensorflow and IE
|
inputs = [k for k in graph_summary['inputs'].keys()]
|
||||||
# Load the Tensorflow model
|
outputs = graph_summary['outputs']
|
||||||
|
|
||||||
|
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(pb_model, inputs, outputs)
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
tflite_model_path = os.path.join(os.path.dirname(pb_model), 'model.tflite')
|
||||||
|
with tf.io.gfile.GFile(tflite_model_path, 'wb') as f:
|
||||||
|
f.write(tflite_model)
|
||||||
|
|
||||||
|
return tflite_model_path
|
||||||
|
|
||||||
|
|
||||||
|
class CommonTFLayerTest(CommonLayerTest):
|
||||||
|
def prepare_tf_inputs(self, inputs_dict):
|
||||||
|
input = dict()
|
||||||
|
for key in inputs_dict.keys():
|
||||||
|
data = inputs_dict.get(key)
|
||||||
|
if self.use_old_api or self.use_new_frontend:
|
||||||
|
key += ':0'
|
||||||
|
input[key] = transpose_nchw_to_nhwc(data, self.use_new_frontend, self.use_old_api)
|
||||||
|
|
||||||
|
return input
|
||||||
|
|
||||||
|
def produce_model_path(self, framework_model, save_path):
|
||||||
|
if not getattr(self, 'tflite', False):
|
||||||
|
return save_to_pb(framework_model, save_path)
|
||||||
|
else:
|
||||||
|
pb_model = save_to_pb(framework_model, save_path)
|
||||||
|
return save_pb_to_tflite(pb_model)
|
||||||
|
|
||||||
|
def get_tf_results(self, inputs_dict, model_path):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
|
|
||||||
@ -61,14 +90,7 @@ class CommonTFLayerTest(CommonLayerTest):
|
|||||||
sess.graph.as_default()
|
sess.graph.as_default()
|
||||||
tf.compat.v1.import_graph_def(graph_def, name='')
|
tf.compat.v1.import_graph_def(graph_def, name='')
|
||||||
|
|
||||||
input = dict()
|
tf_res = sess.run([out + ":0" for out in outputs_list], inputs_dict)
|
||||||
for key in inputs_dict.keys():
|
|
||||||
data = inputs_dict.get(key)
|
|
||||||
if self.use_old_api or self.use_new_frontend:
|
|
||||||
key += ':0'
|
|
||||||
input[key] = transpose_nchw_to_nhwc(data, self.use_new_frontend, self.use_old_api)
|
|
||||||
|
|
||||||
tf_res = sess.run([out + ":0" for out in outputs_list], input)
|
|
||||||
|
|
||||||
result = dict()
|
result = dict()
|
||||||
for i, output in enumerate(outputs_list):
|
for i, output in enumerate(outputs_list):
|
||||||
@ -76,3 +98,40 @@ class CommonTFLayerTest(CommonLayerTest):
|
|||||||
result[output] = transpose_nhwc_to_nchw(_tf_res, self.use_new_frontend,
|
result[output] = transpose_nhwc_to_nchw(_tf_res, self.use_new_frontend,
|
||||||
self.use_old_api)
|
self.use_old_api)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_tflite_results(self, inputs_dict, model_path):
|
||||||
|
import tensorflow as tf
|
||||||
|
interpreter = tf.compat.v1.lite.Interpreter(model_path=model_path)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
|
||||||
|
input_name_to_id_mapping = {input['name']: input['index'] for input in input_details}
|
||||||
|
|
||||||
|
for layer, data in inputs_dict.items():
|
||||||
|
tensor_index = input_name_to_id_mapping[layer]
|
||||||
|
tensor_id = next(i for i, tensor in enumerate(input_details) if tensor['index'] == tensor_index)
|
||||||
|
interpreter.set_tensor(input_details[tensor_id]['index'], data)
|
||||||
|
|
||||||
|
interpreter.invoke()
|
||||||
|
tf_result = dict()
|
||||||
|
for output in output_details:
|
||||||
|
tf_result[output['name']] = interpreter.get_tensor(output['index'])
|
||||||
|
|
||||||
|
result = dict()
|
||||||
|
for out in tf_result.keys():
|
||||||
|
_tf_res = tf_result[out]
|
||||||
|
result[out] = transpose_nhwc_to_nchw(_tf_res, self.use_new_frontend,
|
||||||
|
self.use_old_api)
|
||||||
|
|
||||||
|
return tf_result
|
||||||
|
|
||||||
|
def get_framework_results(self, inputs_dict, model_path):
|
||||||
|
if not getattr(self, 'tflite', False):
|
||||||
|
# prepare inputs
|
||||||
|
inputs_dict = self.prepare_tf_inputs(inputs_dict)
|
||||||
|
# get results from tensorflow
|
||||||
|
return self.get_tf_results(inputs_dict, model_path)
|
||||||
|
else:
|
||||||
|
# get results from tflite
|
||||||
|
return self.get_tflite_results(inputs_dict, model_path)
|
||||||
|
@ -70,6 +70,11 @@ def pytest_addoption(parser):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use old API for model processing in Inference Engine",
|
help="Use old API for model processing in Inference Engine",
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--tflite",
|
||||||
|
required=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Switch to tflite tests version")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -90,6 +95,12 @@ def use_old_api(request):
|
|||||||
return request.config.getoption('use_old_api')
|
return request.config.getoption('use_old_api')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tflite(request):
|
||||||
|
"""Fixture function for command-line option."""
|
||||||
|
return request.config.getoption('tflite')
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def checks_for_keys_usage(request):
|
def checks_for_keys_usage(request):
|
||||||
if request.config.getoption('use_old_api') and request.config.getoption('use_new_frontend'):
|
if request.config.getoption('use_old_api') and request.config.getoption('use_new_frontend'):
|
||||||
|
@ -14,6 +14,7 @@ from common.utils.common_utils import copy_files_by_pattern
|
|||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
test_gen_attrs_names = list(inspect.signature(get_params).parameters)
|
test_gen_attrs_names = list(inspect.signature(get_params).parameters)
|
||||||
params = get_params()
|
params = get_params()
|
||||||
|
setattr(metafunc.cls, 'tflite', metafunc.config.getoption('tflite'))
|
||||||
metafunc.parametrize(test_gen_attrs_names, params, scope="function")
|
metafunc.parametrize(test_gen_attrs_names, params, scope="function")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user