Files
openvino/model-optimizer/extensions/front/tf/identity_ext.py
Alexey Suhov 6478f1742a Align copyright notice in python scripts (CVS-51320) (#4974)
* Align copyright notice in python scripts (CVS-51320)
2021-03-26 17:54:28 +03:00

57 lines
1.5 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.ops.identity import Identity, IdentityN
from mo.front.extractor import FrontExtractorOp
from mo.front.tf.extractors.utils import tf_dtype_extractor
from mo.graph.graph import Node
class IdentityFrontExtractor(FrontExtractorOp):
op = 'Identity'
enabled = True
@classmethod
def extract(cls, node: Node):
Identity.update_node_stat(node, {
'data_type': tf_dtype_extractor(node.pb.attr["T"].type),
})
return cls.enabled
class IdentityNFrontExtractor(FrontExtractorOp):
op = 'IdentityN'
enabled = True
@classmethod
def extract(cls, node: Node):
dtypes = [tf_dtype_extractor(t) for t in node.pb.attr["T"].list.type]
IdentityN.update_node_stat(node, {
'data_types': dtypes,
'in_ports_count': len(dtypes),
'out_ports_count': len(dtypes),
})
return cls.enabled
class ReadVariableOpFrontExtractor(FrontExtractorOp):
op = 'ReadVariableOp'
enabled = True
@classmethod
def extract(cls, node: Node):
Identity.update_node_stat(node, {
'data_type': tf_dtype_extractor(node.pb.attr["T"].type),
})
return cls.enabled
class StopGradientExtractor(FrontExtractorOp):
op = 'StopGradient'
enabled = True
@classmethod
def extract(cls, node: Node):
Identity.update_node_stat(node, {'op': 'StopGradient'})
return cls.enabled