diff --git a/model-optimizer/extensions/front/tf/while_ext.py b/model-optimizer/extensions/front/tf/while_ext.py index bb29379e05c..e3247ca82c5 100644 --- a/model-optimizer/extensions/front/tf/while_ext.py +++ b/model-optimizer/extensions/front/tf/while_ext.py @@ -51,6 +51,12 @@ def update_body_graph(body_graph: Graph, subgraph_proto: dict, # add incoming edges based on data_nodes_map for dst_port, inp in enumerate(pb_node.input): orig_src_id = inp.split(":")[0] + + # TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers: + # skip control flow dependency + if orig_src_id[0] == '^': + continue + src_id = map_original_name[orig_src_id] src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1]) assert (body_graph.has_node(src_id))