[TF FE] Support unknown and dynamic rank Placeholder shape (#14211)
* [TF FE] Support unknown and undefined Placeholder shape Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix mistake in the test Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
21b09e0ac9
commit
db81e50a02
@ -101,8 +101,11 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
|
|||||||
case ::tensorflow::AttrValue::ValueCase::kI:
|
case ::tensorflow::AttrValue::ValueCase::kI:
|
||||||
return attrs[0].i();
|
return attrs[0].i();
|
||||||
case ::tensorflow::AttrValue::ValueCase::kShape: {
|
case ::tensorflow::AttrValue::ValueCase::kShape: {
|
||||||
std::vector<ov::Dimension> dims;
|
|
||||||
const auto& tf_shape = attrs[0].shape();
|
const auto& tf_shape = attrs[0].shape();
|
||||||
|
if (tf_shape.unknown_rank()) {
|
||||||
|
return ov::PartialShape::dynamic();
|
||||||
|
}
|
||||||
|
std::vector<ov::Dimension> dims;
|
||||||
for (int i = 0; i < tf_shape.dim_size(); i++) {
|
for (int i = 0; i < tf_shape.dim_size(); i++) {
|
||||||
dims.emplace_back(tf_shape.dim(i).size());
|
dims.emplace_back(tf_shape.dim(i).size());
|
||||||
}
|
}
|
||||||
|
@ -115,8 +115,19 @@ void InputModel::InputModelTFImpl::loadPlaces() {
|
|||||||
m_op_places.push_back(op_place);
|
m_op_places.push_back(op_place);
|
||||||
m_op_places_map[op_name] = op_place;
|
m_op_places_map[op_name] = op_place;
|
||||||
if (op_type == "Placeholder") {
|
if (op_type == "Placeholder") {
|
||||||
auto pshape = node_decoder->get_attribute("shape").as<ov::PartialShape>();
|
auto pshape = ov::PartialShape::dynamic();
|
||||||
auto type = node_decoder->get_attribute("dtype").as<ov::element::Type>();
|
auto shape_any = node_decoder->get_attribute("shape");
|
||||||
|
if (shape_any.is<ov::PartialShape>()) {
|
||||||
|
// sometimes shape attribute can be absent in the graph
|
||||||
|
// so we need to check if Any object is initialized first
|
||||||
|
pshape = shape_any.as<ov::PartialShape>();
|
||||||
|
}
|
||||||
|
auto dtype_any = node_decoder->get_attribute("dtype");
|
||||||
|
auto placeholder_name = node_decoder->get_op_name();
|
||||||
|
FRONT_END_GENERAL_CHECK(
|
||||||
|
dtype_any.is<ov::element::Type>(),
|
||||||
|
"Incorrect input model: Placeholder node " + placeholder_name + " has unspecified type.");
|
||||||
|
auto type = dtype_any.as<ov::element::Type>();
|
||||||
std::vector<std::string> names = {op_name};
|
std::vector<std::string> names = {op_name};
|
||||||
auto tensor_place = std::make_shared<TensorPlace>(m_input_model, pshape, type, names);
|
auto tensor_place = std::make_shared<TensorPlace>(m_input_model, pshape, type, names);
|
||||||
m_tensor_places[op_name] = tensor_place;
|
m_tensor_places[op_name] = tensor_place;
|
||||||
|
40
src/frontends/tensorflow/tests/convert_tricky_models.cpp
Normal file
40
src/frontends/tensorflow/tests/convert_tricky_models.cpp
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <openvino/frontend/exception.hpp>
|
||||||
|
#include <openvino/frontend/manager.hpp>
|
||||||
|
#include <openvino/op/util/framework_node.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include "tf_utils.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
using namespace ov::frontend;
|
||||||
|
|
||||||
|
TEST(FrontEndConvertModelTest, test_undefined_input_shape) {
|
||||||
|
FrontEndManager fem;
|
||||||
|
FrontEnd::Ptr frontEnd;
|
||||||
|
InputModel::Ptr inputModel;
|
||||||
|
ASSERT_NO_THROW(frontEnd = fem.load_by_framework(TF_FE));
|
||||||
|
ASSERT_NE(frontEnd, nullptr);
|
||||||
|
auto model_filename = FrontEndTestUtils::make_model_path(string(TEST_TENSORFLOW_MODELS_DIRNAME) +
|
||||||
|
string("undefined_input_shape/undefined_input_shape.pb"));
|
||||||
|
ASSERT_NO_THROW(inputModel = frontEnd->load(model_filename));
|
||||||
|
ASSERT_NE(inputModel, nullptr);
|
||||||
|
shared_ptr<ngraph::Function> function;
|
||||||
|
ASSERT_NO_THROW(function = frontEnd->convert(inputModel));
|
||||||
|
ASSERT_NE(function, nullptr);
|
||||||
|
|
||||||
|
for (auto& node : function->get_ordered_ops()) {
|
||||||
|
if (node->get_friendly_name() == "x") {
|
||||||
|
ASSERT_TRUE(node->get_output_partial_shape(0).same_scheme(ov::PartialShape::dynamic()));
|
||||||
|
} else if (node->get_friendly_name() == "y") {
|
||||||
|
ASSERT_TRUE(node->get_output_partial_shape(0).same_scheme(ov::PartialShape{2, 3}));
|
||||||
|
} else if (node->get_friendly_name() == "z") {
|
||||||
|
ASSERT_TRUE(node->get_output_partial_shape(0).same_scheme(ov::PartialShape::dynamic()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
node {
|
||||||
|
name: "x"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
unknown_rank: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "y"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "z"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "add"
|
||||||
|
op: "AddV2"
|
||||||
|
input: "x"
|
||||||
|
input: "y"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Mul"
|
||||||
|
op: "Mul"
|
||||||
|
input: "add"
|
||||||
|
input: "z"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,17 @@
|
|||||||
|
# Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
x = tf.placeholder(dtype=tf.float32, shape=None, name='x')
|
||||||
|
y = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='y')
|
||||||
|
z = tf.placeholder(dtype=tf.float32, shape=None, name='z')
|
||||||
|
add = tf.add(x, y, name="add")
|
||||||
|
tf.multiply(add, z)
|
||||||
|
|
||||||
|
tf.global_variables_initializer()
|
||||||
|
tf.io.write_graph(sess.graph, '.', 'undefined_input_shape.pbtxt', as_text=True)
|
Loading…
Reference in New Issue
Block a user