Hide extra header files of TF Frontend into src (#8086)

* Move extra TF Frontend headers to src

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Load a model by GraphIterator

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Apply code style

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Avoid use of InputModelTF in FrontEndTF API and correct a comment

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Apply code style

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Apply code style

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2021-10-23 01:17:06 +03:00 committed by GitHub
parent a7dff0d0ec
commit 8b3a7cfc8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 38 additions and 22 deletions

View File

@ -10,7 +10,6 @@
#include <map>
#include <openvino/core/node_vector.hpp>
#include <openvino/core/variant.hpp>
#include <tensorflow_frontend/model.hpp>
#include <tensorflow_frontend/utility.hpp>
namespace ov {
@ -74,7 +73,7 @@ protected:
const std::vector<std::shared_ptr<ov::Variant>>& variants) const override;
private:
void translate_graph(const std::shared_ptr<InputModelTF>& model,
void translate_graph(const ngraph::frontend::InputModel::Ptr& model,
const std::string& model_name,
bool fail_fast,
bool no_conversion,

View File

@ -4,6 +4,7 @@
#pragma once
#include <openvino/core/variant.hpp>
#include <tensorflow_frontend/decoder.hpp>
#include <tensorflow_frontend/utility.hpp>
@ -34,3 +35,15 @@ public:
};
} // namespace frontend
} // namespace ov
namespace ov {
/// Keep GraphIterator::Ptr object and type information for type-safe
/// dynamic conversions without using C++ RTTI
template <>
class TF_API VariantWrapper<::ov::frontend::GraphIterator::Ptr>
: public VariantImpl<::ov::frontend::GraphIterator::Ptr> {
public:
OPENVINO_RTTI("Variant::GraphIterator::Ptr");
VariantWrapper(const value_type& value) : VariantImpl<value_type>(value) {}
};
} // namespace ov

View File

@ -7,8 +7,6 @@
#include <ngraph/ngraph.hpp>
#include <string>
#include <tensorflow_frontend/decoder.hpp>
#include <tensorflow_frontend/frontend.hpp>
#include <tensorflow_frontend/place.hpp>
#include <vector>
#include "attr_value.pb.h"

View File

@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <tensorflow_frontend/exceptions.hpp>
#include "exceptions.hpp"
#include "node_context.hpp"

View File

@ -4,8 +4,9 @@
#include <openvino/util/common_util.hpp>
#include <tensorflow_frontend/frontend.hpp>
#include <tensorflow_frontend/model.hpp>
#include <tensorflow_frontend/graph_iterator.hpp>
#include "model.hpp"
#include "op_table.hpp"
#include "pass/transpose_sinking.hpp"
#include "tf_framework_node.hpp"
@ -47,7 +48,7 @@ void translate_framework_node(const std::shared_ptr<TFFrameworkNode>& node,
FrontEndTF::FrontEndTF() : m_op_translators(tf::op::get_supported_ops()) {}
void FrontEndTF::translate_graph(const std::shared_ptr<InputModelTF>& model,
void FrontEndTF::translate_graph(const ngraph::frontend::InputModel::Ptr& model,
const std::string& model_name,
bool fail_fast,
bool no_conversion,
@ -57,10 +58,12 @@ void FrontEndTF::translate_graph(const std::shared_ptr<InputModelTF>& model,
ov::ParameterVector params;
ov::ResultVector results;
const auto& operation_places = model->get_op_places();
const auto& model_inputs = model->get_inputs();
const auto& model_outputs = model->get_outputs();
const auto& model_frozen_inputs = model->get_tensor_values();
const auto& model_tf = std::dynamic_pointer_cast<InputModelTF>(model);
FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into nGraph function");
const auto& operation_places = model_tf->get_op_places();
const auto& model_inputs = model_tf->get_inputs();
const auto& model_outputs = model_tf->get_outputs();
const auto& model_frozen_inputs = model_tf->get_tensor_values();
std::map<const std::string, const std::function<ov::OutputVector(const NodeContext&)>> translate_map;
@ -282,6 +285,8 @@ bool FrontEndTF::supported_impl(const std::vector<std::shared_ptr<ov::Variant>>&
if (ov::util::ends_with(model_path, suffix.c_str())) {
return true;
}
} else if (ov::is_type<VariantWrapper<GraphIterator::Ptr>>(variants[0])) {
return true;
}
return false;
}
@ -298,6 +303,9 @@ ngraph::frontend::InputModel::Ptr FrontEndTF::load_impl(
return std::make_shared<InputModelTF>(
std::make_shared<::ov::frontend::tf::GraphIteratorProto>(model_path));
}
} else if (ov::is_type<VariantWrapper<GraphIterator::Ptr>>(variants[0])) {
auto graph_iterator = ov::as_type_ptr<VariantWrapper<GraphIterator::Ptr>>(variants[0])->get();
return std::make_shared<InputModelTF>(graph_iterator);
}
}
return nullptr;

View File

@ -2,18 +2,18 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "model.hpp"
#include <frontend_manager/frontend_exceptions.hpp>
#include <fstream>
#include <openvino/opsets/opset7.hpp>
#include <queue>
#include <tensorflow_frontend/graph_iterator.hpp>
#include <tensorflow_frontend/model.hpp>
#include <tensorflow_frontend/place.hpp>
#include <tensorflow_frontend/utility.hpp>
#include "graph_iterator_proto.hpp"
#include "ngraph_conversions.hpp"
#include "node_context.hpp"
#include "place.hpp"
#include "utils.hpp"
using namespace google;

View File

@ -7,7 +7,6 @@
#include <frontend_manager/input_model.hpp>
#include <frontend_manager/place.hpp>
#include <tensorflow_frontend/graph_iterator.hpp>
#include <tensorflow_frontend/utility.hpp>
namespace ov {
namespace frontend {
@ -15,7 +14,7 @@ namespace frontend {
class OpPlaceTF;
class TensorPlaceTF;
class TF_API InputModelTF : public ngraph::frontend::InputModel {
class InputModelTF : public ngraph::frontend::InputModel {
friend class FrontEndTF;
class InputModelTFImpl;
std::shared_ptr<InputModelTFImpl> _impl;

View File

@ -4,10 +4,10 @@
#pragma once
#include <openvino/core/variant.hpp>
#include <tensorflow_frontend/exceptions.hpp>
#include <tensorflow_frontend/place.hpp>
#include <tensorflow_frontend/utility.hpp>
#include "exceptions.hpp"
#include "place.hpp"
#include "tensor.pb.h"
#include "types.pb.h"

View File

@ -2,8 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "place.hpp"
#include <frontend_manager/frontend_exceptions.hpp>
#include <tensorflow_frontend/place.hpp>
#include "node_context.hpp"
#include "op_def.pb.h"

View File

@ -6,9 +6,7 @@
#include <algorithm>
#include <openvino/op/util/framework_node.hpp>
#include <tensorflow_frontend/place.hpp>
#include "graph_iterator_proto.hpp"
#include <tensorflow_frontend/decoder.hpp>
namespace ov {
namespace frontend {