diff --git a/ngraph/frontend/tensorflow/include/tensorflow_frontend/frontend.hpp b/ngraph/frontend/tensorflow/include/tensorflow_frontend/frontend.hpp index 301ec254346..5f0996035b0 100644 --- a/ngraph/frontend/tensorflow/include/tensorflow_frontend/frontend.hpp +++ b/ngraph/frontend/tensorflow/include/tensorflow_frontend/frontend.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include namespace ov { @@ -74,7 +73,7 @@ protected: const std::vector>& variants) const override; private: - void translate_graph(const std::shared_ptr& model, + void translate_graph(const ngraph::frontend::InputModel::Ptr& model, const std::string& model_name, bool fail_fast, bool no_conversion, diff --git a/ngraph/frontend/tensorflow/include/tensorflow_frontend/graph_iterator.hpp b/ngraph/frontend/tensorflow/include/tensorflow_frontend/graph_iterator.hpp index 43f26f04d77..aa5f7399bd5 100644 --- a/ngraph/frontend/tensorflow/include/tensorflow_frontend/graph_iterator.hpp +++ b/ngraph/frontend/tensorflow/include/tensorflow_frontend/graph_iterator.hpp @@ -4,6 +4,7 @@ #pragma once +#include #include #include @@ -34,3 +35,15 @@ public: }; } // namespace frontend } // namespace ov + +namespace ov { +/// Keep GraphIterator::Ptr object and type information for type-safe +/// dynamic conversions without using C++ RTTI +template <> +class TF_API VariantWrapper<::ov::frontend::GraphIterator::Ptr> + : public VariantImpl<::ov::frontend::GraphIterator::Ptr> { +public: + OPENVINO_RTTI("Variant::GraphIterator::Ptr"); + VariantWrapper(const value_type& value) : VariantImpl(value) {} +}; +} // namespace ov diff --git a/ngraph/frontend/tensorflow/src/decoder_proto.hpp b/ngraph/frontend/tensorflow/src/decoder_proto.hpp index e1d620c4efb..a446cdecf79 100644 --- a/ngraph/frontend/tensorflow/src/decoder_proto.hpp +++ b/ngraph/frontend/tensorflow/src/decoder_proto.hpp @@ -7,8 +7,6 @@ #include #include #include -#include -#include #include #include "attr_value.pb.h" diff --git a/ngraph/frontend/tensorflow/src/exceptions.cpp b/ngraph/frontend/tensorflow/src/exceptions.cpp index f74796a3974..83b7c5716db 100644 --- a/ngraph/frontend/tensorflow/src/exceptions.cpp +++ b/ngraph/frontend/tensorflow/src/exceptions.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include "exceptions.hpp" #include "node_context.hpp" diff --git a/ngraph/frontend/tensorflow/include/tensorflow_frontend/exceptions.hpp b/ngraph/frontend/tensorflow/src/exceptions.hpp similarity index 100% rename from ngraph/frontend/tensorflow/include/tensorflow_frontend/exceptions.hpp rename to ngraph/frontend/tensorflow/src/exceptions.hpp diff --git a/ngraph/frontend/tensorflow/src/frontend.cpp b/ngraph/frontend/tensorflow/src/frontend.cpp index 7475814ffd8..5ff7358cf02 100644 --- a/ngraph/frontend/tensorflow/src/frontend.cpp +++ b/ngraph/frontend/tensorflow/src/frontend.cpp @@ -4,8 +4,9 @@ #include #include -#include +#include +#include "model.hpp" #include "op_table.hpp" #include "pass/transpose_sinking.hpp" #include "tf_framework_node.hpp" @@ -47,7 +48,7 @@ void translate_framework_node(const std::shared_ptr& node, FrontEndTF::FrontEndTF() : m_op_translators(tf::op::get_supported_ops()) {} -void FrontEndTF::translate_graph(const std::shared_ptr& model, +void FrontEndTF::translate_graph(const ngraph::frontend::InputModel::Ptr& model, const std::string& model_name, bool fail_fast, bool no_conversion, @@ -57,10 +58,12 @@ void FrontEndTF::translate_graph(const std::shared_ptr& model, ov::ParameterVector params; ov::ResultVector results; - const auto& operation_places = model->get_op_places(); - const auto& model_inputs = model->get_inputs(); - const auto& model_outputs = model->get_outputs(); - const auto& model_frozen_inputs = model->get_tensor_values(); + const auto& model_tf = std::dynamic_pointer_cast(model); + FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into nGraph function"); + const auto& operation_places = model_tf->get_op_places(); + const auto& model_inputs = model_tf->get_inputs(); + const auto& model_outputs = model_tf->get_outputs(); + const auto& model_frozen_inputs = model_tf->get_tensor_values(); std::map> translate_map; @@ -282,6 +285,8 @@ bool FrontEndTF::supported_impl(const std::vector>& if (ov::util::ends_with(model_path, suffix.c_str())) { return true; } + } else if (ov::is_type>(variants[0])) { + return true; } return false; } @@ -298,6 +303,9 @@ ngraph::frontend::InputModel::Ptr FrontEndTF::load_impl( return std::make_shared( std::make_shared<::ov::frontend::tf::GraphIteratorProto>(model_path)); } + } else if (ov::is_type>(variants[0])) { + auto graph_iterator = ov::as_type_ptr>(variants[0])->get(); + return std::make_shared(graph_iterator); } } return nullptr; diff --git a/ngraph/frontend/tensorflow/src/model.cpp b/ngraph/frontend/tensorflow/src/model.cpp index 6ab98998236..effc82b029d 100644 --- a/ngraph/frontend/tensorflow/src/model.cpp +++ b/ngraph/frontend/tensorflow/src/model.cpp @@ -2,18 +2,18 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "model.hpp" + #include #include #include #include #include -#include -#include -#include #include "graph_iterator_proto.hpp" #include "ngraph_conversions.hpp" #include "node_context.hpp" +#include "place.hpp" #include "utils.hpp" using namespace google; diff --git a/ngraph/frontend/tensorflow/include/tensorflow_frontend/model.hpp b/ngraph/frontend/tensorflow/src/model.hpp similarity index 94% rename from ngraph/frontend/tensorflow/include/tensorflow_frontend/model.hpp rename to ngraph/frontend/tensorflow/src/model.hpp index e1dfad29a6f..e3b9d49a599 100644 --- a/ngraph/frontend/tensorflow/include/tensorflow_frontend/model.hpp +++ b/ngraph/frontend/tensorflow/src/model.hpp @@ -7,7 +7,6 @@ #include #include #include -#include namespace ov { namespace frontend { @@ -15,7 +14,7 @@ namespace frontend { class OpPlaceTF; class TensorPlaceTF; -class TF_API InputModelTF : public ngraph::frontend::InputModel { +class InputModelTF : public ngraph::frontend::InputModel { friend class FrontEndTF; class InputModelTFImpl; std::shared_ptr _impl; diff --git a/ngraph/frontend/tensorflow/src/node_context.hpp b/ngraph/frontend/tensorflow/src/node_context.hpp index 52a01560505..f5a44e78b5c 100644 --- a/ngraph/frontend/tensorflow/src/node_context.hpp +++ b/ngraph/frontend/tensorflow/src/node_context.hpp @@ -4,10 +4,10 @@ #pragma once #include -#include -#include #include +#include "exceptions.hpp" +#include "place.hpp" #include "tensor.pb.h" #include "types.pb.h" diff --git a/ngraph/frontend/tensorflow/src/place.cpp b/ngraph/frontend/tensorflow/src/place.cpp index 4d7ba1e17d6..25f6126c1a8 100644 --- a/ngraph/frontend/tensorflow/src/place.cpp +++ b/ngraph/frontend/tensorflow/src/place.cpp @@ -2,8 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "place.hpp" + #include -#include #include "node_context.hpp" #include "op_def.pb.h" diff --git a/ngraph/frontend/tensorflow/include/tensorflow_frontend/place.hpp b/ngraph/frontend/tensorflow/src/place.hpp similarity index 100% rename from ngraph/frontend/tensorflow/include/tensorflow_frontend/place.hpp rename to ngraph/frontend/tensorflow/src/place.hpp diff --git a/ngraph/frontend/tensorflow/src/tf_framework_node.hpp b/ngraph/frontend/tensorflow/src/tf_framework_node.hpp index dceb1fc1378..8074ea43f20 100644 --- a/ngraph/frontend/tensorflow/src/tf_framework_node.hpp +++ b/ngraph/frontend/tensorflow/src/tf_framework_node.hpp @@ -6,9 +6,7 @@ #include #include -#include - -#include "graph_iterator_proto.hpp" +#include namespace ov { namespace frontend {