Support TF ScatterND and CenterNet (#9257)

* add ScatterND for tf

* Fix ApplyPermutation

* fix description of TFScatterNDDecomposition

* fix permutation

* fix package_POM

* fix pom file

* fix BOM file

* fix bom file

* Added layer tests for TF ScatterND

* fix comments

* fix bom file

* Add additional tests

* Add additional tests

* Added ConvertLike to ScatterND decomposition
This commit is contained in:
Eugeny Volosenkov
2022-01-12 23:15:45 +03:00
committed by GitHub
parent 896532ace2
commit 00da13b058
7 changed files with 166 additions and 4 deletions

View File

@@ -624,6 +624,7 @@ openvino/tools/mo/front/tf/rfcn_support_api_v1.13.json
openvino/tools/mo/front/tf/rfcn_support_api_v1.14.json
openvino/tools/mo/front/tf/roll_ext.py
openvino/tools/mo/front/tf/RollRealImagPack.py
openvino/tools/mo/front/tf/scatter_nd_ext.py
openvino/tools/mo/front/tf/select_ext.py
openvino/tools/mo/front/tf/sign_ext.py
openvino/tools/mo/front/tf/slice_ext.py
@@ -656,6 +657,7 @@ openvino/tools/mo/front/tf/TensorArrayGatherV3.py
openvino/tools/mo/front/tf/tensorflow_custom_operations_config_update.py
openvino/tools/mo/front/tf/TFFFTToDFT.py
openvino/tools/mo/front/tf/TFResizeToInterpolate.py
openvino/tools/mo/front/tf/TFScatterNDDecomposition.py
openvino/tools/mo/front/tf/TFSliceToSlice.py
openvino/tools/mo/front/tf/tile_ext.py
openvino/tools/mo/front/tf/topk_ext.py

View File

@@ -0,0 +1,52 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.front.common.partial_infer.utils import float32_array, int64_array
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
from openvino.tools.mo.graph.graph import Graph, rename_nodes
from openvino.tools.mo.ops.broadcast import Broadcast
from openvino.tools.mo.ops.const import Const
from openvino.tools.mo.ops.scatternd import ScatterNDUpdate
from openvino.tools.mo.ops.ConvertLike import ConvertLike
class TFScatterNDDecomposition(FrontReplacementSubgraph):
"""
Replaces TensorFlow ScatterND with OpenVINO ScatterNDUpdate. TF ScatterND does not have input data, so
instead of this argument it expects its shape
"""
enabled = True
def find_and_replace_pattern(self, graph: Graph):
for tf_scatter_nd in graph.get_op_nodes(op='TFScatterND'):
if not tf_scatter_nd.is_in_port_connected(0) or not tf_scatter_nd.is_in_port_connected(1) \
or not tf_scatter_nd.is_in_port_connected(2):
continue
name = tf_scatter_nd.soft_get('name', tf_scatter_nd.soft_get('id'))
indices_port = tf_scatter_nd.in_port(0).get_source()
updates_port = tf_scatter_nd.in_port(1).get_source()
shape_port = tf_scatter_nd.in_port(2).get_source()
# need get type of const type
zero_const = Const(graph, {'value': int64_array(0.0), 'name': name + '/zero_const'}).create_node()
# Convert zero value to type of updates node
convert_to_type = ConvertLike(graph, {'name': name + '/convert_like'}).create_node()
convert_to_type.in_port(0).connect(zero_const.out_port(0))
convert_to_type.in_port(1).connect(updates_port)
broad_cast_node = Broadcast(graph, {'name': name + '/broadcast'}).create_node()
broad_cast_node.in_port(0).connect(convert_to_type.out_port(0))
broad_cast_node.in_port(1).connect(shape_port)
scatter_nd_node = ScatterNDUpdate(graph, {'name': name + '/replaced'}).create_node()
scatter_nd_node.in_port(0).connect(broad_cast_node.out_port(0))
scatter_nd_node.in_port(1).connect(indices_port)
scatter_nd_node.in_port(2).connect(updates_port)
rename_nodes([(tf_scatter_nd, name + '/TBD'), (scatter_nd_node, name)])
tf_scatter_nd.out_port(0).get_connection().set_source(scatter_nd_node.out_port(0))
tf_scatter_nd.in_port(0).disconnect()
tf_scatter_nd.in_port(1).disconnect()
tf_scatter_nd.in_port(2).disconnect()

View File

@@ -0,0 +1,15 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.ops.scatternd import TFScatterND
from openvino.tools.mo.front.extractor import FrontExtractorOp
class ScatterNDExtractor(FrontExtractorOp):
op = 'ScatterNd'
enabled = True
@classmethod
def extract(cls, node):
TFScatterND.update_node_stat(node, {})
return cls.enabled

View File

@@ -94,8 +94,8 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
mark_output_as_in_correct_layout(permute_node, 0)
# keep the reinterp_shape_node in NHWC layout
mark_input_as_in_correct_layout(reinterp_shape_node, 0)
mark_input_as_in_correct_layout(reinterp_shape_node, 1)
for in_port_id, _ in reinterp_shape_node.in_ports().items():
mark_input_as_in_correct_layout(reinterp_shape_node, in_port_id)
# reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape
for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True):
@@ -118,7 +118,9 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
# keep the reinterp_shape_node in NHWC layout
mark_output_as_in_correct_layout(reinterp_shape_node, 0)
mark_input_as_in_correct_layout(reinterp_shape_node, 1)
for in_port_id in reinterp_shape_node.in_ports().keys():
if in_port_id:
mark_input_as_in_correct_layout(reinterp_shape_node, in_port_id)
# do not re-infer the Transpose node because it output data node should be in NHWC layout to make the
# rest of the graph consistent

View File

@@ -3,7 +3,8 @@
import numpy as np
from openvino.tools.mo.front.common.partial_infer.utils import compatible_shapes, strict_compare_tensors, is_fully_defined
from openvino.tools.mo.front.common.partial_infer.utils import compatible_shapes, strict_compare_tensors, \
is_fully_defined
from openvino.tools.mo.graph.graph import Node, Graph
from openvino.tools.mo.ops.op import Op
@@ -93,3 +94,20 @@ class ScatterNDUpdate(ScatterNDBase):
output_value[indices_value[indx]] = updates_value[indx]
node.out_port(0).data.set_value(output_value)
class TFScatterND(Op):
"""
TFScatterND operation comes from TensorFlow and will be replaced by TFScatterNDDecomposition.
"""
op = 'TFScatterND'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': None,
'op': self.op,
'in_ports_count': 3,
'out_ports_count': 1,
'infer': None
}, attrs)