Support TF ScatterND and CenterNet (#9257)
* add ScatterND for tf * Fix ApplyPermutation * fix description of TFScatterNDDecomposition * fix permutation * fix package_POM * fix pom file * fix BOM file * fix bom file * Added layer tests for TF ScatterND * fix comments * fix bom file * Add additional tests * Add additional tests * Added ConvertLike to ScatterND decomposition
This commit is contained in:
committed by
GitHub
parent
896532ace2
commit
00da13b058
@@ -624,6 +624,7 @@ openvino/tools/mo/front/tf/rfcn_support_api_v1.13.json
|
||||
openvino/tools/mo/front/tf/rfcn_support_api_v1.14.json
|
||||
openvino/tools/mo/front/tf/roll_ext.py
|
||||
openvino/tools/mo/front/tf/RollRealImagPack.py
|
||||
openvino/tools/mo/front/tf/scatter_nd_ext.py
|
||||
openvino/tools/mo/front/tf/select_ext.py
|
||||
openvino/tools/mo/front/tf/sign_ext.py
|
||||
openvino/tools/mo/front/tf/slice_ext.py
|
||||
@@ -656,6 +657,7 @@ openvino/tools/mo/front/tf/TensorArrayGatherV3.py
|
||||
openvino/tools/mo/front/tf/tensorflow_custom_operations_config_update.py
|
||||
openvino/tools/mo/front/tf/TFFFTToDFT.py
|
||||
openvino/tools/mo/front/tf/TFResizeToInterpolate.py
|
||||
openvino/tools/mo/front/tf/TFScatterNDDecomposition.py
|
||||
openvino/tools/mo/front/tf/TFSliceToSlice.py
|
||||
openvino/tools/mo/front/tf/tile_ext.py
|
||||
openvino/tools/mo/front/tf/topk_ext.py
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import float32_array, int64_array
|
||||
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from openvino.tools.mo.graph.graph import Graph, rename_nodes
|
||||
from openvino.tools.mo.ops.broadcast import Broadcast
|
||||
from openvino.tools.mo.ops.const import Const
|
||||
from openvino.tools.mo.ops.scatternd import ScatterNDUpdate
|
||||
from openvino.tools.mo.ops.ConvertLike import ConvertLike
|
||||
|
||||
|
||||
class TFScatterNDDecomposition(FrontReplacementSubgraph):
|
||||
"""
|
||||
Replaces TensorFlow ScatterND with OpenVINO ScatterNDUpdate. TF ScatterND does not have input data, so
|
||||
instead of this argument it expects its shape
|
||||
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for tf_scatter_nd in graph.get_op_nodes(op='TFScatterND'):
|
||||
if not tf_scatter_nd.is_in_port_connected(0) or not tf_scatter_nd.is_in_port_connected(1) \
|
||||
or not tf_scatter_nd.is_in_port_connected(2):
|
||||
continue
|
||||
name = tf_scatter_nd.soft_get('name', tf_scatter_nd.soft_get('id'))
|
||||
indices_port = tf_scatter_nd.in_port(0).get_source()
|
||||
updates_port = tf_scatter_nd.in_port(1).get_source()
|
||||
shape_port = tf_scatter_nd.in_port(2).get_source()
|
||||
# need get type of const type
|
||||
zero_const = Const(graph, {'value': int64_array(0.0), 'name': name + '/zero_const'}).create_node()
|
||||
|
||||
# Convert zero value to type of updates node
|
||||
convert_to_type = ConvertLike(graph, {'name': name + '/convert_like'}).create_node()
|
||||
convert_to_type.in_port(0).connect(zero_const.out_port(0))
|
||||
convert_to_type.in_port(1).connect(updates_port)
|
||||
|
||||
broad_cast_node = Broadcast(graph, {'name': name + '/broadcast'}).create_node()
|
||||
broad_cast_node.in_port(0).connect(convert_to_type.out_port(0))
|
||||
broad_cast_node.in_port(1).connect(shape_port)
|
||||
|
||||
scatter_nd_node = ScatterNDUpdate(graph, {'name': name + '/replaced'}).create_node()
|
||||
scatter_nd_node.in_port(0).connect(broad_cast_node.out_port(0))
|
||||
scatter_nd_node.in_port(1).connect(indices_port)
|
||||
scatter_nd_node.in_port(2).connect(updates_port)
|
||||
|
||||
rename_nodes([(tf_scatter_nd, name + '/TBD'), (scatter_nd_node, name)])
|
||||
|
||||
tf_scatter_nd.out_port(0).get_connection().set_source(scatter_nd_node.out_port(0))
|
||||
tf_scatter_nd.in_port(0).disconnect()
|
||||
tf_scatter_nd.in_port(1).disconnect()
|
||||
tf_scatter_nd.in_port(2).disconnect()
|
||||
15
tools/mo/openvino/tools/mo/front/tf/scatter_nd_ext.py
Normal file
15
tools/mo/openvino/tools/mo/front/tf/scatter_nd_ext.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.ops.scatternd import TFScatterND
|
||||
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
class ScatterNDExtractor(FrontExtractorOp):
|
||||
op = 'ScatterNd'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
TFScatterND.update_node_stat(node, {})
|
||||
return cls.enabled
|
||||
@@ -94,8 +94,8 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
||||
mark_output_as_in_correct_layout(permute_node, 0)
|
||||
|
||||
# keep the reinterp_shape_node in NHWC layout
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, 0)
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, 1)
|
||||
for in_port_id, _ in reinterp_shape_node.in_ports().items():
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, in_port_id)
|
||||
|
||||
# reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape
|
||||
for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True):
|
||||
@@ -118,7 +118,9 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
||||
|
||||
# keep the reinterp_shape_node in NHWC layout
|
||||
mark_output_as_in_correct_layout(reinterp_shape_node, 0)
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, 1)
|
||||
for in_port_id in reinterp_shape_node.in_ports().keys():
|
||||
if in_port_id:
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, in_port_id)
|
||||
|
||||
# do not re-infer the Transpose node because it output data node should be in NHWC layout to make the
|
||||
# rest of the graph consistent
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import compatible_shapes, strict_compare_tensors, is_fully_defined
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import compatible_shapes, strict_compare_tensors, \
|
||||
is_fully_defined
|
||||
from openvino.tools.mo.graph.graph import Node, Graph
|
||||
from openvino.tools.mo.ops.op import Op
|
||||
|
||||
@@ -93,3 +94,20 @@ class ScatterNDUpdate(ScatterNDBase):
|
||||
output_value[indices_value[indx]] = updates_value[indx]
|
||||
|
||||
node.out_port(0).data.set_value(output_value)
|
||||
|
||||
|
||||
class TFScatterND(Op):
|
||||
"""
|
||||
TFScatterND operation comes from TensorFlow and will be replaced by TFScatterNDDecomposition.
|
||||
"""
|
||||
op = 'TFScatterND'
|
||||
enabled = False
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'in_ports_count': 3,
|
||||
'out_ports_count': 1,
|
||||
'infer': None
|
||||
}, attrs)
|
||||
|
||||
Reference in New Issue
Block a user