Add workaround for control edges to support TF 2.4 RNN (#4634)
Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
18cb230af4
commit
ff73955354
@ -51,6 +51,12 @@ def update_body_graph(body_graph: Graph, subgraph_proto: dict,
|
|||||||
# add incoming edges based on data_nodes_map
|
# add incoming edges based on data_nodes_map
|
||||||
for dst_port, inp in enumerate(pb_node.input):
|
for dst_port, inp in enumerate(pb_node.input):
|
||||||
orig_src_id = inp.split(":")[0]
|
orig_src_id = inp.split(":")[0]
|
||||||
|
|
||||||
|
# TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers:
|
||||||
|
# skip control flow dependency
|
||||||
|
if orig_src_id[0] == '^':
|
||||||
|
continue
|
||||||
|
|
||||||
src_id = map_original_name[orig_src_id]
|
src_id = map_original_name[orig_src_id]
|
||||||
src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1])
|
src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1])
|
||||||
assert (body_graph.has_node(src_id))
|
assert (body_graph.has_node(src_id))
|
||||||
|
Loading…
Reference in New Issue
Block a user