[MO]Implement additional shape extraction for TF Placeholder operation (#8817)
* Implement additional shape extraction for TF placeholder operation * Update checks and add comments * Use more accurate check * Extract _output_shapes attr only if we have problems with shape attr * Add missed import * Apply suggestions from code review Update code comments Co-authored-by: Anastasia Popova <anastasia.popova@intel.com> * Add update condition to avoid unnecessary checks * Remove unnecessary checks Co-authored-by: Anastasia Popova <anastasia.popova@intel.com>
This commit is contained in:
parent
526fe3098d
commit
6b0d1525fd
@ -1,10 +1,13 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.ops.parameter import Parameter
|
||||
import logging as log
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import shape_array
|
||||
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
||||
from openvino.tools.mo.front.tf.extractors.utils import tf_dtype_extractor, tf_tensor_shape
|
||||
from openvino.tools.mo.ops.op import PermuteAttrs
|
||||
from openvino.tools.mo.ops.parameter import Parameter
|
||||
|
||||
|
||||
class PlaceholderFrontExtractor(FrontExtractorOp):
|
||||
@ -13,9 +16,29 @@ class PlaceholderFrontExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
shape = shape_array([])
|
||||
# Extract output shape from `shape` attribute
|
||||
extracted_shape = tf_tensor_shape(node.pb.attr["shape"].shape)
|
||||
if len(extracted_shape) != 0:
|
||||
shape = extracted_shape
|
||||
else:
|
||||
# Extract output shape from `_output_shapes` attribute if it is possible
|
||||
extracted_output_shapes = node.pb.attr["_output_shapes"].list.shape
|
||||
if len(extracted_output_shapes) == 1: # check if attribute not empty
|
||||
extracted_output_shapes = tf_tensor_shape(extracted_output_shapes[0])
|
||||
|
||||
# Check equality of extracted shapes. We know some cases then Placeholder operation has empty `shape`
|
||||
# attribute value and non-empty `_output_shapes` attribute value and need co handle and support it.
|
||||
if len(extracted_output_shapes) > len(extracted_shape):
|
||||
log.warning('Extracted shapes for Placeholder operation {} have different lengths: `shape` {} and '
|
||||
'`_output_shapes` {}. Please, check if model is consistent'.format(
|
||||
node.pb.name, extracted_shape, extracted_output_shapes))
|
||||
if len(extracted_output_shapes) != 0:
|
||||
shape = extracted_output_shapes
|
||||
|
||||
attrs = {
|
||||
'data_type': tf_dtype_extractor(node.pb.attr["dtype"].type),
|
||||
'shape': tf_tensor_shape(node.pb.attr["shape"].shape),
|
||||
'shape': shape,
|
||||
'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
|
||||
}
|
||||
if node.pb.attr["shape"].shape.unknown_rank:
|
||||
|
Loading…
Reference in New Issue
Block a user