[TF FE] Fix conversion of TF1 OD models out-of-the-box (#20916)
* [TF FE] Fix conversion of TF1 OD models out-of-the-box Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add test While with nested If operation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Update tests/layer_tests/tensorflow_tests/test_tf_While.py --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
ac1fb7b955
commit
c6ca7865fb
@ -15,6 +15,7 @@
|
||||
#include "openvino/op/if.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "openvino/op/util/multi_subgraph_base.hpp"
|
||||
#include "tf_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
@ -151,6 +152,16 @@ void insert_result_before_merge(const shared_ptr<Merge>& merge_node,
|
||||
} // namespace
|
||||
|
||||
bool pass::SwitchMergeResolver::run_on_model(const shared_ptr<Model>& m) {
|
||||
// run this transformation recursively since this is a model pass
|
||||
for (const auto& op : m->get_ordered_ops()) {
|
||||
auto multisubgraph_op = as_type_ptr<ov::op::util::MultiSubGraphOp>(op);
|
||||
if (multisubgraph_op) {
|
||||
for (size_t i = 0; i < multisubgraph_op->get_internal_subgraphs_size(); ++i) {
|
||||
run_on_model(multisubgraph_op->get_function(static_cast<int>(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// split set of Switch and Merge nodes to clusters
|
||||
// where each cluster of Switch and Merge nodes will represent
|
||||
// the single If operation for fusing
|
||||
|
@ -16,6 +16,8 @@ namespace tensorflow {
|
||||
namespace pass {
|
||||
|
||||
bool ConstToResultRemover::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
||||
// Note: need to perform this transformation only on the main ov::Model graph
|
||||
// no need to apply it for sub-graphs!
|
||||
ResultVector results_to_remove;
|
||||
// look for isolated UnsupportedConst->Result sub-graphs to remove
|
||||
// also, find isolated Constant->Result sub-graphs to remove
|
||||
|
@ -50,6 +50,7 @@ class TestWhile(CommonTFLayerTest):
|
||||
|
||||
test_data_basic = [
|
||||
dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=False),
|
||||
dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=True),
|
||||
dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=False),
|
||||
dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=True)
|
||||
]
|
||||
@ -109,6 +110,7 @@ class TestWhileShapeVariant(CommonTFLayerTest):
|
||||
|
||||
test_data_basic = [
|
||||
dict(y_shape=[2, 3], lower_control_flow=False),
|
||||
dict(y_shape=[2, 3], lower_control_flow=True),
|
||||
dict(y_shape=[2, 1, 4], lower_control_flow=False),
|
||||
dict(y_shape=[2, 1, 4], lower_control_flow=True)
|
||||
]
|
||||
@ -122,3 +124,77 @@ class TestWhileShapeVariant(CommonTFLayerTest):
|
||||
self._test(*self.create_while_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
|
||||
class TestWhileWithNestedIf(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
assert 'x' in inputs_info, "Test error: inputs_info must contain `x`"
|
||||
assert 'y' in inputs_info, "Test error: inputs_info must contain `y`"
|
||||
x_shape = inputs_info['x']
|
||||
y_shape = inputs_info['y']
|
||||
inputs_data = {}
|
||||
inputs_data['x'] = np.random.randint(1, 10, x_shape).astype(np.int32)
|
||||
inputs_data['y'] = np.random.randint(-50, 50, y_shape).astype(np.int32)
|
||||
return inputs_data
|
||||
|
||||
def create_while_with_nested_if_net(self, y_shape, data_type, lower_control_flow):
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
||||
def while_function(x, y):
|
||||
@tf.function
|
||||
def cond(x, y):
|
||||
return tf.less(x, 10)
|
||||
|
||||
@tf.function
|
||||
def body(x, y):
|
||||
# create If operation inside While body
|
||||
# use different logic for updating y based on x
|
||||
def if_op(cond, y):
|
||||
def then_branch():
|
||||
y_new = tf.multiply(y, tf.constant(2, dtype=data_type))
|
||||
return y_new
|
||||
|
||||
def else_branch():
|
||||
y_new = tf.subtract(y, tf.constant(55, dtype=data_type))
|
||||
return y_new
|
||||
|
||||
if_op = tf.cond(cond, then_branch, else_branch)
|
||||
output = tf.identity(if_op, name='if_op')
|
||||
return output
|
||||
|
||||
y_new = tf.add(y, tf.constant(2, dtype=data_type))
|
||||
cond = tf.less(x, 5)
|
||||
y_new = if_op(cond, y_new)
|
||||
x_new = tf.add(x, 1)
|
||||
return x_new, y_new
|
||||
|
||||
return tf.while_loop(cond, body, [x, y])
|
||||
|
||||
tf_while_graph = tf.function(while_function)
|
||||
x = np.random.randint(9, 10, []).astype(data_type)
|
||||
y = np.random.randint(-50, 50, y_shape).astype(data_type)
|
||||
concrete_func = tf_while_graph.get_concrete_function(x, y)
|
||||
|
||||
# lower_control_flow defines representation of While operation
|
||||
# in case of lower_control_flow=True it is decomposed into LoopCond, NextIteration and TensorArray operations
|
||||
frozen_func = convert_variables_to_constants_v2(concrete_func,
|
||||
lower_control_flow=lower_control_flow)
|
||||
|
||||
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
|
||||
return graph_def, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=False),
|
||||
dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=True),
|
||||
dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=False),
|
||||
dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=True)
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182")
|
||||
def test_while_with_nested_if_basic(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_while_with_nested_if_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
Loading…
Reference in New Issue
Block a user