Fix squeeze shape infer to not change axis value (#13975)
* Fix squeeze shape infer to not change axis value * Fix TensorIterator * Revert Unsqueeze changes
This commit is contained in:
@@ -75,7 +75,7 @@ class TestSqueeze(CommonTFLayerTest):
|
||||
test_data_3D = [
|
||||
dict(shape=[1, 1, 3], axis=[]),
|
||||
dict(shape=[1, 1, 3], axis=[0]),
|
||||
dict(shape=[1, 1, 3], axis=[-1])
|
||||
dict(shape=[1, 3, 1], axis=[-1])
|
||||
]
|
||||
|
||||
# TODO mark as precommit (after successfully passing in nightly)
|
||||
|
||||
@@ -7,6 +7,7 @@ from openvino.tools.mo.middle.LSTMRNNSequenceToTensorIterator import LSTMToTenso
|
||||
from openvino.tools.mo.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
|
||||
from openvino.tools.mo.middle.SwapAxesMiddleReplacer import SwapAxisMiddleReplacer
|
||||
from openvino.tools.mo.middle.TensorIteratorMerge import TensorIteratorMerge
|
||||
from openvino.tools.mo.ops.const import Const
|
||||
from openvino.tools.mo.ops.gather import Gather
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
@@ -197,10 +198,16 @@ class TransposeTensorIteratorLSTM(MiddleReplacementPattern):
|
||||
|
||||
isomorphism['input_unsqueezed_i'].shape = isomorphism['input_unsqueezed_i'].shape[[1, 0, 2]]
|
||||
isomorphism['input_unsqueezed_i'].infer(isomorphism['input_unsqueezed_i'])
|
||||
isomorphism['squeeze_dim'].value = ti.input_port_map[data_input_port]['axis']
|
||||
isomorphism['squeeze_dim'].infer(isomorphism['squeeze_dim'])
|
||||
isomorphism['squeeze']['need_shape_inference'] = True
|
||||
squeeze = isomorphism['squeeze']
|
||||
squeeze_dim = Const(squeeze.graph, {'value': ti.input_port_map[data_input_port]['axis'],
|
||||
'need_shape_inference': True,
|
||||
'override_output_shape': True}).create_node()
|
||||
squeeze.in_port(1).get_connection().set_source(squeeze_dim.out_port(0))
|
||||
squeeze['need_shape_inference'] = True
|
||||
|
||||
isomorphism['unsqueeze_dim'].value = ti.output_port_map[data_output_port]['axis']
|
||||
isomorphism['unsqueeze_dim'].infer(isomorphism['unsqueeze_dim'])
|
||||
isomorphism['unsqueeze'].infer(isomorphism['unsqueeze'])
|
||||
unsqueeze = isomorphism['unsqueeze']
|
||||
unsqueeze_dim = Const(unsqueeze.graph, {'value': ti.output_port_map[data_output_port]['axis'],
|
||||
'need_shape_inference': True,
|
||||
'override_output_shape': True}).create_node()
|
||||
unsqueeze.in_port(1).get_connection().set_source(unsqueeze_dim.out_port(0))
|
||||
unsqueeze['need_shape_inference'] = True
|
||||
|
||||
@@ -62,10 +62,6 @@ class Squeeze(Op):
|
||||
output_shape = shape_delete(output_shape, real_squeeze_dims)
|
||||
node.out_port(0).data.set_shape(output_shape)
|
||||
|
||||
# make dimensions positive to correctly translate from NHWC to NCHW layout
|
||||
if node.in_port(1).get_source().node.op == 'Const':
|
||||
node.in_port(1).data.set_value(real_squeeze_dims)
|
||||
|
||||
if node.in_port(0).data.get_value() is not None:
|
||||
node.out_port(0).data.set_value(node.in_port(0).data.get_value().reshape(output_shape))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user