[TF FE] Support unknown and dynamic rank Placeholder shape (#14211)

* [TF FE] Support unknown and undefined Placeholder shape

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix mistake in the test

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-11-25 09:22:42 +03:00 committed by GitHub
parent 21b09e0ac9
commit db81e50a02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 149 additions and 3 deletions

View File

@ -101,8 +101,11 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
case ::tensorflow::AttrValue::ValueCase::kI: case ::tensorflow::AttrValue::ValueCase::kI:
return attrs[0].i(); return attrs[0].i();
case ::tensorflow::AttrValue::ValueCase::kShape: { case ::tensorflow::AttrValue::ValueCase::kShape: {
std::vector<ov::Dimension> dims;
const auto& tf_shape = attrs[0].shape(); const auto& tf_shape = attrs[0].shape();
if (tf_shape.unknown_rank()) {
return ov::PartialShape::dynamic();
}
std::vector<ov::Dimension> dims;
for (int i = 0; i < tf_shape.dim_size(); i++) { for (int i = 0; i < tf_shape.dim_size(); i++) {
dims.emplace_back(tf_shape.dim(i).size()); dims.emplace_back(tf_shape.dim(i).size());
} }

View File

@ -115,8 +115,19 @@ void InputModel::InputModelTFImpl::loadPlaces() {
m_op_places.push_back(op_place); m_op_places.push_back(op_place);
m_op_places_map[op_name] = op_place; m_op_places_map[op_name] = op_place;
if (op_type == "Placeholder") { if (op_type == "Placeholder") {
auto pshape = node_decoder->get_attribute("shape").as<ov::PartialShape>(); auto pshape = ov::PartialShape::dynamic();
auto type = node_decoder->get_attribute("dtype").as<ov::element::Type>(); auto shape_any = node_decoder->get_attribute("shape");
if (shape_any.is<ov::PartialShape>()) {
// sometimes shape attribute can be absent in the graph
// so we need to check if Any object is initialized first
pshape = shape_any.as<ov::PartialShape>();
}
auto dtype_any = node_decoder->get_attribute("dtype");
auto placeholder_name = node_decoder->get_op_name();
FRONT_END_GENERAL_CHECK(
dtype_any.is<ov::element::Type>(),
"Incorrect input model: Placeholder node " + placeholder_name + " has unspecified type.");
auto type = dtype_any.as<ov::element::Type>();
std::vector<std::string> names = {op_name}; std::vector<std::string> names = {op_name};
auto tensor_place = std::make_shared<TensorPlace>(m_input_model, pshape, type, names); auto tensor_place = std::make_shared<TensorPlace>(m_input_model, pshape, type, names);
m_tensor_places[op_name] = tensor_place; m_tensor_places[op_name] = tensor_place;

View File

@ -0,0 +1,40 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/frontend/exception.hpp>
#include <openvino/frontend/manager.hpp>
#include <openvino/op/util/framework_node.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "tf_utils.hpp"
#include "utils.hpp"
using namespace std;
using namespace ngraph;
using namespace ov::frontend;
TEST(FrontEndConvertModelTest, test_undefined_input_shape) {
FrontEndManager fem;
FrontEnd::Ptr frontEnd;
InputModel::Ptr inputModel;
ASSERT_NO_THROW(frontEnd = fem.load_by_framework(TF_FE));
ASSERT_NE(frontEnd, nullptr);
auto model_filename = FrontEndTestUtils::make_model_path(string(TEST_TENSORFLOW_MODELS_DIRNAME) +
string("undefined_input_shape/undefined_input_shape.pb"));
ASSERT_NO_THROW(inputModel = frontEnd->load(model_filename));
ASSERT_NE(inputModel, nullptr);
shared_ptr<ngraph::Function> function;
ASSERT_NO_THROW(function = frontEnd->convert(inputModel));
ASSERT_NE(function, nullptr);
for (auto& node : function->get_ordered_ops()) {
if (node->get_friendly_name() == "x") {
ASSERT_TRUE(node->get_output_partial_shape(0).same_scheme(ov::PartialShape::dynamic()));
} else if (node->get_friendly_name() == "y") {
ASSERT_TRUE(node->get_output_partial_shape(0).same_scheme(ov::PartialShape{2, 3}));
} else if (node->get_friendly_name() == "z") {
ASSERT_TRUE(node->get_output_partial_shape(0).same_scheme(ov::PartialShape::dynamic()));
}
}
}

View File

@ -0,0 +1,75 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "z"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node {
name: "add"
op: "AddV2"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Mul"
op: "Mul"
input: "add"
input: "z"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}

View File

@ -0,0 +1,17 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(dtype=tf.float32, shape=None, name='x')
y = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='y')
z = tf.placeholder(dtype=tf.float32, shape=None, name='z')
add = tf.add(x, y, name="add")
tf.multiply(add, z)
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'undefined_input_shape.pbtxt', as_text=True)