Correct ReverseV2ToReverseSequence transformation (#8120)
* add subgraph instead of constant with fixed shape to allow model have undefined batch * updated transformation (not checked yet) * changed ReverseV2ToReverseSequence to support dynamic shapes/reshape; added transformation to reverse_tensor_iterator to support new subgraph got from ReverseV2ToReverseSequence * remove changes that should not be on this branch * added tests; fixed old transformation * added delete of reversesequences to avoid run of transformation twice * fixed pattern check for case with dynamic value for input of reversesequence * Revert "fixed pattern check for case with dynamic value for input of reversesequence" This reverts commit0c04164e
* Revert "added delete of reversesequences to avoid run of transformation twice" This reverts commitfcb7de9c
* reversed changes in reverse_tensorr_iterator for Squeeze case; update reverse_tensor_iterator with shapeof subgraph added permutations for attributes to pass layer test * minor fix for dynamic shape * updated test; fixed backward compatibility in reverse_tensor_iterator transformation * revew comments fixed: added comments; refactoring done; fixed framework name saving for rank = 1 * minor review fixes * small fix
This commit is contained in:
parent
c084f8aa42
commit
980ad59ac4
@ -1,16 +1,27 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.reverse_sequence import ReverseSequence
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, rename_node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.utils.error import Error
|
||||
from mo.ops.broadcast import Broadcast
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.squeeze import Squeeze
|
||||
from mo.ops.unsqueeze import Unsqueeze
|
||||
from mo.utils.shape import node_to_get_shape_value_of_indices
|
||||
|
||||
|
||||
class ReverseToReverseSequence(MiddleReplacementPattern):
|
||||
"""
|
||||
Transformation converts Reverse to ReverseSequence operation.
|
||||
Parameters for ReverseSequence calculates in the following way:
|
||||
* seq_axis - set axis value from Reverse operation
|
||||
* batch_axis - set 0 if seq_axis is not 0 otherwise set 1
|
||||
* seq_lengths - take from shape shape[seq_axis] value and broadcast it to vector with shape[batch_axis] length
|
||||
If input is 1D tensor then we add one more dimension to set different seq_axis and batch_axis.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
@ -21,40 +32,57 @@ class ReverseToReverseSequence(MiddleReplacementPattern):
|
||||
from extensions.middle.reverse_tensor_iterator import ReverseTensorIteratorLSTM
|
||||
return [ReverseTensorIteratorLSTM]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('reverse', dict(kind='op', op='Reverse'))
|
||||
],
|
||||
edges=[]
|
||||
)
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
reverse_nodes = graph.get_op_nodes(op='Reverse')
|
||||
for reverse in reverse_nodes:
|
||||
reverse_name = reverse.soft_get('name', reverse.id)
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
reverse = match['reverse']
|
||||
input_data_shape = reverse.in_node(0).shape
|
||||
assert reverse.in_port(1).disconnected()
|
||||
assert reverse.has_valid('axis')
|
||||
|
||||
if len(input_data_shape) == 1:
|
||||
raise Error('Reverse operation name = {} is\'t supported because of 1D input.'.format(reverse.name))
|
||||
in_shape_rank = len(reverse.in_port(0).data.get_shape())
|
||||
# 1. Add new dimension as batch for rank = 1 to have batch != seq_axis
|
||||
if in_shape_rank == 1:
|
||||
unsq_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
|
||||
{'name': reverse_name+"/Unsqueeze"})
|
||||
reverse.in_port(0).get_source().connect(unsq_node.in_port(0))
|
||||
new_in = unsq_node.out_port(0)
|
||||
batch_axis = 0
|
||||
seq_axis = 1
|
||||
else:
|
||||
new_in = reverse.in_port(0).get_source()
|
||||
seq_axis = reverse['axis']
|
||||
batch_axis = 0 if seq_axis != 0 else 1
|
||||
|
||||
assert reverse.in_port(1).disconnected()
|
||||
# 2. For ReverseSequence 1-port input is seq_lengths => create this input node as
|
||||
# shape[seq_axis] broadcasted to shape[batch_axis]
|
||||
# in ---> ShapeOf ----> Gather(seq_axis) ----> Broadcast----->
|
||||
# | |
|
||||
# | -------> Gather(batch_axis)----------|
|
||||
shape_node = Shape(graph, {'name': reverse_name + "/Shape"}).create_node()
|
||||
new_in.connect(shape_node.in_port(0))
|
||||
seq_axis_node = node_to_get_shape_value_of_indices(shape_node, [seq_axis])
|
||||
batch_node = node_to_get_shape_value_of_indices(shape_node, [batch_axis])
|
||||
broadcast_node = Broadcast(graph, {'name': reverse_name + "/Broadcast"}).create_node()
|
||||
broadcast_node.in_port(0).connect(seq_axis_node.out_port(0))
|
||||
broadcast_node.in_port(1).connect(batch_node.out_port(0))
|
||||
|
||||
seq_axis = reverse['axis']
|
||||
# We need to choose arbitrary batch_axis != sequence_axis
|
||||
batch_axis = int(not seq_axis)
|
||||
# 3. Create new ReverseSequence node and reconnect all inputs/outputs to it
|
||||
rename_node(reverse, reverse_name + '/to_delete')
|
||||
reverse_sequence = ReverseSequence(graph, {'name': reverse_name, 'seq_axis': seq_axis,
|
||||
'batch_axis': batch_axis}).create_node()
|
||||
reverse_sequence.in_port(0).connect(new_in)
|
||||
reverse_sequence.in_port(1).connect(broadcast_node.out_port(0))
|
||||
|
||||
# 1. For ReverseSequence 1-port input is seq_lengths => create this input node
|
||||
seq_lengths = np.ones(input_data_shape[batch_axis]) * input_data_shape[seq_axis]
|
||||
# 4. remove added dimension for rank = 1
|
||||
if in_shape_rank == 1:
|
||||
rename_node(reverse_sequence, reverse_name + '/ReverseSequence')
|
||||
squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
|
||||
{'name': reverse_name})
|
||||
squeeze_node.in_port(0).connect(reverse_sequence.out_port(0))
|
||||
reverse.out_port(0).get_connection().set_source(squeeze_node.out_port(0))
|
||||
else:
|
||||
reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))
|
||||
|
||||
reverse_name = reverse.soft_get('name', reverse.id)
|
||||
rename_node(reverse, reverse_name + '/to_delete')
|
||||
# 2. Create new ReverseSequence node and reconnect all inputs/outputs to it
|
||||
reverse_sequence = create_op_node_with_second_input(graph, ReverseSequence, seq_lengths,
|
||||
{'name': reverse_name, 'seq_axis': seq_axis,
|
||||
'batch_axis': batch_axis})
|
||||
rename_node(reverse_sequence, reverse_name)
|
||||
reverse.in_port(0).get_connection().set_destination(reverse_sequence.in_port(0))
|
||||
reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))
|
||||
|
||||
# 3. Delete old Reverse node
|
||||
graph.remove_node(reverse.id)
|
||||
# 5. Delete old Reverse node
|
||||
graph.remove_nodes_from([reverse.id for reverse in reverse_nodes])
|
||||
|
@ -5,6 +5,7 @@ import numpy as np
|
||||
|
||||
from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
|
||||
from extensions.middle.permute_tensor_iterator import TransposeTensorIteratorLSTM
|
||||
from mo.front.common.partial_infer.utils import is_fully_defined
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.passes.eliminate import remove_op_node_with_data_node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
@ -18,6 +19,7 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
|
||||
"""
|
||||
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [
|
||||
@ -32,39 +34,55 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
|
||||
@staticmethod
|
||||
def is_fusable_reverse_sequence(node: Node):
|
||||
sequence_lengths = node.in_port(1).data.get_value()
|
||||
assert sequence_lengths is not None
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
assert input_shape is not None
|
||||
|
||||
seq_len = input_shape[node.seq_axis]
|
||||
return np.all(sequence_lengths == seq_len)
|
||||
if sequence_lengths is not None and is_fully_defined(sequence_lengths) and is_fully_defined(seq_len):
|
||||
return np.all(sequence_lengths == seq_len)
|
||||
else:
|
||||
# check that we take sequence_length from input shape based on ReverseV2ToReverseSequence transformation
|
||||
broadcast_node = node.in_port(1).get_source().node
|
||||
if broadcast_node.op != 'Broadcast':
|
||||
return False
|
||||
gather_node = broadcast_node.in_port(0).get_source().node
|
||||
if gather_node.op != "Gather" or \
|
||||
(np.all(gather_node.in_port(2).data.get_value() != [0]) or
|
||||
np.all(gather_node.in_port(1).data.get_value() != [node.seq_axis])):
|
||||
return False
|
||||
gather_node_2 = broadcast_node.in_port(1).get_source().node
|
||||
if gather_node_2.op != "Gather" or \
|
||||
(np.all(gather_node_2.in_port(2).data.get_value() != [0]) or
|
||||
np.all(gather_node_2.in_port(1).data.get_value() != [node.batch_axis])):
|
||||
return False
|
||||
shape_node = gather_node.in_port(0).get_source().node
|
||||
if shape_node.op != "ShapeOf":
|
||||
return False
|
||||
if shape_node.in_port(0).get_source().node != node.in_port(0).get_source().node:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('input', dict(kind='data')),
|
||||
|
||||
('const', dict(type='Const')),
|
||||
('const_d', dict(kind='data')),
|
||||
|
||||
('direct_seq_len_d', dict(kind='data')),
|
||||
('direct_reverse', dict(op='ReverseSequence')),
|
||||
('input_reversed', dict(kind='data')),
|
||||
('init_hidden', dict(kind='data')),
|
||||
|
||||
('ti', dict(kind='op', op='TensorIterator')),
|
||||
|
||||
('output_reversed', dict(kind='data')),
|
||||
|
||||
('const_1', dict(type='Const')),
|
||||
('const_1_d', dict(kind='data')),
|
||||
|
||||
('inverse_seq_len_d', dict(kind='data')),
|
||||
('inverse_reverse', dict(op='ReverseSequence')),
|
||||
('output', dict(kind='data')),
|
||||
],
|
||||
edges=[
|
||||
('input', 'direct_reverse', {'in': 0}),
|
||||
('const', 'const_d'),
|
||||
('const_d', 'direct_reverse', {'in': 1}),
|
||||
('direct_seq_len_d', 'direct_reverse', {'in': 1}),
|
||||
('direct_reverse', 'input_reversed'),
|
||||
|
||||
('input_reversed', 'ti', {'in': 0}),
|
||||
@ -72,8 +90,7 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
|
||||
('ti', 'output_reversed', {'out': 0}),
|
||||
|
||||
('output_reversed', 'inverse_reverse', {'in': 0}),
|
||||
('const_1', 'const_1_d'),
|
||||
('const_1_d', 'inverse_reverse', {'in': 1}),
|
||||
('inverse_seq_len_d', 'inverse_reverse', {'in': 1}),
|
||||
('inverse_reverse', 'output'),
|
||||
]
|
||||
)
|
||||
@ -106,9 +123,12 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
|
||||
port['start'] = -1
|
||||
port['end'] = 0
|
||||
elif port['stride'] == 1:
|
||||
port['start'] = None
|
||||
port['end'] = None
|
||||
port['start'] = 0
|
||||
port['end'] = -1
|
||||
|
||||
# disconnect subgraph for seq length calculation
|
||||
direct_reverse.in_port(1).disconnect()
|
||||
inverse_reverse.in_port(1).disconnect()
|
||||
# Remove reverses
|
||||
remove_op_node_with_data_node(graph, direct_reverse)
|
||||
remove_op_node_with_data_node(graph, inverse_reverse)
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
from mo.ops.op import Op, PermuteAttrs
|
||||
|
||||
|
||||
class ReverseSequence(Op):
|
||||
@ -35,3 +35,6 @@ class ReverseSequence(Op):
|
||||
|
||||
assert len(node.out_nodes()) == 1
|
||||
node.out_port(0).data.set_shape(input_data_shape)
|
||||
|
||||
PermuteAttrs.create_permute_attrs(node, attrs=[('seq_axis', 'input:0')])
|
||||
PermuteAttrs.create_permute_attrs(node, attrs=[('batch_axis', 'input:0')])
|
||||
|
@ -374,6 +374,8 @@ class PermuteAttrs:
|
||||
'ellipsis_mask': slice_permutation,
|
||||
'axes': common_permutation_inv,
|
||||
'axis': common_permutation_inv,
|
||||
'seq_axis': common_permutation_inv,
|
||||
'batch_axis': common_permutation_inv,
|
||||
'batch_dims': common_permutation_inv,
|
||||
'channel_dims': common_permutation_inv,
|
||||
'spatial_dims': common_permutation_inv,
|
||||
|
@ -0,0 +1,117 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.reverse_tensor_iterator import ReverseTensorIteratorLSTM
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, connect, \
|
||||
valued_const_with_data, regular_op_with_empty_data, result
|
||||
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('parameter', [1, 3, 227, 227],
|
||||
{'type': 'Parameter', 'op': 'Parameter', 'shape': [1, 3, 227, 227]}),
|
||||
**valued_const_with_data('seq_len', np.array([227])),
|
||||
**regular_op_with_empty_data('shapeof', {'type': 'ShapeOf', 'op': 'ShapeOf'}),
|
||||
**valued_const_with_data('gather_axis', np.array([0])),
|
||||
**valued_const_with_data('gather_batch_ind', np.array([0])),
|
||||
**valued_const_with_data('gather_seq_ind', np.array([2])),
|
||||
**regular_op_with_empty_data('gather_batch', {'type': 'Gather', 'op': 'Gather'}),
|
||||
**regular_op_with_empty_data('gather_seq', {'type': 'Gather', 'op': 'Gather'}),
|
||||
**regular_op_with_empty_data('broadcast', {'type': 'Broadcast', 'op': 'Broadcast'}),
|
||||
**regular_op_with_shaped_data('direct_reverse', [1, 3, 227, 227], {'type': 'ReverseSequence',
|
||||
'op': 'ReverseSequence',
|
||||
'seq_axis': 2, 'batch_axis': 0}),
|
||||
**regular_op_with_empty_data('init_hidden', {'type': 'Init', 'op': 'Init'}),
|
||||
|
||||
**regular_op_with_shaped_data('ti', [1, 2, 34, 56], {'type': 'TensorIterator', 'op': 'TensorIterator',
|
||||
'output_port_map': [{'axis': 2, 'start': 0, 'end': -1,
|
||||
'stride': 1, 'external_port_id': 0}],
|
||||
'input_port_map': [{'axis': 2, 'start': -1, 'end': 0,
|
||||
'stride': -1, 'external_port_id': 0}]}),
|
||||
**valued_const_with_data('inverse_seq_len', np.array([34])),
|
||||
**regular_op_with_empty_data('inverse_shapeof', {'type': 'ShapeOf', 'op': 'ShapeOf'}),
|
||||
**regular_op_with_empty_data('inverse_gather_batch', {'type': 'Gather', 'op': 'Gather'}),
|
||||
**regular_op_with_empty_data('inverse_gather_seq', {'type': 'Gather', 'op': 'Gather'}),
|
||||
**regular_op_with_empty_data('inverse_broadcast', {'type': 'Broadcast', 'op': 'Broadcast'}),
|
||||
**regular_op_with_shaped_data('inverse_reverse', [1, 2, 34, 56], {'type': 'ReverseSequence',
|
||||
'op': 'ReverseSequence',
|
||||
'seq_axis': 2, 'batch_axis': 0}),
|
||||
**regular_op_with_empty_data('some_op', {'op': 'SomeOp'}),
|
||||
**result()
|
||||
}
|
||||
|
||||
ref_nodes = {
|
||||
**regular_op_with_shaped_data('parameter', [1, 3, 227, 227],
|
||||
{'type': 'Parameter', 'op': 'Parameter', 'shape': [1, 3, 227, 227]}),
|
||||
**regular_op_with_empty_data('init_hidden', {'type': 'Init', 'op': 'Init'}),
|
||||
**regular_op_with_empty_data('ti', {'type': 'TensorIterator', 'op': 'TensorIterator',
|
||||
'output_port_map': [{'axis': 2, 'start': -1, 'end': 0, 'stride': -1,
|
||||
'external_port_id': 0}],
|
||||
'input_port_map': [{'axis': 2, 'start': 0, 'end': -1, 'stride': 1,
|
||||
'external_port_id': 0}]}),
|
||||
**regular_op_with_empty_data('some_op', {'op': 'SomeOp'}),
|
||||
**result()
|
||||
}
|
||||
|
||||
|
||||
class ReverseTensorIteratorTest(unittest.TestCase):
|
||||
def test_ti_reverse(self):
|
||||
graph = build_graph(nodes, [*connect('parameter:0', '0:direct_reverse'),
|
||||
*connect('parameter:0', 'shapeof', skip_data=True),
|
||||
*connect('shapeof:0', '0:gather_batch'),
|
||||
*connect('gather_batch_ind', '1:gather_batch'),
|
||||
*connect('gather_axis', '2:gather_batch'),
|
||||
*connect('shapeof:0', '0:gather_seq', skip_data=True),
|
||||
*connect('gather_seq_ind', '1:gather_seq'),
|
||||
*connect('gather_axis', '2:gather_seq'),
|
||||
*connect('gather_seq', '0:broadcast'),
|
||||
*connect('gather_batch', '1:broadcast'),
|
||||
*connect('broadcast', '1:direct_reverse'),
|
||||
*connect('direct_reverse', '0:ti'),
|
||||
*connect('init_hidden', '1:ti'),
|
||||
*connect('ti', 'inverse_shapeof'),
|
||||
*connect('inverse_shapeof:0', '0:inverse_gather_batch'),
|
||||
*connect('gather_batch_ind', '1:inverse_gather_batch'),
|
||||
*connect('gather_axis', '2:inverse_gather_batch'),
|
||||
*connect('inverse_shapeof:0', '0:inverse_gather_seq', skip_data=True),
|
||||
*connect('gather_seq_ind', '1:inverse_gather_seq'),
|
||||
*connect('gather_axis', '2:inverse_gather_seq'),
|
||||
*connect('inverse_gather_seq', '0:inverse_broadcast'),
|
||||
*connect('inverse_gather_batch', '1:inverse_broadcast'),
|
||||
*connect('ti', '0:inverse_reverse', skip_data=True),
|
||||
*connect('inverse_broadcast', '1:inverse_reverse'),
|
||||
*connect('inverse_reverse', 'some_op'),
|
||||
*connect('some_op', 'output')], nodes_with_edges_only=True)
|
||||
|
||||
ReverseTensorIteratorLSTM().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
|
||||
ref_graph = build_graph(ref_nodes, [*connect('parameter', '0:ti'),
|
||||
*connect('init_hidden', '1:ti'),
|
||||
*connect('ti', 'some_op'),
|
||||
*connect('some_op', 'output')])
|
||||
flag, resp = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_ti_reverse_const(self):
|
||||
graph = build_graph(nodes, [*connect('parameter:0', '0:direct_reverse'),
|
||||
*connect('seq_len', '1:direct_reverse'),
|
||||
*connect('direct_reverse', '0:ti'),
|
||||
*connect('init_hidden', '1:ti'),
|
||||
*connect('ti', '0:inverse_reverse'),
|
||||
*connect('inverse_seq_len', '1:inverse_reverse'),
|
||||
*connect('inverse_reverse', 'some_op'),
|
||||
*connect('some_op', 'output')], nodes_with_edges_only=True)
|
||||
|
||||
ReverseTensorIteratorLSTM().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
|
||||
ref_graph = build_graph(ref_nodes, [*connect('parameter', '0:ti'),
|
||||
*connect('init_hidden', '1:ti'),
|
||||
*connect('ti', 'some_op'),
|
||||
*connect('some_op', 'output')])
|
||||
flag, resp = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
56
tests/layer_tests/tensorflow_tests/test_tf_ReverseV2.py
Normal file
56
tests/layer_tests/tensorflow_tests/test_tf_ReverseV2.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestReverseV2Ops(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_dict):
|
||||
for input in inputs_dict.keys():
|
||||
inputs_dict[input] = np.random.random(inputs_dict[input])
|
||||
return inputs_dict
|
||||
|
||||
def create_reversev2_net(self, shape, keep_dims, axis, ir_version):
|
||||
import tensorflow as tf
|
||||
tf.compat.v1.reset_default_graph()
|
||||
with tf.compat.v1.Session() as sess:
|
||||
shapes = shape.copy()
|
||||
if len(shapes) >= 4:
|
||||
shapes.append(shapes.pop(1))
|
||||
|
||||
x = tf.compat.v1.placeholder(tf.float32, shapes, 'Input')
|
||||
tf.compat.v1.reverse_v2(x, axis)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data = []
|
||||
test_data.extend([
|
||||
dict(shape=[5], axis=[0]),
|
||||
dict(shape=[2, 3], axis=[1]),
|
||||
dict(shape=[2, 3, 5], axis=[-2]),
|
||||
dict(shape=[2, 3, 5, 7], axis=[0]),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize("params", test_data)
|
||||
@pytest.mark.parametrize("keep_dims", [True, False])
|
||||
@pytest.mark.nightly
|
||||
def test_reversev2(self, params, keep_dims, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_reversev2_net(**params, keep_dims=keep_dims, ir_version=ir_version),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir)
|
||||
|
||||
test_data_pre_commit = []
|
||||
test_data_pre_commit.extend([dict(shape=[5], axis=[0]),
|
||||
dict(shape=[2, 3, 5], axis=[-2])
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_pre_commit)
|
||||
@pytest.mark.parametrize("keep_dims", [True])
|
||||
@pytest.mark.precommit
|
||||
def test_reversev2_precommit(self, params, keep_dims, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_reversev2_net(**params, keep_dims=keep_dims, ir_version=ir_version),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir)
|
Loading…
Reference in New Issue
Block a user