Correct ReverseV2ToReverseSequence transformation (#8120)

* add subgraph instead of constant with fixed shape to allow model have undefined batch

* updated transformation (not checked yet)

* changed ReverseV2ToReverseSequence to support dynamic shapes/reshape;
added transformation to reverse_tensor_iterator to support new subgraph got from ReverseV2ToReverseSequence

* remove changes that should not be on this branch

* added tests;
fixed old transformation

* added delete of reversesequences to avoid run of transformation twice

* fixed pattern check for case with dynamic value for input of reversesequence

* Revert "fixed pattern check for case with dynamic value for input of reversesequence"

This reverts commit 0c04164e

* Revert "added delete of reversesequences to avoid run of transformation twice"

This reverts commit fcb7de9c

* reversed changes in reverse_tensorr_iterator for Squeeze case;
update reverse_tensor_iterator with shapeof subgraph
added permutations for attributes to pass layer test

* minor fix for dynamic shape

* updated test;
fixed backward compatibility in reverse_tensor_iterator transformation

* revew comments fixed:
added comments;
refactoring done;
fixed framework name saving for rank = 1

* minor review fixes

* small fix
This commit is contained in:
Svetlana Dolinina 2021-11-29 15:29:00 +03:00 committed by GitHub
parent c084f8aa42
commit 980ad59ac4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 276 additions and 50 deletions

View File

@ -1,16 +1,27 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from extensions.ops.reverse_sequence import ReverseSequence
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph, rename_node
from mo.middle.replacement import MiddleReplacementPattern
from mo.utils.error import Error
from mo.ops.broadcast import Broadcast
from mo.ops.shape import Shape
from mo.ops.squeeze import Squeeze
from mo.ops.unsqueeze import Unsqueeze
from mo.utils.shape import node_to_get_shape_value_of_indices
class ReverseToReverseSequence(MiddleReplacementPattern):
"""
Transformation converts Reverse to ReverseSequence operation.
Parameters for ReverseSequence calculates in the following way:
* seq_axis - set axis value from Reverse operation
* batch_axis - set 0 if seq_axis is not 0 otherwise set 1
* seq_lengths - take from shape shape[seq_axis] value and broadcast it to vector with shape[batch_axis] length
If input is 1D tensor then we add one more dimension to set different seq_axis and batch_axis.
"""
enabled = True
def run_after(self):
@ -21,40 +32,57 @@ class ReverseToReverseSequence(MiddleReplacementPattern):
from extensions.middle.reverse_tensor_iterator import ReverseTensorIteratorLSTM
return [ReverseTensorIteratorLSTM]
@staticmethod
def pattern():
return dict(
nodes=[
('reverse', dict(kind='op', op='Reverse'))
],
edges=[]
)
def find_and_replace_pattern(self, graph: Graph):
reverse_nodes = graph.get_op_nodes(op='Reverse')
for reverse in reverse_nodes:
reverse_name = reverse.soft_get('name', reverse.id)
def replace_pattern(self, graph: Graph, match: dict):
reverse = match['reverse']
input_data_shape = reverse.in_node(0).shape
assert reverse.in_port(1).disconnected()
assert reverse.has_valid('axis')
if len(input_data_shape) == 1:
raise Error('Reverse operation name = {} is\'t supported because of 1D input.'.format(reverse.name))
in_shape_rank = len(reverse.in_port(0).data.get_shape())
# 1. Add new dimension as batch for rank = 1 to have batch != seq_axis
if in_shape_rank == 1:
unsq_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
{'name': reverse_name+"/Unsqueeze"})
reverse.in_port(0).get_source().connect(unsq_node.in_port(0))
new_in = unsq_node.out_port(0)
batch_axis = 0
seq_axis = 1
else:
new_in = reverse.in_port(0).get_source()
seq_axis = reverse['axis']
batch_axis = 0 if seq_axis != 0 else 1
assert reverse.in_port(1).disconnected()
# 2. For ReverseSequence 1-port input is seq_lengths => create this input node as
# shape[seq_axis] broadcasted to shape[batch_axis]
# in ---> ShapeOf ----> Gather(seq_axis) ----> Broadcast----->
# | |
# | -------> Gather(batch_axis)----------|
shape_node = Shape(graph, {'name': reverse_name + "/Shape"}).create_node()
new_in.connect(shape_node.in_port(0))
seq_axis_node = node_to_get_shape_value_of_indices(shape_node, [seq_axis])
batch_node = node_to_get_shape_value_of_indices(shape_node, [batch_axis])
broadcast_node = Broadcast(graph, {'name': reverse_name + "/Broadcast"}).create_node()
broadcast_node.in_port(0).connect(seq_axis_node.out_port(0))
broadcast_node.in_port(1).connect(batch_node.out_port(0))
seq_axis = reverse['axis']
# We need to choose arbitrary batch_axis != sequence_axis
batch_axis = int(not seq_axis)
# 3. Create new ReverseSequence node and reconnect all inputs/outputs to it
rename_node(reverse, reverse_name + '/to_delete')
reverse_sequence = ReverseSequence(graph, {'name': reverse_name, 'seq_axis': seq_axis,
'batch_axis': batch_axis}).create_node()
reverse_sequence.in_port(0).connect(new_in)
reverse_sequence.in_port(1).connect(broadcast_node.out_port(0))
# 1. For ReverseSequence 1-port input is seq_lengths => create this input node
seq_lengths = np.ones(input_data_shape[batch_axis]) * input_data_shape[seq_axis]
# 4. remove added dimension for rank = 1
if in_shape_rank == 1:
rename_node(reverse_sequence, reverse_name + '/ReverseSequence')
squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
{'name': reverse_name})
squeeze_node.in_port(0).connect(reverse_sequence.out_port(0))
reverse.out_port(0).get_connection().set_source(squeeze_node.out_port(0))
else:
reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))
reverse_name = reverse.soft_get('name', reverse.id)
rename_node(reverse, reverse_name + '/to_delete')
# 2. Create new ReverseSequence node and reconnect all inputs/outputs to it
reverse_sequence = create_op_node_with_second_input(graph, ReverseSequence, seq_lengths,
{'name': reverse_name, 'seq_axis': seq_axis,
'batch_axis': batch_axis})
rename_node(reverse_sequence, reverse_name)
reverse.in_port(0).get_connection().set_destination(reverse_sequence.in_port(0))
reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))
# 3. Delete old Reverse node
graph.remove_node(reverse.id)
# 5. Delete old Reverse node
graph.remove_nodes_from([reverse.id for reverse in reverse_nodes])

