diff --git a/model-optimizer/extensions/front/tf/assign_elimination.py b/model-optimizer/extensions/front/tf/assign_elimination.py index 9823fa72e45..ec3424854dd 100644 --- a/model-optimizer/extensions/front/tf/assign_elimination.py +++ b/model-optimizer/extensions/front/tf/assign_elimination.py @@ -16,68 +16,20 @@ import logging as log -import networkx as nx - -from mo.front.common.replacement import FrontReplacementOp +from mo.front.common.replacement import FrontReplacementPattern from mo.graph.graph import Graph -from mo.utils.error import Error -class AssignElimination(FrontReplacementOp): - op = "Assign" +class AssignAndAssertElimination(FrontReplacementPattern): + # The solution with removal of Assign and Assert operations is temporary. + # The proper solution is to keep these operations until the partial inference + # phase when control flow edges are properly handled and later unnecessary ones are eliminated. + # In order to achieve this we need to implement control flow inference function + # for these operations similar to "Merge" and "Switch" operations. enabled = True - def replace_sub_graph(self, graph: Graph, match: dict): - node = match['op'] - # here we request all data flow output edges (control flow edges will not be listed) - out_edges = node.out_edges() - if len(out_edges) == 0: - graph.remove_node(node.id) - log.debug('Assign op was removed {}'.format(node.id)) - else: - raise Error('Data flow edge coming out of Assign node {}'.format(node.id)) - - -class AssignSubElimination(FrontReplacementOp): - op = "AssignSub" - enabled = True - - def replace_sub_graph(self, graph: Graph, match: dict): - node = match['op'] - # here we request all data flow output edges (control flow edges will not be listed) - out_edges = node.out_edges() - if len(out_edges) == 0: - graph.remove_node(node.id) - log.debug('AssignSub op was removed {}'.format(node.id)) - else: - raise Error('Data flow edge coming out of AssignSub node {}'.format(node.id)) - - -class AssignAddElimination(FrontReplacementOp): - op = "AssignAdd" - enabled = True - - def replace_sub_graph(self, graph: Graph, match: dict): - node = match['op'] - # here we request all data flow output edges (control flow edges will not be listed) - out_edges = node.out_edges() - if len(out_edges) == 0: - graph.remove_node(node.id) - log.debug('AssignAdd op was removed {}'.format(node.id)) - else: - raise Error('Data flow edge coming out of AssignAdd node {}'.format(node.id)) - - -class AssertElimination(FrontReplacementOp): - op = "Assert" - enabled = True - - def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): - node = match['op'] - # here we request all data flow output edges (control flow edges will not be listed) - out_edges = node.out_edges() - if len(out_edges) == 0: - graph.remove_node(node.id) - log.debug('Assert op was removed {}'.format(node.id)) - else: - raise Error('Data flow edge coming out of Assert node {}'.format(node.id)) + def find_and_replace_pattern(self, graph: Graph): + for node in graph.get_op_nodes(): + if node.soft_get('op') in ["Assign", "AssignSub", "AssignAdd", "Assert"]: + log.debug('"{}" op with id="{}" was removed'.format(node.op, node.id)) + graph.remove_node(node.id)