[TF FE] Avoid reinterpret_cast for non-standard types (#21506)
Clean-up translator for Placeholder Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
@@ -4,10 +4,11 @@
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "input_model.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset8;
|
||||
using namespace ov::op;
|
||||
using namespace ov;
|
||||
|
||||
namespace ov {
|
||||
@@ -16,40 +17,10 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_placeholder_linked_op(const NodeContext& node) {
|
||||
auto dtype = node.get_attribute<ov::element::Type>("dtype");
|
||||
auto shape = node.get_attribute<ov::PartialShape>("shape", ov::PartialShape::dynamic());
|
||||
auto translate_session = node.get_translate_session();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
translate_session,
|
||||
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
|
||||
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
|
||||
auto tensor_places = model->get_tensor_places();
|
||||
auto saved_model_input_names = model->get_saved_model_input_names();
|
||||
|
||||
if (saved_model_input_names.get() && saved_model_input_names->size() > 0) {
|
||||
auto input_name = saved_model_input_names->find(node.get_name());
|
||||
if (input_name == saved_model_input_names->end()) {
|
||||
input_name = saved_model_input_names->find(node.get_name() + ":0");
|
||||
}
|
||||
if (input_name != saved_model_input_names->end()) {
|
||||
auto tensor_place = tensor_places.find(input_name->second);
|
||||
if (tensor_place != tensor_places.end()) {
|
||||
shape = tensor_place->second->get_partial_shape();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (shape.rank().is_static() && shape.rank().get_length() == 0 && node.has_attribute("_output_shapes")) {
|
||||
// we know some cases when Placeholder operation has empty scalar `shape` attribute value
|
||||
// and non-empty `_output_shapes` attribute value.
|
||||
// `_output_shapes` attribute value turns to be correct in this case
|
||||
auto output_shapes = node.get_attribute<std::vector<ov::PartialShape>>("_output_shapes");
|
||||
if (output_shapes.size() == 1 && output_shapes[0].rank().is_static()) {
|
||||
shape = output_shapes[0];
|
||||
}
|
||||
}
|
||||
|
||||
auto res = std::make_shared<Parameter>(dtype, shape);
|
||||
default_op_checks(node, 0, {"Placeholder"});
|
||||
auto dtype = node.get_attribute<element::Type>("dtype");
|
||||
auto shape = node.get_attribute<PartialShape>("shape", PartialShape::dynamic());
|
||||
auto res = std::make_shared<v0::Parameter>(dtype, shape);
|
||||
set_node_name(node.get_name(), res);
|
||||
return res->outputs();
|
||||
}
|
||||
|
||||
@@ -69,7 +69,11 @@ OutputVector translate_varhandle_op(const NodeContext& node) {
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
translate_session,
|
||||
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
|
||||
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
|
||||
auto model = dynamic_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
model,
|
||||
"[TensorFlow Frontend] internal error: cannot cast a pointer to ov::frontend::tensorflow::InputModel*");
|
||||
auto var_index = model->get_variables_index();
|
||||
auto ov_type = node.get_attribute<element::Type>("dtype");
|
||||
std::shared_ptr<Node> const_node;
|
||||
@@ -188,10 +192,24 @@ OutputVector translate_restorev2_op(const NodeContext& node) {
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
translate_session,
|
||||
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
|
||||
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
|
||||
auto model = dynamic_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
model,
|
||||
"[TensorFlow Frontend] internal error: cannot cast a pointer to ov::frontend::tensorflow::InputModel*");
|
||||
auto var_index = model->get_variables_index();
|
||||
auto tensor_names =
|
||||
reinterpret_cast<StringConstant*>(node.get_input(1).get_node())->get_data().as<std::vector<std::string>>();
|
||||
|
||||
auto string_constant_node = dynamic_pointer_cast<StringConstant>(node.get_input(1).get_node_shared_ptr());
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
string_constant_node,
|
||||
"[TensorFlow Frontend] internal error: cannot cast a node pointer to StringConstant pointer");
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
string_constant_node->get_data().is<std::vector<std::string>>(),
|
||||
"[TensorFlow Frontend] internal error: cannot cast data of StringConstant into std::vector<std::string>");
|
||||
auto tensor_names = string_constant_node->get_data().as<std::vector<std::string>>();
|
||||
|
||||
auto tensor_types = node.get_attribute<std::vector<ov::element::Type>>("dtypes");
|
||||
|
||||
OutputVector outs = {};
|
||||
|
||||
Reference in New Issue
Block a user