String tensors in tf.Graph decoder (#18461)
* String consts in TF Decoder. * Fixed test. * Small corrections. * Added test, minor corrections. * Test correction, removed decoding. * Removed wrong changes. * Removed not needed code.
This commit is contained in:
committed by
GitHub
parent
826f345daf
commit
6af1bb307c
@@ -587,6 +587,29 @@ def create_keras_layer_with_tf_function_call_no_signature_single_input(tmp_dir):
|
||||
return model, model_ref, {'example_input': example_input}
|
||||
|
||||
|
||||
def create_keras_layer_with_string_tensor(tmp_dir):
|
||||
import tensorflow as tf
|
||||
class LayerModel(tf.Module):
|
||||
def __init__(self):
|
||||
super(LayerModel, self).__init__()
|
||||
self.var = tf.Variable("Text_1", dtype=tf.string)
|
||||
self.const = tf.constant("Text_2", dtype=tf.string)
|
||||
|
||||
@tf.function(input_signature=[tf.TensorSpec([1], tf.float32), tf.TensorSpec([1], tf.float32)])
|
||||
def __call__(self, input1, input2):
|
||||
return input1 + input2, self.var, self.const
|
||||
|
||||
model = LayerModel()
|
||||
|
||||
param1 = ov.opset8.parameter([1], dtype=np.float32)
|
||||
param2 = ov.opset8.parameter([1], dtype=np.float32)
|
||||
add = ov.opset8.add(param1, param2)
|
||||
parameter_list = [param1, param2]
|
||||
model_ref = Model([add], parameter_list, "test")
|
||||
|
||||
return model, model_ref, {}
|
||||
|
||||
|
||||
class TestMoConvertTF(CommonMOConvertTest):
|
||||
test_data = [
|
||||
# TF2
|
||||
@@ -608,6 +631,7 @@ class TestMoConvertTF(CommonMOConvertTest):
|
||||
create_keras_layer_with_tf_function_call,
|
||||
create_keras_layer_with_tf_function_call_no_signature,
|
||||
create_keras_layer_with_tf_function_call_no_signature_single_input,
|
||||
create_keras_layer_with_string_tensor,
|
||||
|
||||
# TF1
|
||||
create_tf_graph,
|
||||
|
||||
Reference in New Issue
Block a user