Move FakeOutput resolving to back phase (#2033)

This commit is contained in:
Maxim Vafin 2020-09-07 10:20:24 +03:00 committed by GitHub
parent 51564f415c
commit 6730cab192
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 99 deletions

View File

@ -19,6 +19,7 @@ extensions/back/CropToStridedSlice.py
extensions/back/CutMemory.py
extensions/back/disable_unsupported_ND_operations.py
extensions/back/EnableConstantStridedSlice.py
extensions/back/FakeOutputResolver.py
extensions/back/ForceStrictPrecision.py
extensions/back/fuse_sub_div_min.py
extensions/back/FuseTransposesSequence.py
@ -120,7 +121,6 @@ extensions/front/disable_weights_quantize_value_propagation.py
extensions/front/div.py
extensions/front/eltwise_n.py
extensions/front/ExpandDimsToUnsqueeze.py
extensions/front/FakeOutputResolver.py
extensions/front/FillToBroadcast.py
extensions/front/flatten_to_reshape.py
extensions/front/freeze_placeholder_value.py

View File

@ -15,18 +15,19 @@
"""
from extensions.ops.elementwise import Add
from mo.front.common.replacement import FrontReplacementPattern
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes, rename_node
class FakeOutputResolver(FrontReplacementPattern):
class FakeOutputResolver(BackReplacementPattern):
"""
This transformation removes FakeOutput nodes. If producer of FakeOutput have only one consumer (FakeOutput itself)
the name of FakeOutput is inherited by its producer, otherwise FakeOutput is replaced with op which does nothing.
"""
enabled = True
force_clean_up = True
def find_and_replace_pattern(self, graph: Graph):
for fake_output in graph.get_op_nodes(op='FakeOutput'):
@ -46,5 +47,7 @@ class FakeOutputResolver(FrontReplacementPattern):
fake_output.in_port(0).get_connection().set_destination(add.in_port(0))
fake_output.out_port(0).get_connection().set_source(add.out_port(0))
else:
graph.erase_node(fake_output)
rename_node(producer, name)
result_in_port = fake_output.out_port(0).get_destination()
result_in_port.disconnect()
fake_output.in_port(0).get_connection().set_destination(result_in_port)
rename_nodes([(fake_output, name + '/TBD'), (producer, name)])

View File

@ -0,0 +1,93 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.back.FakeOutputResolver import FakeOutputResolver
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, result, regular_op_with_empty_data, const_with_data, connect, \
empty_data
class FakeOutputResolverTest(unittest.TestCase):
def test_one(self):
nodes = {
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
**regular_op_with_empty_data('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}),
**regular_op_with_empty_data('fake_output',
{'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name'}),
**result('result'),
}
edges = [*connect('input', 'some_op'),
*connect('some_op', 'fake_output'),
*connect('fake_output', 'result'),
]
graph = build_graph(nodes, edges)
edges_ref = [*connect('input', 'some_op'),
*connect('some_op', 'result'),
]
graph_ref = build_graph(nodes, edges_ref, {'some_op': {'name': 'my_output_name'}})
FakeOutputResolver().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
def test_multi(self):
nodes = {
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
**regular_op_with_empty_data('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}),
**empty_data('some_op_d2'),
**regular_op_with_empty_data('fake_output1',
{'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name1'}),
**regular_op_with_empty_data('fake_output2',
{'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name2'}),
**const_with_data('const1', int64_array(0)),
**const_with_data('const2', int64_array(0)),
**regular_op_with_empty_data('add1', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name1'}),
**regular_op_with_empty_data('add2', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name2'}),
**result('result1'),
**result('result2'),
}
edges = [*connect('input', 'some_op'),
*connect('some_op', 'fake_output1'),
('some_op', 'some_op_d2'),
('some_op_d2', 'fake_output2'),
*connect('fake_output1', 'result1'),
*connect('fake_output2', 'result2'),
]
graph = build_graph(nodes, edges)
edges_ref = [*connect('input', 'some_op'),
*connect('some_op', '0:add1'),
*connect('const1', '1:add1'),
('some_op', 'some_op_d2'),
('some_op_d2', 'add2', {'in': 0}),
*connect('const2', '1:add2'),
*connect('add1', 'result1'),
*connect('add2', 'result2'),
]
graph_ref = build_graph(nodes, edges_ref)
FakeOutputResolver().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result1')
self.assertTrue(flag, resp)

View File

@ -1,92 +0,0 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.FakeOutputResolver import FakeOutputResolver
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, result, regular_op, const
class FakeOutputResolverTest(unittest.TestCase):
def test_one(self):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}),
**regular_op('fake_output', {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name'}),
**result('result'),
}
edges = [('input', 'some_op'),
('some_op', 'fake_output'),
('fake_output', 'result'),
]
graph = build_graph(nodes, edges)
graph.graph['layout'] = 'NCHW'
graph.stage = 'front'
edges_ref = [('input', 'some_op'),
('some_op', 'result'),
]
graph_ref = build_graph(nodes, edges_ref, {'some_op': {'name': 'my_output_name'}})
FakeOutputResolver().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
def test_multi(self):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}),
**regular_op('fake_output1', {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name1'}),
**regular_op('fake_output2', {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name2'}),
**const('const1', int64_array(0)),
**const('const2', int64_array(0)),
**regular_op('add1', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name1'}),
**regular_op('add2', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name2'}),
**result('result1'),
**result('result2'),
}
edges = [('input', 'some_op'),
('some_op', 'fake_output1'),
('some_op', 'fake_output2'),
('fake_output1', 'result1'),
('fake_output2', 'result2'),
]
graph = build_graph(nodes, edges)
graph.graph['layout'] = 'NCHW'
graph.stage = 'front'
edges_ref = [('input', 'some_op'),
('some_op', 'add1'),
('const1', 'add1'),
('some_op', 'add2'),
('const2', 'add2'),
('add1', 'result1'),
('add2', 'result2'),
]
graph_ref = build_graph(nodes, edges_ref)
FakeOutputResolver().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result1')
self.assertTrue(flag, resp)

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.common.partial_infer.elemental import copy_shape_infer, copy_value
from mo.graph.graph import Graph
from mo.ops.op import Op
@ -31,7 +31,7 @@ class FakeOutput(Op):
'type': None,
'version': None,
'infer': None,
'infer': lambda n: copy_shape_infer(n, copy_value),
'type_infer': None,