View File

@ -5,6 +5,7 @@ import numpy as np
from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
from extensions.middle.permute_tensor_iterator import TransposeTensorIteratorLSTM
from mo.front.common.partial_infer.utils import is_fully_defined
from mo.graph.graph import Graph, Node
from mo.middle.passes.eliminate import remove_op_node_with_data_node
from mo.middle.replacement import MiddleReplacementPattern
@ -18,6 +19,7 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
"""
enabled = True
force_clean_up = True
def run_after(self):
return [
@ -32,39 +34,55 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
@staticmethod
def is_fusable_reverse_sequence(node: Node):
sequence_lengths = node.in_port(1).data.get_value()
assert sequence_lengths is not None
input_shape = node.in_port(0).data.get_shape()
assert input_shape is not None
seq_len = input_shape[node.seq_axis]
return np.all(sequence_lengths == seq_len)
if sequence_lengths is not None and is_fully_defined(sequence_lengths) and is_fully_defined(seq_len):
return np.all(sequence_lengths == seq_len)
else:
# check that we take sequence_length from input shape based on ReverseV2ToReverseSequence transformation
broadcast_node = node.in_port(1).get_source().node
if broadcast_node.op != 'Broadcast':
return False
gather_node = broadcast_node.in_port(0).get_source().node
if gather_node.op != "Gather" or \
(np.all(gather_node.in_port(2).data.get_value() != [0]) or
np.all(gather_node.in_port(1).data.get_value() != [node.seq_axis])):
return False
gather_node_2 = broadcast_node.in_port(1).get_source().node
if gather_node_2.op != "Gather" or \
(np.all(gather_node_2.in_port(2).data.get_value() != [0]) or
np.all(gather_node_2.in_port(1).data.get_value() != [node.batch_axis])):
return False
shape_node = gather_node.in_port(0).get_source().node
if shape_node.op != "ShapeOf":
return False
if shape_node.in_port(0).get_source().node != node.in_port(0).get_source().node:
return False
return True
def pattern(self):
return dict(
nodes=[
('input', dict(kind='data')),
('const', dict(type='Const')),
('const_d', dict(kind='data')),
('direct_seq_len_d', dict(kind='data')),
('direct_reverse', dict(op='ReverseSequence')),
('input_reversed', dict(kind='data')),
('init_hidden', dict(kind='data')),
('ti', dict(kind='op', op='TensorIterator')),
('output_reversed', dict(kind='data')),
('const_1', dict(type='Const')),
('const_1_d', dict(kind='data')),
('inverse_seq_len_d', dict(kind='data')),
('inverse_reverse', dict(op='ReverseSequence')),
('output', dict(kind='data')),
],
edges=[
('input', 'direct_reverse', {'in': 0}),
('const', 'const_d'),
('const_d', 'direct_reverse', {'in': 1}),
('direct_seq_len_d', 'direct_reverse', {'in': 1}),
('direct_reverse', 'input_reversed'),
('input_reversed', 'ti', {'in': 0}),
@ -72,8 +90,7 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
('ti', 'output_reversed', {'out': 0}),
('output_reversed', 'inverse_reverse', {'in': 0}),
('const_1', 'const_1_d'),
('const_1_d', 'inverse_reverse', {'in': 1}),
('inverse_seq_len_d', 'inverse_reverse', {'in': 1}),
('inverse_reverse', 'output'),
]
)
@ -106,9 +123,12 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
port['start'] = -1
port['end'] = 0
elif port['stride'] == 1:
port['start'] = None
port['end'] = None
port['start'] = 0
port['end'] = -1
# disconnect subgraph for seq length calculation
direct_reverse.in_port(1).disconnect()
inverse_reverse.in_port(1).disconnect()
# Remove reverses
remove_op_node_with_data_node(graph, direct_reverse)
remove_op_node_with_data_node(graph, inverse_reverse)

View File

@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from mo.graph.graph import Graph
from mo.ops.op import Op
from mo.ops.op import Op, PermuteAttrs
class ReverseSequence(Op):
@ -35,3 +35,6 @@ class ReverseSequence(Op):
assert len(node.out_nodes()) == 1
node.out_port(0).data.set_shape(input_data_shape)
PermuteAttrs.create_permute_attrs(node, attrs=[('seq_axis', 'input:0')])
PermuteAttrs.create_permute_attrs(node, attrs=[('batch_axis', 'input:0')])

View File

@ -374,6 +374,8 @@ class PermuteAttrs:
'ellipsis_mask': slice_permutation,
'axes': common_permutation_inv,
'axis': common_permutation_inv,
'seq_axis': common_permutation_inv,
'batch_axis': common_permutation_inv,
'batch_dims': common_permutation_inv,
'channel_dims': common_permutation_inv,
'spatial_dims': common_permutation_inv,

View File

@ -0,0 +1,117 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.middle.reverse_tensor_iterator import ReverseTensorIteratorLSTM
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, connect, \
valued_const_with_data, regular_op_with_empty_data, result
nodes = {
**regular_op_with_shaped_data('parameter', [1, 3, 227, 227],
{'type': 'Parameter', 'op': 'Parameter', 'shape': [1, 3, 227, 227]}),
**valued_const_with_data('seq_len', np.array([227])),
**regular_op_with_empty_data('shapeof', {'type': 'ShapeOf', 'op': 'ShapeOf'}),
**valued_const_with_data('gather_axis', np.array([0])),
**valued_const_with_data('gather_batch_ind', np.array([0])),
**valued_const_with_data('gather_seq_ind', np.array([2])),
**regular_op_with_empty_data('gather_batch', {'type': 'Gather', 'op': 'Gather'}),
**regular_op_with_empty_data('gather_seq', {'type': 'Gather', 'op': 'Gather'}),
**regular_op_with_empty_data('broadcast', {'type': 'Broadcast', 'op': 'Broadcast'}),
**regular_op_with_shaped_data('direct_reverse', [1, 3, 227, 227], {'type': 'ReverseSequence',
'op': 'ReverseSequence',
'seq_axis': 2, 'batch_axis': 0}),
**regular_op_with_empty_data('init_hidden', {'type': 'Init', 'op': 'Init'}),
**regular_op_with_shaped_data('ti', [1, 2, 34, 56], {'type': 'TensorIterator', 'op': 'TensorIterator',
'output_port_map': [{'axis': 2, 'start': 0, 'end': -1,
'stride': 1, 'external_port_id': 0}],
'input_port_map': [{'axis': 2, 'start': -1, 'end': 0,
'stride': -1, 'external_port_id': 0}]}),
**valued_const_with_data('inverse_seq_len', np.array([34])),
**regular_op_with_empty_data('inverse_shapeof', {'type': 'ShapeOf', 'op': 'ShapeOf'}),
**regular_op_with_empty_data('inverse_gather_batch', {'type': 'Gather', 'op': 'Gather'}),
**regular_op_with_empty_data('inverse_gather_seq', {'type': 'Gather', 'op': 'Gather'}),
**regular_op_with_empty_data('inverse_broadcast', {'type': 'Broadcast', 'op': 'Broadcast'}),
**regular_op_with_shaped_data('inverse_reverse', [1, 2, 34, 56], {'type': 'ReverseSequence',
'op': 'ReverseSequence',
'seq_axis': 2, 'batch_axis': 0}),
**regular_op_with_empty_data('some_op', {'op': 'SomeOp'}),
**result()
}
ref_nodes = {
**regular_op_with_shaped_data('parameter', [1, 3, 227, 227],
{'type': 'Parameter', 'op': 'Parameter', 'shape': [1, 3, 227, 227]}),
**regular_op_with_empty_data('init_hidden', {'type': 'Init', 'op': 'Init'}),
**regular_op_with_empty_data('ti', {'type': 'TensorIterator', 'op': 'TensorIterator',
'output_port_map': [{'axis': 2, 'start': -1, 'end': 0, 'stride': -1,
'external_port_id': 0}],
'input_port_map': [{'axis': 2, 'start': 0, 'end': -1, 'stride': 1,
'external_port_id': 0}]}),
**regular_op_with_empty_data('some_op', {'op': 'SomeOp'}),
**result()
}
class ReverseTensorIteratorTest(unittest.TestCase):
def test_ti_reverse(self):
graph = build_graph(nodes, [*connect('parameter:0', '0:direct_reverse'),
*connect('parameter:0', 'shapeof', skip_data=True),
*connect('shapeof:0', '0:gather_batch'),
*connect('gather_batch_ind', '1:gather_batch'),
*connect('gather_axis', '2:gather_batch'),
*connect('shapeof:0', '0:gather_seq', skip_data=True),
*connect('gather_seq_ind', '1:gather_seq'),
*connect('gather_axis', '2:gather_seq'),
*connect('gather_seq', '0:broadcast'),
*connect('gather_batch', '1:broadcast'),
*connect('broadcast', '1:direct_reverse'),
*connect('direct_reverse', '0:ti'),
*connect('init_hidden', '1:ti'),
*connect('ti', 'inverse_shapeof'),
*connect('inverse_shapeof:0', '0:inverse_gather_batch'),
*connect('gather_batch_ind', '1:inverse_gather_batch'),
*connect('gather_axis', '2:inverse_gather_batch'),
*connect('inverse_shapeof:0', '0:inverse_gather_seq', skip_data=True),
*connect('gather_seq_ind', '1:inverse_gather_seq'),
*connect('gather_axis', '2:inverse_gather_seq'),
*connect('inverse_gather_seq', '0:inverse_broadcast'),
*connect('inverse_gather_batch', '1:inverse_broadcast'),
*connect('ti', '0:inverse_reverse', skip_data=True),
*connect('inverse_broadcast', '1:inverse_reverse'),
*connect('inverse_reverse', 'some_op'),
*connect('some_op', 'output')], nodes_with_edges_only=True)
ReverseTensorIteratorLSTM().find_and_replace_pattern(graph)
graph.clean_up()
ref_graph = build_graph(ref_nodes, [*connect('parameter', '0:ti'),
*connect('init_hidden', '1:ti'),
*connect('ti', 'some_op'),
*connect('some_op', 'output')])
flag, resp = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_ti_reverse_const(self):
graph = build_graph(nodes, [*connect('parameter:0', '0:direct_reverse'),
*connect('seq_len', '1:direct_reverse'),
*connect('direct_reverse', '0:ti'),
*connect('init_hidden', '1:ti'),
*connect('ti', '0:inverse_reverse'),
*connect('inverse_seq_len', '1:inverse_reverse'),
*connect('inverse_reverse', 'some_op'),
*connect('some_op', 'output')], nodes_with_edges_only=True)
ReverseTensorIteratorLSTM().find_and_replace_pattern(graph)
graph.clean_up()
ref_graph = build_graph(ref_nodes, [*connect('parameter', '0:ti'),
*connect('init_hidden', '1:ti'),
*connect('ti', 'some_op'),
*connect('some_op', 'output')])
flag, resp = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -0,0 +1,56 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
import numpy as np
from common.tf_layer_test_class import CommonTFLayerTest
class TestReverseV2Ops(CommonTFLayerTest):
def _prepare_input(self, inputs_dict):
for input in inputs_dict.keys():
inputs_dict[input] = np.random.random(inputs_dict[input])
return inputs_dict
def create_reversev2_net(self, shape, keep_dims, axis, ir_version):
import tensorflow as tf
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
shapes = shape.copy()
if len(shapes) >= 4:
shapes.append(shapes.pop(1))
x = tf.compat.v1.placeholder(tf.float32, shapes, 'Input')
tf.compat.v1.reverse_v2(x, axis)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data = []
test_data.extend([
dict(shape=[5], axis=[0]),
dict(shape=[2, 3], axis=[1]),
dict(shape=[2, 3, 5], axis=[-2]),
dict(shape=[2, 3, 5, 7], axis=[0]),
])
@pytest.mark.parametrize("params", test_data)
@pytest.mark.parametrize("keep_dims", [True, False])
@pytest.mark.nightly
def test_reversev2(self, params, keep_dims, ie_device, precision, ir_version, temp_dir):
self._test(*self.create_reversev2_net(**params, keep_dims=keep_dims, ir_version=ir_version),
ie_device, precision, ir_version, temp_dir=temp_dir)
test_data_pre_commit = []
test_data_pre_commit.extend([dict(shape=[5], axis=[0]),
dict(shape=[2, 3, 5], axis=[-2])
])
@pytest.mark.parametrize("params", test_data_pre_commit)
@pytest.mark.parametrize("keep_dims", [True])
@pytest.mark.precommit
def test_reversev2_precommit(self, params, keep_dims, ie_device, precision, ir_version, temp_dir):
self._test(*self.create_reversev2_net(**params, keep_dims=keep_dims, ir_version=ir_version),
ie_device, precision, ir_version, temp_dir=temp_dir)