[ MO TF ] IdentityN support (#529)
This commit is contained in:
parent
507c06c8bc
commit
b6a05c232e
@ -379,6 +379,7 @@ extensions/front/tf/gather_ext.py
|
||||
extensions/front/tf/GatherTree_ext.py
|
||||
extensions/front/tf/GNMT_DynamicSequenceLengths.py
|
||||
extensions/front/tf/identity_ext.py
|
||||
extensions/front/tf/identityN_to_identity.py
|
||||
extensions/front/tf/InterpolateTransposes.py
|
||||
extensions/front/tf/IteratorGetNext_ext.py
|
||||
extensions/front/tf/LoopCond_ext.py
|
||||
|
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.identity import IdentityOp
|
||||
from extensions.ops.identity import Identity
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.graph.graph import Graph
|
||||
|
||||
@ -33,7 +33,7 @@ class SplitToIdentity(FrontReplacementOp):
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
|
||||
identity = IdentityOp(graph, {'name': node.soft_get('name', node.id)}).create_node()
|
||||
identity = Identity(graph, {'name': node.soft_get('name', node.id)}).create_node()
|
||||
node.in_port(0).get_connection().set_destination(identity.in_port(0))
|
||||
|
||||
for idx, port in node.out_ports().items():
|
||||
|
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from extensions.ops.identity import IdentityOp
|
||||
from extensions.ops.identity import Identity
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.graph.graph import Node
|
||||
|
||||
@ -25,5 +25,5 @@ class BlockGradExt(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
IdentityOp.update_node_stat(node, {})
|
||||
Identity.update_node_stat(node, {})
|
||||
return cls.enabled
|
||||
|
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from extensions.ops.identity import IdentityOp
|
||||
from extensions.ops.identity import Identity
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.graph.graph import Node
|
||||
|
||||
@ -25,5 +25,5 @@ class CopyExt(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
IdentityOp.update_node_stat(node, {})
|
||||
Identity.update_node_stat(node, {})
|
||||
return cls.enabled
|
||||
|
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from extensions.ops.identity import IdentityOp
|
||||
from extensions.ops.identity import Identity
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.graph.graph import Node
|
||||
|
||||
@ -25,5 +25,5 @@ class DropoutExt(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
IdentityOp.update_node_stat(node, {})
|
||||
Identity.update_node_stat(node, {})
|
||||
return cls.enabled
|
||||
|
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from extensions.ops.identity import IdentityOp
|
||||
from extensions.ops.identity import Identity
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
from mo.utils.error import Error
|
||||
@ -32,5 +32,5 @@ class DropoutFrontExtractor(FrontExtractorOp):
|
||||
raise Error('Dropout node {} has more than one consumer. Unsupported.', node.name)
|
||||
if not is_test:
|
||||
raise Error('Dropout node {} has is_test: 0. This means training mode which is not supported.', node.name)
|
||||
IdentityOp.update_node_stat(node)
|
||||
Identity.update_node_stat(node)
|
||||
return cls.enabled
|
||||
|
52
model-optimizer/extensions/front/tf/identityN_to_identity.py
Normal file
52
model-optimizer/extensions/front/tf/identityN_to_identity.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from extensions.ops.identity import Identity
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.graph.graph import Graph, Node
|
||||
|
||||
|
||||
class IdentityN_to_Identity(FrontReplacementPattern):
|
||||
"""
|
||||
Replaces IdentityN op with several Identity ops.
|
||||
|
||||
Example:
|
||||
input_0 input_1 input_0 input_1
|
||||
\ / | |
|
||||
IdentityN Identity Identity
|
||||
/ \ | |
|
||||
output_0 output_1 output_0 output_1
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
@staticmethod
|
||||
def replace_identityN(node: Node):
|
||||
graph = node.graph
|
||||
name = node.soft_get('name', node.id)
|
||||
|
||||
assert node.has_valid('data_types'), 'IdentityN {} has no `data_types` attribute'.format(name)
|
||||
dtypes = node.data_types
|
||||
|
||||
for idx, port in node.in_ports().items():
|
||||
assert node.is_out_port_connected(idx), 'IdentityN {} has inconsistent input and output ports'.format(name)
|
||||
assert idx < len(dtypes), 'IdentityN {} has inconsistent `data_types` attribute {}'.format(name, dtypes)
|
||||
identity = Identity(graph, {'name': '{}/{}_port'.format(name, idx), 'data_type': dtypes[idx]}).create_node()
|
||||
port.get_connection().set_destination(identity.in_port(0))
|
||||
node.out_port(idx).get_connection().set_source(identity.out_port(0))
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for identityN in graph.get_op_nodes(op='IdentityN'):
|
||||
self.replace_identityN(identityN)
|
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.tf.identityN_to_identity import IdentityN_to_Identity
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import result, regular_op_with_shaped_data, \
|
||||
regular_op_with_empty_data, build_graph, connect, empty_data
|
||||
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('placeholder_0', [1, 227, 227, 3], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('placeholder_1', [1, 227, 227, 3], {'type': 'Parameter'}),
|
||||
|
||||
**regular_op_with_empty_data('identityN', {'op': 'IdentityN', 'type': None, 'data_types': [np.int32, np.float],
|
||||
'name': 'my_identity'}),
|
||||
**empty_data('identityN_1_d'),
|
||||
**regular_op_with_empty_data('identity0', {'op': 'Identity', 'type': None, 'data_type': np.int32,
|
||||
'name': 'my_identity/0_port'}),
|
||||
**regular_op_with_empty_data('identity1', {'op': 'Identity', 'type': None, 'data_type': np.float,
|
||||
'name': 'my_identity/1_port'}),
|
||||
|
||||
**result('output0'),
|
||||
**result('output1'),
|
||||
}
|
||||
|
||||
|
||||
class TestIdentityN(unittest.TestCase):
|
||||
def test_identityN(self):
|
||||
graph = build_graph(nodes, [
|
||||
*connect('placeholder_0', '0:identityN'),
|
||||
*connect('placeholder_1', '1:identityN'),
|
||||
*connect('identityN:0', 'output0'),
|
||||
('identityN', 'identityN_1_d', {'out': 1}),
|
||||
('identityN_1_d', 'output1', {'out': 1}),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
IdentityN_to_Identity().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('placeholder_0', 'identity0'),
|
||||
*connect('placeholder_1', 'identity1'),
|
||||
*connect('identity0', 'output0'),
|
||||
*connect('identity1', 'output1'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.identity import IdentityOp
|
||||
from extensions.ops.identity import Identity, IdentityN
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.tf.extractors.utils import tf_dtype_extractor
|
||||
from mo.graph.graph import Node
|
||||
@ -25,19 +25,34 @@ class IdentityFrontExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
IdentityOp.update_node_stat(node, {
|
||||
Identity.update_node_stat(node, {
|
||||
'data_type': tf_dtype_extractor(node.pb.attr["T"].type),
|
||||
})
|
||||
return cls.enabled
|
||||
|
||||
|
||||
class IdentityNFrontExtractor(FrontExtractorOp):
|
||||
op = 'IdentityN'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
dtypes = [tf_dtype_extractor(t) for t in node.pb.attr["T"].list.type]
|
||||
IdentityN.update_node_stat(node, {
|
||||
'data_types': dtypes,
|
||||
'in_ports_count': len(dtypes),
|
||||
'out_ports_count': len(dtypes),
|
||||
})
|
||||
return cls.enabled
|
||||
|
||||
|
||||
class ReadVariableOpFrontExtractor(FrontExtractorOp):
|
||||
op = 'ReadVariableOp'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
IdentityOp.update_node_stat(node, {
|
||||
Identity.update_node_stat(node, {
|
||||
'data_type': tf_dtype_extractor(node.pb.attr["T"].type),
|
||||
})
|
||||
return cls.enabled
|
||||
@ -49,5 +64,5 @@ class StopGradientExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
IdentityOp.update_node_stat(node, {'op': 'StopGradient'})
|
||||
Identity.update_node_stat(node, {'op': 'StopGradient'})
|
||||
return cls.enabled
|
||||
|
@ -13,24 +13,39 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer, copy_value
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class IdentityOp(Op):
|
||||
class Identity(Op):
|
||||
op = 'Identity'
|
||||
enabled = True
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'op': __class__.op,
|
||||
'op': self.op,
|
||||
'type': None,
|
||||
|
||||
'identity': True,
|
||||
'infer': self.infer,
|
||||
|
||||
'in_ports_count': 1,
|
||||
'out_ports_count': 1,
|
||||
'infer': IdentityOp.shape_infer
|
||||
}, attrs)
|
||||
|
||||
@staticmethod
|
||||
def shape_infer(node):
|
||||
copy_shape_infer(node, value_infer=copy_value)
|
||||
def infer(node):
|
||||
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
||||
if node.in_port(0).data.get_value() is not None:
|
||||
node.out_port(0).data.set_value(node.in_port(0).data.get_value())
|
||||
|
||||
|
||||
class IdentityN(Op):
|
||||
op = 'IdentityN'
|
||||
enabled = True
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'op': self.op,
|
||||
'type': None,
|
||||
}, attrs)
|
||||
|
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.identity import IdentityOp
|
||||
from extensions.ops.identity import Identity
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
@ -23,5 +23,5 @@ class ClipGradientComponentFrontExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
IdentityOp.update_node_stat(node, {})
|
||||
Identity.update_node_stat(node, {})
|
||||
return cls.enabled
|
||||
|
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.identity import IdentityOp
|
||||
from extensions.ops.identity import Identity
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
@ -23,5 +23,5 @@ class NoOpFrontExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
IdentityOp.update_node_stat(node)
|
||||
Identity.update_node_stat(node)
|
||||
return cls.enabled
|
||||
|
Loading…
Reference in New Issue
Block a user