From 148daf035bd2baa51c5189a9abfa326c9664ea79 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Wed, 6 Dec 2023 23:55:59 +0400 Subject: [PATCH] [TF FE] Avoid reinterpret_cast for non-standard types (#21506) Clean-up translator for Placeholder Signed-off-by: Kazantsev, Roman --- .../tensorflow/src/op/placeholder.cpp | 43 +++---------------- .../tensorflow/src/op/var_handle.cpp | 26 +++++++++-- 2 files changed, 29 insertions(+), 40 deletions(-) diff --git a/src/frontends/tensorflow/src/op/placeholder.cpp b/src/frontends/tensorflow/src/op/placeholder.cpp index 62babcf826f..9b4d54b07df 100644 --- a/src/frontends/tensorflow/src/op/placeholder.cpp +++ b/src/frontends/tensorflow/src/op/placeholder.cpp @@ -4,10 +4,11 @@ #include "common_op_table.hpp" #include "input_model.hpp" -#include "openvino/opsets/opset8.hpp" +#include "openvino/op/parameter.hpp" +#include "utils.hpp" using namespace std; -using namespace ov::opset8; +using namespace ov::op; using namespace ov; namespace ov { @@ -16,40 +17,10 @@ namespace tensorflow { namespace op { OutputVector translate_placeholder_linked_op(const NodeContext& node) { - auto dtype = node.get_attribute("dtype"); - auto shape = node.get_attribute("shape", ov::PartialShape::dynamic()); - auto translate_session = node.get_translate_session(); - TENSORFLOW_OP_VALIDATION(node, - translate_session, - "[TensorFlow Frontend] Internal error: Translate session is nullptr."); - auto model = reinterpret_cast(translate_session->get_input_model().get()); - auto tensor_places = model->get_tensor_places(); - auto saved_model_input_names = model->get_saved_model_input_names(); - - if (saved_model_input_names.get() && saved_model_input_names->size() > 0) { - auto input_name = saved_model_input_names->find(node.get_name()); - if (input_name == saved_model_input_names->end()) { - input_name = saved_model_input_names->find(node.get_name() + ":0"); - } - if (input_name != saved_model_input_names->end()) { - auto tensor_place = tensor_places.find(input_name->second); - if (tensor_place != tensor_places.end()) { - shape = tensor_place->second->get_partial_shape(); - } - } - } - - if (shape.rank().is_static() && shape.rank().get_length() == 0 && node.has_attribute("_output_shapes")) { - // we know some cases when Placeholder operation has empty scalar `shape` attribute value - // and non-empty `_output_shapes` attribute value. - // `_output_shapes` attribute value turns to be correct in this case - auto output_shapes = node.get_attribute>("_output_shapes"); - if (output_shapes.size() == 1 && output_shapes[0].rank().is_static()) { - shape = output_shapes[0]; - } - } - - auto res = std::make_shared(dtype, shape); + default_op_checks(node, 0, {"Placeholder"}); + auto dtype = node.get_attribute("dtype"); + auto shape = node.get_attribute("shape", PartialShape::dynamic()); + auto res = std::make_shared(dtype, shape); set_node_name(node.get_name(), res); return res->outputs(); } diff --git a/src/frontends/tensorflow/src/op/var_handle.cpp b/src/frontends/tensorflow/src/op/var_handle.cpp index e9eb0eaa181..7aed518c5dc 100644 --- a/src/frontends/tensorflow/src/op/var_handle.cpp +++ b/src/frontends/tensorflow/src/op/var_handle.cpp @@ -69,7 +69,11 @@ OutputVector translate_varhandle_op(const NodeContext& node) { TENSORFLOW_OP_VALIDATION(node, translate_session, "[TensorFlow Frontend] Internal error: Translate session is nullptr."); - auto model = reinterpret_cast(translate_session->get_input_model().get()); + auto model = dynamic_cast(translate_session->get_input_model().get()); + TENSORFLOW_OP_VALIDATION( + node, + model, + "[TensorFlow Frontend] internal error: cannot cast a pointer to ov::frontend::tensorflow::InputModel*"); auto var_index = model->get_variables_index(); auto ov_type = node.get_attribute("dtype"); std::shared_ptr const_node; @@ -188,10 +192,24 @@ OutputVector translate_restorev2_op(const NodeContext& node) { TENSORFLOW_OP_VALIDATION(node, translate_session, "[TensorFlow Frontend] Internal error: Translate session is nullptr."); - auto model = reinterpret_cast(translate_session->get_input_model().get()); + auto model = dynamic_cast(translate_session->get_input_model().get()); + TENSORFLOW_OP_VALIDATION( + node, + model, + "[TensorFlow Frontend] internal error: cannot cast a pointer to ov::frontend::tensorflow::InputModel*"); auto var_index = model->get_variables_index(); - auto tensor_names = - reinterpret_cast(node.get_input(1).get_node())->get_data().as>(); + + auto string_constant_node = dynamic_pointer_cast(node.get_input(1).get_node_shared_ptr()); + TENSORFLOW_OP_VALIDATION( + node, + string_constant_node, + "[TensorFlow Frontend] internal error: cannot cast a node pointer to StringConstant pointer"); + TENSORFLOW_OP_VALIDATION( + node, + string_constant_node->get_data().is>(), + "[TensorFlow Frontend] internal error: cannot cast data of StringConstant into std::vector"); + auto tensor_names = string_constant_node->get_data().as>(); + auto tensor_types = node.get_attribute>("dtypes"); OutputVector outs = {};