Support TF ScatterND and CenterNet (#9257)
* add ScatterND for tf * Fix ApplyPermutation * fix description of TFScatterNDDecomposition * fix permutation * fix package_POM * fix pom file * fix BOM file * fix bom file * Added layer tests for TF ScatterND * fix comments * fix bom file * Add additional tests * Add additional tests * Added ConvertLike to ScatterND decomposition
This commit is contained in:
committed by
GitHub
parent
896532ace2
commit
00da13b058
@@ -290,6 +290,7 @@ Some TensorFlow\* operations do not match to any Inference Engine layer, but are
|
||||
| Round | |
|
||||
| Pow | |
|
||||
| Rsqrt | |
|
||||
| ScatterNd | |
|
||||
| Select | |
|
||||
| SelectV2 | |
|
||||
| Shape | |
|
||||
|
||||
72
tests/layer_tests/tensorflow_tests/test_tf_ScatterND.py
Normal file
72
tests/layer_tests/tensorflow_tests/test_tf_ScatterND.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, connect, \
|
||||
shaped_data, connect_front
|
||||
|
||||
from common.layer_test_class import check_ir_version
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
from common.utils.tf_utils import permute_nchw_to_nhwc
|
||||
|
||||
|
||||
class TestTFScatterND(CommonTFLayerTest):
|
||||
def create_tf_scatternd_placeholder_const_net(self, x_shape, indices, updates, ir_version, use_new_frontend):
|
||||
#
|
||||
# Create Tensorflow model
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
tf_x_shape = x_shape.copy()
|
||||
|
||||
tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend)
|
||||
|
||||
x = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input')
|
||||
tf_indices = tf.constant(indices)
|
||||
tf_updates = tf.constant(updates)
|
||||
|
||||
scatter_nd = tf.scatter_nd(tf_indices, tf_updates, tf.shape(x), name="Operation")
|
||||
res = tf.add(x, scatter_nd)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
|
||||
tf_net = sess.graph_def
|
||||
|
||||
ref_net = None
|
||||
|
||||
return tf_net, ref_net
|
||||
|
||||
test_data = [
|
||||
pytest.param(dict(x_shape=[8], indices=[[4], [3], [1], [7]], updates=[9.0, 10.0, 11.0, 12.0]),
|
||||
marks=pytest.mark.precommit),
|
||||
pytest.param(dict(x_shape=[4, 4, 4], indices=[[0], [2]], updates= \
|
||||
[[[5.0, 5.0, 5.0, 5.0], [6.0, 6.0, 6.0, 6.0], [7.0, 7.0, 7.0, 7.0], [8.0, 8.0, 8.0, 8.0]],
|
||||
[[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0], [4.0, 4.0, 4.0, 4.0]]])),
|
||||
pytest.param(dict(x_shape=[2, 2], indices=[[0]], updates=[[5.0, 3.0]])),
|
||||
pytest.param(dict(x_shape=[2, 2], indices=[[1, 1]], updates=[5.0])),
|
||||
dict(x_shape=[1], indices=[[0]], updates=[3.0]),
|
||||
dict(x_shape=[20], indices=[[0], [6], [9], [19], [13]], updates=[3.0, 7.0, -12.0, 4.0, -99.0]),
|
||||
dict(x_shape=[4, 2], indices=[[1], [2]], updates=[[9.0, 14.0], [-76.0, 0.0]]),
|
||||
dict(x_shape=[4, 4, 4], indices=[[0], [1], [3]], updates=[
|
||||
[[5.0, 1.0, 5.0, 13.0], [8.0, 6.0, 6.0, 8.0], [7.0, 0.0, 0.0, 7.0], [8.0, 8.0, 8.0, 8.0]],
|
||||
[[0.0, 0.0, 0.0, 0.0], [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
|
||||
[[5.0, 5.0, 5.0, 5.0], [6.0, 6.0, 6.0, 6.0], [7.0, 7.0, 7.0, 7.0], [8.0, 8.0, 8.0, 8.0]]]),
|
||||
dict(x_shape=[2, 2, 2], indices=[[1, 1, 1], [0, 1, 0]], updates=[9.0, 6.3]),
|
||||
dict(x_shape=[2, 2, 2], indices=[[0, 0], [0, 1]], updates=[[6.7, 9.0], [45.0, 8.3]]),
|
||||
dict(x_shape=[2, 2, 2], indices=[[1]], updates=[[[6.7, 9.0], [45.0, 8.3]]]),
|
||||
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data)
|
||||
@pytest.mark.nightly
|
||||
def test_tf_scatter_nd(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend):
|
||||
self._test(*self.create_tf_scatternd_placeholder_const_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
ie_device, precision, temp_dir=temp_dir, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend, **params)
|
||||
@@ -624,6 +624,7 @@ openvino/tools/mo/front/tf/rfcn_support_api_v1.13.json
|
||||
openvino/tools/mo/front/tf/rfcn_support_api_v1.14.json
|
||||
openvino/tools/mo/front/tf/roll_ext.py
|
||||
openvino/tools/mo/front/tf/RollRealImagPack.py
|
||||
openvino/tools/mo/front/tf/scatter_nd_ext.py
|
||||
openvino/tools/mo/front/tf/select_ext.py
|
||||
openvino/tools/mo/front/tf/sign_ext.py
|
||||
openvino/tools/mo/front/tf/slice_ext.py
|
||||
@@ -656,6 +657,7 @@ openvino/tools/mo/front/tf/TensorArrayGatherV3.py
|
||||
openvino/tools/mo/front/tf/tensorflow_custom_operations_config_update.py
|
||||
openvino/tools/mo/front/tf/TFFFTToDFT.py
|
||||
openvino/tools/mo/front/tf/TFResizeToInterpolate.py
|
||||
openvino/tools/mo/front/tf/TFScatterNDDecomposition.py
|
||||
openvino/tools/mo/front/tf/TFSliceToSlice.py
|
||||
openvino/tools/mo/front/tf/tile_ext.py
|
||||
openvino/tools/mo/front/tf/topk_ext.py
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import float32_array, int64_array
|
||||
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from openvino.tools.mo.graph.graph import Graph, rename_nodes
|
||||
from openvino.tools.mo.ops.broadcast import Broadcast
|
||||
from openvino.tools.mo.ops.const import Const
|
||||
from openvino.tools.mo.ops.scatternd import ScatterNDUpdate
|
||||
from openvino.tools.mo.ops.ConvertLike import ConvertLike
|
||||
|
||||
|
||||
class TFScatterNDDecomposition(FrontReplacementSubgraph):
|
||||
"""
|
||||
Replaces TensorFlow ScatterND with OpenVINO ScatterNDUpdate. TF ScatterND does not have input data, so
|
||||
instead of this argument it expects its shape
|
||||
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for tf_scatter_nd in graph.get_op_nodes(op='TFScatterND'):
|
||||
if not tf_scatter_nd.is_in_port_connected(0) or not tf_scatter_nd.is_in_port_connected(1) \
|
||||
or not tf_scatter_nd.is_in_port_connected(2):
|
||||
continue
|
||||
name = tf_scatter_nd.soft_get('name', tf_scatter_nd.soft_get('id'))
|
||||
indices_port = tf_scatter_nd.in_port(0).get_source()
|
||||
updates_port = tf_scatter_nd.in_port(1).get_source()
|
||||
shape_port = tf_scatter_nd.in_port(2).get_source()
|
||||
# need get type of const type
|
||||
zero_const = Const(graph, {'value': int64_array(0.0), 'name': name + '/zero_const'}).create_node()
|
||||
|
||||
# Convert zero value to type of updates node
|
||||
convert_to_type = ConvertLike(graph, {'name': name + '/convert_like'}).create_node()
|
||||
convert_to_type.in_port(0).connect(zero_const.out_port(0))
|
||||
convert_to_type.in_port(1).connect(updates_port)
|
||||
|
||||
broad_cast_node = Broadcast(graph, {'name': name + '/broadcast'}).create_node()
|
||||
broad_cast_node.in_port(0).connect(convert_to_type.out_port(0))
|
||||
broad_cast_node.in_port(1).connect(shape_port)
|
||||
|
||||
scatter_nd_node = ScatterNDUpdate(graph, {'name': name + '/replaced'}).create_node()
|
||||
scatter_nd_node.in_port(0).connect(broad_cast_node.out_port(0))
|
||||
scatter_nd_node.in_port(1).connect(indices_port)
|
||||
scatter_nd_node.in_port(2).connect(updates_port)
|
||||
|
||||
rename_nodes([(tf_scatter_nd, name + '/TBD'), (scatter_nd_node, name)])
|
||||
|
||||
tf_scatter_nd.out_port(0).get_connection().set_source(scatter_nd_node.out_port(0))
|
||||
tf_scatter_nd.in_port(0).disconnect()
|
||||
tf_scatter_nd.in_port(1).disconnect()
|
||||
tf_scatter_nd.in_port(2).disconnect()
|
||||
15
tools/mo/openvino/tools/mo/front/tf/scatter_nd_ext.py
Normal file
15
tools/mo/openvino/tools/mo/front/tf/scatter_nd_ext.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.ops.scatternd import TFScatterND
|
||||
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
class ScatterNDExtractor(FrontExtractorOp):
|
||||
op = 'ScatterNd'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
TFScatterND.update_node_stat(node, {})
|
||||
return cls.enabled
|
||||
@@ -94,8 +94,8 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
||||
mark_output_as_in_correct_layout(permute_node, 0)
|
||||
|
||||
# keep the reinterp_shape_node in NHWC layout
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, 0)
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, 1)
|
||||
for in_port_id, _ in reinterp_shape_node.in_ports().items():
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, in_port_id)
|
||||
|
||||
# reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape
|
||||
for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True):
|
||||
@@ -118,7 +118,9 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
||||
|
||||
# keep the reinterp_shape_node in NHWC layout
|
||||
mark_output_as_in_correct_layout(reinterp_shape_node, 0)
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, 1)
|
||||
for in_port_id in reinterp_shape_node.in_ports().keys():
|
||||
if in_port_id:
|
||||
mark_input_as_in_correct_layout(reinterp_shape_node, in_port_id)
|
||||
|
||||
# do not re-infer the Transpose node because it output data node should be in NHWC layout to make the
|
||||
# rest of the graph consistent
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import compatible_shapes, strict_compare_tensors, is_fully_defined
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import compatible_shapes, strict_compare_tensors, \
|
||||
is_fully_defined
|
||||
from openvino.tools.mo.graph.graph import Node, Graph
|
||||
from openvino.tools.mo.ops.op import Op
|
||||
|
||||
@@ -93,3 +94,20 @@ class ScatterNDUpdate(ScatterNDBase):
|
||||
output_value[indices_value[indx]] = updates_value[indx]
|
||||
|
||||
node.out_port(0).data.set_value(output_value)
|
||||
|
||||
|
||||
class TFScatterND(Op):
|
||||
"""
|
||||
TFScatterND operation comes from TensorFlow and will be replaced by TFScatterNDDecomposition.
|
||||
"""
|
||||
op = 'TFScatterND'
|
||||
enabled = False
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'in_ports_count': 3,
|
||||
'out_ports_count': 1,
|
||||
'infer': None
|
||||
}, attrs)
|
||||
|
||||
Reference in New Issue
Block a user