[TF FE] Avoid reinterpret_cast for non-standard types (#21506)

Clean-up translator for Placeholder

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev
2023-12-06 23:55:59 +04:00
committed by GitHub
parent cac6aadcc8
commit 148daf035b
2 changed files with 29 additions and 40 deletions

View File

@@ -4,10 +4,11 @@
#include "common_op_table.hpp"
#include "input_model.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/op/parameter.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov::opset8;
using namespace ov::op;
using namespace ov;
namespace ov {
@@ -16,40 +17,10 @@ namespace tensorflow {
namespace op {
OutputVector translate_placeholder_linked_op(const NodeContext& node) {
auto dtype = node.get_attribute<ov::element::Type>("dtype");
auto shape = node.get_attribute<ov::PartialShape>("shape", ov::PartialShape::dynamic());
auto translate_session = node.get_translate_session();
TENSORFLOW_OP_VALIDATION(node,
translate_session,
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
auto tensor_places = model->get_tensor_places();
auto saved_model_input_names = model->get_saved_model_input_names();
if (saved_model_input_names.get() && saved_model_input_names->size() > 0) {
auto input_name = saved_model_input_names->find(node.get_name());
if (input_name == saved_model_input_names->end()) {
input_name = saved_model_input_names->find(node.get_name() + ":0");
}
if (input_name != saved_model_input_names->end()) {
auto tensor_place = tensor_places.find(input_name->second);
if (tensor_place != tensor_places.end()) {
shape = tensor_place->second->get_partial_shape();
}
}
}
if (shape.rank().is_static() && shape.rank().get_length() == 0 && node.has_attribute("_output_shapes")) {
// we know some cases when Placeholder operation has empty scalar `shape` attribute value
// and non-empty `_output_shapes` attribute value.
// `_output_shapes` attribute value turns to be correct in this case
auto output_shapes = node.get_attribute<std::vector<ov::PartialShape>>("_output_shapes");
if (output_shapes.size() == 1 && output_shapes[0].rank().is_static()) {
shape = output_shapes[0];
}
}
auto res = std::make_shared<Parameter>(dtype, shape);
default_op_checks(node, 0, {"Placeholder"});
auto dtype = node.get_attribute<element::Type>("dtype");
auto shape = node.get_attribute<PartialShape>("shape", PartialShape::dynamic());
auto res = std::make_shared<v0::Parameter>(dtype, shape);
set_node_name(node.get_name(), res);
return res->outputs();
}

View File

@@ -69,7 +69,11 @@ OutputVector translate_varhandle_op(const NodeContext& node) {
TENSORFLOW_OP_VALIDATION(node,
translate_session,
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
auto model = dynamic_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
TENSORFLOW_OP_VALIDATION(
node,
model,
"[TensorFlow Frontend] internal error: cannot cast a pointer to ov::frontend::tensorflow::InputModel*");
auto var_index = model->get_variables_index();
auto ov_type = node.get_attribute<element::Type>("dtype");
std::shared_ptr<Node> const_node;
@@ -188,10 +192,24 @@ OutputVector translate_restorev2_op(const NodeContext& node) {
TENSORFLOW_OP_VALIDATION(node,
translate_session,
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
auto model = dynamic_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
TENSORFLOW_OP_VALIDATION(
node,
model,
"[TensorFlow Frontend] internal error: cannot cast a pointer to ov::frontend::tensorflow::InputModel*");
auto var_index = model->get_variables_index();
auto tensor_names =
reinterpret_cast<StringConstant*>(node.get_input(1).get_node())->get_data().as<std::vector<std::string>>();
auto string_constant_node = dynamic_pointer_cast<StringConstant>(node.get_input(1).get_node_shared_ptr());
TENSORFLOW_OP_VALIDATION(
node,
string_constant_node,
"[TensorFlow Frontend] internal error: cannot cast a node pointer to StringConstant pointer");
TENSORFLOW_OP_VALIDATION(
node,
string_constant_node->get_data().is<std::vector<std::string>>(),
"[TensorFlow Frontend] internal error: cannot cast data of StringConstant into std::vector<std::string>");
auto tensor_names = string_constant_node->get_data().as<std::vector<std::string>>();
auto tensor_types = node.get_attribute<std::vector<ov::element::Type>>("dtypes");
OutputVector outs = {};