Used share_memory param in tf.Graph decoder. (#18747)
This commit is contained in:
parent
be02d1a3c9
commit
5aad9ee652
@ -10,12 +10,13 @@ from openvino.frontend.tensorflow.py_tensorflow_frontend import _FrontEndPyGraph
|
||||
|
||||
|
||||
class GraphIteratorTFGraph(GraphIterator):
|
||||
def __init__(self, tf_graph: tf.Graph, inner_graph: bool = False):
|
||||
def __init__(self, tf_graph: tf.Graph, share_weights: bool, inner_graph: bool = False):
|
||||
GraphIterator.__init__(self)
|
||||
self.m_graph = tf_graph
|
||||
self.m_node_index = 0
|
||||
self.m_decoders = []
|
||||
self.m_inner_graph = inner_graph
|
||||
self.m_share_weights = share_weights
|
||||
|
||||
self.m_vars = None
|
||||
if hasattr(tf_graph, "variables"):
|
||||
@ -24,7 +25,7 @@ class GraphIteratorTFGraph(GraphIterator):
|
||||
self.m_vars = tf_graph.variables
|
||||
|
||||
for op in tf_graph.get_operations():
|
||||
self.m_decoders.append(TFGraphNodeDecoder(op, inner_graph))
|
||||
self.m_decoders.append(TFGraphNodeDecoder(op, share_weights, inner_graph))
|
||||
|
||||
self.m_iterators = {}
|
||||
for func_name, _ in self.m_graph._functions.items():
|
||||
@ -85,5 +86,7 @@ class GraphIteratorTFGraph(GraphIterator):
|
||||
if func_name not in self.m_iterators:
|
||||
return None
|
||||
if self.m_iterators[func_name] is None:
|
||||
self.m_iterators[func_name] = GraphIteratorTFGraph(self.m_graph._functions[func_name].graph, True)
|
||||
self.m_iterators[func_name] = GraphIteratorTFGraph(self.m_graph._functions[func_name].graph,
|
||||
self.m_share_weights,
|
||||
True)
|
||||
return self.m_iterators[func_name]
|
||||
|
@ -47,7 +47,7 @@ def tf_attr_to_ov(attr):
|
||||
|
||||
|
||||
class TFGraphNodeDecoder(DecoderBase):
|
||||
def __init__(self, operation: tf.Operation, inner_graph: bool):
|
||||
def __init__(self, operation: tf.Operation, share_weights: bool, inner_graph: bool):
|
||||
DecoderBase.__init__(self)
|
||||
assert isinstance(operation, tf.Operation), "Unknown operation type. " \
|
||||
"Expected tf.Operation, got {}".format(type(operation))
|
||||
@ -57,9 +57,7 @@ class TFGraphNodeDecoder(DecoderBase):
|
||||
|
||||
# Copies value from inner buffer of TF_Operation to NodeDef class.
|
||||
self.m_node_def = self.m_operation.node_def
|
||||
|
||||
# TODO: Create parameter in convert_model() for turning on/off shared memory, ticket: 114971
|
||||
self.m_shared_memory = True
|
||||
self.m_shared_memory = share_weights
|
||||
|
||||
if self.m_operation.type == "Const":
|
||||
self.m_data_type = tf.dtypes.DType(self.m_node_def.attr["dtype"].type).name
|
||||
|
@ -143,7 +143,7 @@ def type_supported_by_tf_fe(input_model):
|
||||
return False
|
||||
|
||||
|
||||
def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_types, example_input):
|
||||
def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_types, example_input, share_weights):
|
||||
input_model = trace_tf_model_if_needed(input_model, placeholder_shapes, placeholder_data_types, example_input)
|
||||
|
||||
import tensorflow as tf
|
||||
@ -151,9 +151,9 @@ def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_t
|
||||
if model_is_graph_iterator(input_model):
|
||||
return input_model
|
||||
if isinstance(input_model, tf.Graph):
|
||||
return GraphIteratorTFGraph(input_model)
|
||||
return GraphIteratorTFGraph(input_model, share_weights)
|
||||
elif isinstance(input_model, tf.types.experimental.ConcreteFunction):
|
||||
return GraphIteratorTFGraph(input_model.graph)
|
||||
return GraphIteratorTFGraph(input_model.graph, share_weights)
|
||||
raise Exception("Could not wrap model of type {} to GraphIteratorTFGraph.".format(type(input_model)))
|
||||
|
||||
|
||||
|
@ -688,7 +688,7 @@ class TestMoConvertTF(CommonMOConvertTest):
|
||||
|
||||
def test_zero_copy(self, ie_device, precision, ir_version, temp_dir):
|
||||
import tensorflow as tf
|
||||
from openvino.tools.mo import convert_model
|
||||
from openvino.tools.ovc import convert_model
|
||||
from openvino.runtime import compile_model
|
||||
class LayerModel(tf.Module):
|
||||
def __init__(self):
|
||||
@ -710,7 +710,7 @@ class TestMoConvertTF(CommonMOConvertTest):
|
||||
test_input = np.array(7.).astype(np.float32)
|
||||
|
||||
# Convert model to OV
|
||||
ov_model = convert_model(keras_model, input_shape=[1])
|
||||
ov_model = convert_model(keras_model, input=[1], share_weights=True)
|
||||
cmp_model = compile_model(ov_model)
|
||||
|
||||
# Check model inference
|
||||
@ -734,8 +734,55 @@ class TestMoConvertTF(CommonMOConvertTest):
|
||||
assert np.array_equal(ov_infer2['Identity:0'], fw_infer2)
|
||||
assert np.array_equal(ov_infer2['Identity:0'], [ 0., 8., 16.])
|
||||
|
||||
def test_turn_off_sharing(self, ie_device, precision, ir_version, temp_dir):
|
||||
import tensorflow as tf
|
||||
from openvino.tools.ovc import convert_model
|
||||
from openvino.runtime import compile_model
|
||||
class LayerModel(tf.Module):
|
||||
def __init__(self):
|
||||
super(LayerModel, self).__init__()
|
||||
self.var1 = tf.Variable([7., 5., 6.], name='var1')
|
||||
self.var2 = tf.Variable([5., 7., 3.], name='var2')
|
||||
|
||||
|
||||
@tf.function
|
||||
def sub_function(self, input):
|
||||
return input * self.var1 + self.var2
|
||||
|
||||
@tf.function()
|
||||
def __call__(self, input):
|
||||
return self.sub_function(input)
|
||||
|
||||
# Create TF model with variables
|
||||
keras_model = LayerModel()
|
||||
test_input = np.array(7.).astype(np.float32)
|
||||
|
||||
# Convert model to OV
|
||||
ov_model = convert_model(keras_model, input=[1], share_weights=False)
|
||||
cmp_model = compile_model(ov_model)
|
||||
|
||||
# Check model inference
|
||||
ov_infer1 = cmp_model(test_input, ie_device)
|
||||
fw_infer1 = keras_model(test_input).numpy()
|
||||
|
||||
assert np.array_equal(ov_infer1['Identity:0'], fw_infer1)
|
||||
assert np.array_equal(ov_infer1['Identity:0'], [54., 42., 45.])
|
||||
|
||||
# Change value of variables in original model
|
||||
for val in keras_model.variables:
|
||||
arr = val.value().__array__()
|
||||
arr[0] = 0
|
||||
arr[1] = 1
|
||||
arr[2] = 2
|
||||
|
||||
# Check model inference
|
||||
ov_infer2 = cmp_model(test_input)
|
||||
fw_infer2 = keras_model(test_input).numpy()
|
||||
|
||||
# Check model inference calculated with old constant values
|
||||
assert not np.array_equal(ov_infer2['Identity:0'], fw_infer2)
|
||||
assert np.array_equal(ov_infer2['Identity:0'], [54., 42., 45.])
|
||||
|
||||
def test_memory_loss(self, ie_device, precision, ir_version, temp_dir):
|
||||
# This test checks that the memory allocated for constants
|
||||
# is not lost after returning the model from convert_model() method.
|
||||
@ -823,7 +870,7 @@ class TestTFLoadByModel(unittest.TestCase):
|
||||
return tf_net
|
||||
from openvino.frontend.tensorflow.graph_iterator import GraphIteratorTFGraph
|
||||
from openvino.frontend import FrontEndManager
|
||||
model = GraphIteratorTFGraph(simple_tf_model())
|
||||
model = GraphIteratorTFGraph(simple_tf_model(), True)
|
||||
fem = FrontEndManager()
|
||||
fe = fem.load_by_model(model)
|
||||
assert fe is not None
|
||||
|
@ -396,7 +396,8 @@ def prepare_ir(argv: argparse.Namespace):
|
||||
argv.input_model = create_tf_graph_iterator(argv.input_model,
|
||||
argv.placeholder_shapes,
|
||||
argv.placeholder_data_types,
|
||||
getattr(argv, "example_input", None))
|
||||
getattr(argv, "example_input", None),
|
||||
argv.share_weights)
|
||||
try:
|
||||
t.send_event("mo", "conversion_method", moc_front_end.get_name() + "_frontend")
|
||||
moc_front_end.add_extension(TelemetryExtension("mo", t.send_event, t.send_error, t.send_stack_trace))
|
||||
|
@ -136,7 +136,8 @@ def prepare_ir(argv: argparse.Namespace):
|
||||
argv.input_model = create_tf_graph_iterator(argv.input_model,
|
||||
argv.placeholder_shapes,
|
||||
argv.placeholder_data_types,
|
||||
getattr(argv, "example_input", None))
|
||||
getattr(argv, "example_input", None),
|
||||
argv.share_weights)
|
||||
t.send_event("mo", "conversion_method", moc_front_end.get_name() + "_frontend")
|
||||
moc_front_end.add_extension(TelemetryExtension("mo", t.send_event, t.send_error, t.send_stack_trace))
|
||||
if new_extensions_used(argv):
|
||||
|
Loading…
Reference in New Issue
Block a user