Files
openvino/tests/layer_tests/common/utils/tf_utils.py
2022-01-19 01:07:49 +03:00

131 lines
4.8 KiB
Python

# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import re
import tensorflow as tf
import numpy as np
from openvino.tools.mo.ops.op import PermuteAttrs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def load_graph(model_file, output_nodes_for_freeze=None):
is_meta = os.path.splitext(model_file)[-1] == ".meta"
tf.compat.v1.reset_default_graph()
graph = tf.Graph()
graph_def = tf.compat.v1.GraphDef() if not is_meta else tf.compat.v1.MetaGraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
nodes_to_clear_device = graph_def.node if isinstance(graph_def, tf.compat.v1.GraphDef) else graph_def.graph_def.node
for node in nodes_to_clear_device:
node.device = ""
if is_meta:
with tf.compat.v1.Session() as sess:
restorer = tf.compat.v1.train.import_meta_graph(graph_def)
restorer.restore(sess, re.sub('\.meta$', '', model_file))
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, graph_def.graph_def, output_nodes_for_freeze)
with graph.as_default():
tf.import_graph_def(graph_def, name='')
return graph
def collect_tf_references(model_path, feed_dict, out_layer, output_nodes_for_freeze=None):
_feed_dict = dict()
graph = load_graph(model_path, output_nodes_for_freeze)
output_tensors_list = list()
outputs_list = list()
for input in feed_dict:
input_node = [node for node in graph.as_graph_def().node if node.name == input][0]
if input_node.op == "Placeholder":
tensor = graph.get_tensor_by_name(input + ":0")
_feed_dict[tensor] = feed_dict[input]
else:
for parrent_input in input_node.input:
in_node = [node for node in graph.as_graph_def().node if node.name == parrent_input][0]
if in_node.op in ['Const', 'Assign', 'NoOp', 'Assert']:
continue
else:
tensor = graph.get_tensor_by_name(parrent_input + ":0")
_feed_dict[tensor] = feed_dict[input]
for output in out_layer:
tensor = graph.get_tensor_by_name(output + ":0")
output_tensors_list.append(tensor)
outputs_list.append(output)
with graph.as_default():
with tf.compat.v1.Session(graph=graph) as sess:
outputs = sess.run(output_tensors_list, feed_dict=_feed_dict)
out_dict = dict(zip(outputs_list, outputs))
return out_dict
def children(op, graph):
op = graph.get_operation_by_name(op)
return set(op for out in op.outputs for op in out.consumers())
def summarize_graph(model_path, output_nodes_for_freeze=None, reshape_net=None):
placeholders = dict()
variables = list()
outputs = list()
graph = load_graph(model_path, output_nodes_for_freeze)
unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert', 'switch_t', 'switch_f']
for node in graph.as_graph_def().node:
if node.op == 'Placeholder':
node_dict = dict()
node_dict['type'] = tf.DType(node.attr['dtype'].type).name
node_dict['shape'] = str(node.attr['shape'].shape.dim).replace('\n', '').replace(' ', '').replace(
'size:', '').replace('[', '').replace(']', '')
node_dict['shape'] = tuple(map(lambda x: int(x), node_dict['shape'].split(',')))
placeholders[node.name] = node_dict
if node.op == "Variable" or node.op == "VariableV2":
variables.append(node.name)
if len(children(node.name, graph)) == 0:
if node.op not in unlikely_output_types and node.name.split('/')[-1] not in unlikely_output_types:
outputs.append(node.name)
result = dict()
result['inputs'] = placeholders
result['outputs'] = outputs
if reshape_net:
out_layer = list(result['inputs'].keys()) + result['outputs']
feed_dict = {}
for inputl in reshape_net:
feed_dict.update({inputl: np.ones(shape=reshape_net[inputl])})
scoring_res = collect_tf_references(model_path=model_path, feed_dict=feed_dict, out_layer=out_layer)
for layer in scoring_res:
if layer in result['inputs']:
result['inputs'][layer]['shape'] = scoring_res[layer].shape
return result
def permute_nhwc_to_nchw(shape, use_new_frontend=False):
if use_new_frontend:
return shape
perm = PermuteAttrs.get_nhwc_to_nchw_permutation(len(shape)).perm
new_shape = np.array(shape)[perm]
return new_shape
def permute_nchw_to_nhwc(shape, use_new_frontend=False):
if use_new_frontend:
return shape
perm = PermuteAttrs.get_nchw_to_nhwc_permutation(len(shape)).perm
new_shape = np.array(shape)[perm]
return new_shape
def permute_axis(axis, permutation_inv):
return permutation_inv[axis]