[TF FE] Test low-level While support by TF FE (#16257)
This commit is contained in:
parent
f0f1c47063
commit
3b71286f1d
@ -18,13 +18,9 @@ class TestWhile(CommonTFLayerTest):
|
||||
inputs_data['y'] = np.random.randint(-50, 50, y_shape).astype(np.int32)
|
||||
return inputs_data
|
||||
|
||||
def create_while_net(self, y_shape, data_type):
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(data_type, [], 'x')
|
||||
y = tf.compat.v1.placeholder(data_type, y_shape, 'y')
|
||||
|
||||
def create_while_net(self, y_shape, data_type, lower_control_flow):
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
||||
def while_function(x, y):
|
||||
@tf.function
|
||||
def cond(x, y):
|
||||
return tf.less(x, 10)
|
||||
@ -35,16 +31,26 @@ class TestWhile(CommonTFLayerTest):
|
||||
x_new = tf.add(x, 1)
|
||||
return x_new, y_new
|
||||
|
||||
tf.while_loop(cond, body, [x, y])
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
return tf.while_loop(cond, body, [x, y])
|
||||
|
||||
tf_net = sess.graph_def
|
||||
tf_while_graph = tf.function(while_function)
|
||||
x = np.random.randint(1, 10, []).astype(data_type)
|
||||
y = np.random.randint(-50, 50, y_shape).astype(data_type)
|
||||
concrete_func = tf_while_graph.get_concrete_function(x, y)
|
||||
|
||||
return tf_net, None
|
||||
# lower_control_flow defines representation of While operation
|
||||
# in case of lower_control_flow=True it is decomposed into LoopCond, NextIteration and TensorArray operations
|
||||
frozen_func = convert_variables_to_constants_v2(concrete_func,
|
||||
lower_control_flow=lower_control_flow)
|
||||
|
||||
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
|
||||
return graph_def, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(y_shape=[2, 3], data_type=tf.int32),
|
||||
dict(y_shape=[2, 1, 4], data_type=tf.int32),
|
||||
dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=False),
|
||||
dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=False),
|
||||
pytest.param(dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=True),
|
||||
marks=pytest.mark.xfail(reason="105670"))
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@ -68,13 +74,9 @@ class TestWhileShapeVariant(CommonTFLayerTest):
|
||||
inputs_data['y'] = np.random.randint(-50, 50, y_shape).astype(np.float32)
|
||||
return inputs_data
|
||||
|
||||
def create_while_net(self, y_shape):
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(tf.int32, [], 'x')
|
||||
y = tf.compat.v1.placeholder(tf.float32, y_shape, 'y')
|
||||
|
||||
def create_while_net(self, y_shape, lower_control_flow):
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
||||
def while_function(x, y):
|
||||
@tf.function
|
||||
def cond(x, y):
|
||||
return tf.less(x, 10)
|
||||
@ -86,18 +88,28 @@ class TestWhileShapeVariant(CommonTFLayerTest):
|
||||
x_new = tf.add(x, tf.constant(1, tf.int32))
|
||||
return x_new, y_new
|
||||
|
||||
tf.while_loop(cond, body, [x, y],
|
||||
shape_invariants=[tf.TensorShape([]),
|
||||
tf.TensorShape([None] + y_shape[1:])])
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
return tf.while_loop(cond, body, [x, y],
|
||||
shape_invariants=[tf.TensorShape([]),
|
||||
tf.TensorShape([None] + y_shape[1:])])
|
||||
|
||||
tf_net = sess.graph_def
|
||||
tf_while_graph = tf.function(while_function)
|
||||
x = np.random.randint(1, 10, []).astype(np.int32)
|
||||
y = np.random.randint(-50, 50, y_shape).astype(np.float32)
|
||||
concrete_func = tf_while_graph.get_concrete_function(x, y)
|
||||
|
||||
return tf_net, None
|
||||
# lower_control_flow defines representation of While operation
|
||||
# in case of lower_control_flow=True it is decomposed into LoopCond, NextIteration and TensorArray operations
|
||||
frozen_func = convert_variables_to_constants_v2(concrete_func,
|
||||
lower_control_flow=lower_control_flow)
|
||||
|
||||
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
|
||||
return graph_def, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(y_shape=[2, 3]),
|
||||
dict(y_shape=[2, 1, 4]),
|
||||
dict(y_shape=[2, 3], lower_control_flow=False),
|
||||
dict(y_shape=[2, 1, 4], lower_control_flow=False),
|
||||
pytest.param(dict(y_shape=[2, 1, 4], lower_control_flow=True),
|
||||
marks=pytest.mark.xfail(reason="105670"))
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
|
Loading…
Reference in New Issue
Block a user