[ MO TF ] IdentityN support (#529)

This commit is contained in:
Evgenya Stepyreva 2020-05-25 10:52:58 +03:00 committed by GitHub
parent 507c06c8bc
commit b6a05c232e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 170 additions and 24 deletions

View File

@ -379,6 +379,7 @@ extensions/front/tf/gather_ext.py
extensions/front/tf/GatherTree_ext.py
extensions/front/tf/GNMT_DynamicSequenceLengths.py
extensions/front/tf/identity_ext.py
extensions/front/tf/identityN_to_identity.py
extensions/front/tf/InterpolateTransposes.py
extensions/front/tf/IteratorGetNext_ext.py
extensions/front/tf/LoopCond_ext.py

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.identity import IdentityOp
from extensions.ops.identity import Identity
from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Graph
@ -33,7 +33,7 @@ class SplitToIdentity(FrontReplacementOp):
def replace_sub_graph(self, graph: Graph, match: dict):
node = match['op']
identity = IdentityOp(graph, {'name': node.soft_get('name', node.id)}).create_node()
identity = Identity(graph, {'name': node.soft_get('name', node.id)}).create_node()
node.in_port(0).get_connection().set_destination(identity.in_port(0))
for idx, port in node.out_ports().items():

View File

@ -14,7 +14,7 @@
limitations under the License.
"""
from extensions.ops.identity import IdentityOp
from extensions.ops.identity import Identity
from mo.front.extractor import FrontExtractorOp
from mo.graph.graph import Node
@ -25,5 +25,5 @@ class BlockGradExt(FrontExtractorOp):
@classmethod
def extract(cls, node: Node):
IdentityOp.update_node_stat(node, {})
Identity.update_node_stat(node, {})
return cls.enabled

View File

@ -14,7 +14,7 @@
limitations under the License.
"""
from extensions.ops.identity import IdentityOp
from extensions.ops.identity import Identity
from mo.front.extractor import FrontExtractorOp
from mo.graph.graph import Node
@ -25,5 +25,5 @@ class CopyExt(FrontExtractorOp):
@classmethod
def extract(cls, node: Node):
IdentityOp.update_node_stat(node, {})
Identity.update_node_stat(node, {})
return cls.enabled

View File

@ -14,7 +14,7 @@
limitations under the License.
"""
from extensions.ops.identity import IdentityOp
from extensions.ops.identity import Identity
from mo.front.extractor import FrontExtractorOp
from mo.graph.graph import Node
@ -25,5 +25,5 @@ class DropoutExt(FrontExtractorOp):
@classmethod
def extract(cls, node: Node):
IdentityOp.update_node_stat(node, {})
Identity.update_node_stat(node, {})
return cls.enabled

View File

@ -14,7 +14,7 @@
limitations under the License.
"""
from extensions.ops.identity import IdentityOp
from extensions.ops.identity import Identity
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.utils.error import Error
@ -32,5 +32,5 @@ class DropoutFrontExtractor(FrontExtractorOp):
raise Error('Dropout node {} has more than one consumer. Unsupported.', node.name)
if not is_test:
raise Error('Dropout node {} has is_test: 0. This means training mode which is not supported.', node.name)
IdentityOp.update_node_stat(node)
Identity.update_node_stat(node)
return cls.enabled

View File

@ -0,0 +1,52 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.identity import Identity
from mo.front.common.replacement import FrontReplacementPattern
from mo.graph.graph import Graph, Node
class IdentityN_to_Identity(FrontReplacementPattern):
"""
Replaces IdentityN op with several Identity ops.
Example:
input_0 input_1 input_0 input_1
\ / | |
IdentityN Identity Identity
/ \ | |
output_0 output_1 output_0 output_1
"""
enabled = True
@staticmethod
def replace_identityN(node: Node):
graph = node.graph
name = node.soft_get('name', node.id)
assert node.has_valid('data_types'), 'IdentityN {} has no `data_types` attribute'.format(name)
dtypes = node.data_types
for idx, port in node.in_ports().items():
assert node.is_out_port_connected(idx), 'IdentityN {} has inconsistent input and output ports'.format(name)
assert idx < len(dtypes), 'IdentityN {} has inconsistent `data_types` attribute {}'.format(name, dtypes)
identity = Identity(graph, {'name': '{}/{}_port'.format(name, idx), 'data_type': dtypes[idx]}).create_node()
port.get_connection().set_destination(identity.in_port(0))
node.out_port(idx).get_connection().set_source(identity.out_port(0))
def find_and_replace_pattern(self, graph: Graph):
for identityN in graph.get_op_nodes(op='IdentityN'):
self.replace_identityN(identityN)

View File

@ -0,0 +1,63 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.tf.identityN_to_identity import IdentityN_to_Identity
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import result, regular_op_with_shaped_data, \
regular_op_with_empty_data, build_graph, connect, empty_data
nodes = {
**regular_op_with_shaped_data('placeholder_0', [1, 227, 227, 3], {'type': 'Parameter'}),
**regular_op_with_shaped_data('placeholder_1', [1, 227, 227, 3], {'type': 'Parameter'}),
**regular_op_with_empty_data('identityN', {'op': 'IdentityN', 'type': None, 'data_types': [np.int32, np.float],
'name': 'my_identity'}),
**empty_data('identityN_1_d'),
**regular_op_with_empty_data('identity0', {'op': 'Identity', 'type': None, 'data_type': np.int32,
'name': 'my_identity/0_port'}),
**regular_op_with_empty_data('identity1', {'op': 'Identity', 'type': None, 'data_type': np.float,
'name': 'my_identity/1_port'}),
**result('output0'),
**result('output1'),
}
class TestIdentityN(unittest.TestCase):
def test_identityN(self):
graph = build_graph(nodes, [
*connect('placeholder_0', '0:identityN'),
*connect('placeholder_1', '1:identityN'),
*connect('identityN:0', 'output0'),
('identityN', 'identityN_1_d', {'out': 1}),
('identityN_1_d', 'output1', {'out': 1}),
], nodes_with_edges_only=True)
IdentityN_to_Identity().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes, [
*connect('placeholder_0', 'identity0'),
*connect('placeholder_1', 'identity1'),
*connect('identity0', 'output0'),
*connect('identity1', 'output1'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.identity import IdentityOp
from extensions.ops.identity import Identity, IdentityN
from mo.front.extractor import FrontExtractorOp
from mo.front.tf.extractors.utils import tf_dtype_extractor
from mo.graph.graph import Node
@ -25,19 +25,34 @@ class IdentityFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node: Node):
IdentityOp.update_node_stat(node, {
Identity.update_node_stat(node, {
'data_type': tf_dtype_extractor(node.pb.attr["T"].type),
})
return cls.enabled
class IdentityNFrontExtractor(FrontExtractorOp):
op = 'IdentityN'
enabled = True
@classmethod
def extract(cls, node: Node):
dtypes = [tf_dtype_extractor(t) for t in node.pb.attr["T"].list.type]
IdentityN.update_node_stat(node, {
'data_types': dtypes,
'in_ports_count': len(dtypes),
'out_ports_count': len(dtypes),
})
return cls.enabled
class ReadVariableOpFrontExtractor(FrontExtractorOp):
op = 'ReadVariableOp'
enabled = True
@classmethod
def extract(cls, node: Node):
IdentityOp.update_node_stat(node, {
Identity.update_node_stat(node, {
'data_type': tf_dtype_extractor(node.pb.attr["T"].type),
})
return cls.enabled
@ -49,5 +64,5 @@ class StopGradientExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node: Node):
IdentityOp.update_node_stat(node, {'op': 'StopGradient'})
Identity.update_node_stat(node, {'op': 'StopGradient'})
return cls.enabled

View File

@ -13,24 +13,39 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.common.partial_infer.elemental import copy_shape_infer, copy_value
from mo.graph.graph import Graph
from mo.ops.op import Op
class IdentityOp(Op):
class Identity(Op):
op = 'Identity'
enabled = True
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'op': __class__.op,
'op': self.op,
'type': None,
'identity': True,
'infer': self.infer,
'in_ports_count': 1,
'out_ports_count': 1,
'infer': IdentityOp.shape_infer
}, attrs)
@staticmethod
def shape_infer(node):
copy_shape_infer(node, value_infer=copy_value)
def infer(node):
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
if node.in_port(0).data.get_value() is not None:
node.out_port(0).data.set_value(node.in_port(0).data.get_value())
class IdentityN(Op):
op = 'IdentityN'
enabled = True
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'op': self.op,
'type': None,
}, attrs)

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.identity import IdentityOp
from extensions.ops.identity import Identity
from mo.front.extractor import FrontExtractorOp
@ -23,5 +23,5 @@ class ClipGradientComponentFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
IdentityOp.update_node_stat(node, {})
Identity.update_node_stat(node, {})
return cls.enabled

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.identity import IdentityOp
from extensions.ops.identity import Identity
from mo.front.extractor import FrontExtractorOp
@ -23,5 +23,5 @@ class NoOpFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
IdentityOp.update_node_stat(node)
Identity.update_node_stat(node)
return cls.enabled