Move FakeOutput resolving to back phase (#2033)
This commit is contained in:
parent
51564f415c
commit
6730cab192
@ -19,6 +19,7 @@ extensions/back/CropToStridedSlice.py
|
||||
extensions/back/CutMemory.py
|
||||
extensions/back/disable_unsupported_ND_operations.py
|
||||
extensions/back/EnableConstantStridedSlice.py
|
||||
extensions/back/FakeOutputResolver.py
|
||||
extensions/back/ForceStrictPrecision.py
|
||||
extensions/back/fuse_sub_div_min.py
|
||||
extensions/back/FuseTransposesSequence.py
|
||||
@ -120,7 +121,6 @@ extensions/front/disable_weights_quantize_value_propagation.py
|
||||
extensions/front/div.py
|
||||
extensions/front/eltwise_n.py
|
||||
extensions/front/ExpandDimsToUnsqueeze.py
|
||||
extensions/front/FakeOutputResolver.py
|
||||
extensions/front/FillToBroadcast.py
|
||||
extensions/front/flatten_to_reshape.py
|
||||
extensions/front/freeze_placeholder_value.py
|
||||
|
@ -15,18 +15,19 @@
|
||||
"""
|
||||
|
||||
from extensions.ops.elementwise import Add
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, rename_nodes, rename_node
|
||||
|
||||
|
||||
class FakeOutputResolver(FrontReplacementPattern):
|
||||
class FakeOutputResolver(BackReplacementPattern):
|
||||
"""
|
||||
This transformation removes FakeOutput nodes. If producer of FakeOutput have only one consumer (FakeOutput itself)
|
||||
the name of FakeOutput is inherited by its producer, otherwise FakeOutput is replaced with op which does nothing.
|
||||
"""
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for fake_output in graph.get_op_nodes(op='FakeOutput'):
|
||||
@ -46,5 +47,7 @@ class FakeOutputResolver(FrontReplacementPattern):
|
||||
fake_output.in_port(0).get_connection().set_destination(add.in_port(0))
|
||||
fake_output.out_port(0).get_connection().set_source(add.out_port(0))
|
||||
else:
|
||||
graph.erase_node(fake_output)
|
||||
rename_node(producer, name)
|
||||
result_in_port = fake_output.out_port(0).get_destination()
|
||||
result_in_port.disconnect()
|
||||
fake_output.in_port(0).get_connection().set_destination(result_in_port)
|
||||
rename_nodes([(fake_output, name + '/TBD'), (producer, name)])
|
93
model-optimizer/extensions/back/FakeOutputResolver_test.py
Normal file
93
model-optimizer/extensions/back/FakeOutputResolver_test.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.back.FakeOutputResolver import FakeOutputResolver
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, result, regular_op_with_empty_data, const_with_data, connect, \
|
||||
empty_data
|
||||
|
||||
|
||||
class FakeOutputResolverTest(unittest.TestCase):
|
||||
def test_one(self):
|
||||
nodes = {
|
||||
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
|
||||
**regular_op_with_empty_data('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}),
|
||||
**regular_op_with_empty_data('fake_output',
|
||||
{'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name'}),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [*connect('input', 'some_op'),
|
||||
*connect('some_op', 'fake_output'),
|
||||
*connect('fake_output', 'result'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
|
||||
edges_ref = [*connect('input', 'some_op'),
|
||||
*connect('some_op', 'result'),
|
||||
]
|
||||
|
||||
graph_ref = build_graph(nodes, edges_ref, {'some_op': {'name': 'my_output_name'}})
|
||||
|
||||
FakeOutputResolver().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_multi(self):
|
||||
nodes = {
|
||||
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
|
||||
**regular_op_with_empty_data('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}),
|
||||
**empty_data('some_op_d2'),
|
||||
**regular_op_with_empty_data('fake_output1',
|
||||
{'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name1'}),
|
||||
**regular_op_with_empty_data('fake_output2',
|
||||
{'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name2'}),
|
||||
|
||||
**const_with_data('const1', int64_array(0)),
|
||||
**const_with_data('const2', int64_array(0)),
|
||||
**regular_op_with_empty_data('add1', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name1'}),
|
||||
**regular_op_with_empty_data('add2', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name2'}),
|
||||
**result('result1'),
|
||||
**result('result2'),
|
||||
}
|
||||
edges = [*connect('input', 'some_op'),
|
||||
*connect('some_op', 'fake_output1'),
|
||||
('some_op', 'some_op_d2'),
|
||||
('some_op_d2', 'fake_output2'),
|
||||
*connect('fake_output1', 'result1'),
|
||||
*connect('fake_output2', 'result2'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
|
||||
edges_ref = [*connect('input', 'some_op'),
|
||||
*connect('some_op', '0:add1'),
|
||||
*connect('const1', '1:add1'),
|
||||
('some_op', 'some_op_d2'),
|
||||
('some_op_d2', 'add2', {'in': 0}),
|
||||
*connect('const2', '1:add2'),
|
||||
*connect('add1', 'result1'),
|
||||
*connect('add2', 'result2'),
|
||||
]
|
||||
|
||||
graph_ref = build_graph(nodes, edges_ref)
|
||||
|
||||
FakeOutputResolver().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result1')
|
||||
self.assertTrue(flag, resp)
|
@ -1,92 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.front.FakeOutputResolver import FakeOutputResolver
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, result, regular_op, const
|
||||
|
||||
|
||||
class FakeOutputResolverTest(unittest.TestCase):
|
||||
def test_one(self):
|
||||
nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}),
|
||||
**regular_op('fake_output', {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name'}),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [('input', 'some_op'),
|
||||
('some_op', 'fake_output'),
|
||||
('fake_output', 'result'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
graph.stage = 'front'
|
||||
|
||||
edges_ref = [('input', 'some_op'),
|
||||
('some_op', 'result'),
|
||||
]
|
||||
|
||||
graph_ref = build_graph(nodes, edges_ref, {'some_op': {'name': 'my_output_name'}})
|
||||
|
||||
FakeOutputResolver().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_multi(self):
|
||||
nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}),
|
||||
**regular_op('fake_output1', {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name1'}),
|
||||
**regular_op('fake_output2', {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name2'}),
|
||||
|
||||
**const('const1', int64_array(0)),
|
||||
**const('const2', int64_array(0)),
|
||||
**regular_op('add1', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name1'}),
|
||||
**regular_op('add2', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name2'}),
|
||||
**result('result1'),
|
||||
**result('result2'),
|
||||
}
|
||||
edges = [('input', 'some_op'),
|
||||
('some_op', 'fake_output1'),
|
||||
('some_op', 'fake_output2'),
|
||||
('fake_output1', 'result1'),
|
||||
('fake_output2', 'result2'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
graph.stage = 'front'
|
||||
|
||||
edges_ref = [('input', 'some_op'),
|
||||
('some_op', 'add1'),
|
||||
('const1', 'add1'),
|
||||
('some_op', 'add2'),
|
||||
('const2', 'add2'),
|
||||
('add1', 'result1'),
|
||||
('add2', 'result2'),
|
||||
]
|
||||
|
||||
graph_ref = build_graph(nodes, edges_ref)
|
||||
|
||||
FakeOutputResolver().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result1')
|
||||
self.assertTrue(flag, resp)
|
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer, copy_value
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
@ -31,7 +31,7 @@ class FakeOutput(Op):
|
||||
'type': None,
|
||||
'version': None,
|
||||
|
||||
'infer': None,
|
||||
'infer': lambda n: copy_shape_infer(n, copy_value),
|
||||
|
||||
'type_infer': None,
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user