[POC][TF FE] Support SavedModel format (with compression) (#16317)
* Added Saved Model proto descriptors * Included Google's protobuf repository * Added wstring version of ov::util::directory_exists * Added initial implementation of Saved Model iterator # Conflicts: # src/frontends/tensorflow/src/frontend.cpp * Added missing proto files to repository * Implemented reading of variables index and data files # Conflicts: # src/frontends/tensorflow/src/frontend.cpp * Renamed class # Conflicts: # src/frontends/tensorflow/src/frontend.cpp * Fix for cross-platform directory_exists * Fixed codestyle and simplified code * CI fixes * Separeted Saved Model iterator from Proto iterator * Moved variables index into separate class * Added initial implementation of reading a variables from saved model # Conflicts: # src/frontends/tensorflow/src/frontend.cpp * Added external variable mapping * Code cleanup * Commit is for discussion purposes!!! Implemented RestoreV2 with a workaround for strings Not optimized, includes mem leak * In progress... * Added DT_STRING coverage into decoder_proto * m_variables_index moved into underlying class * Updated copyrgihts, added space between license and code * Moved string constant to separate class * Added AssignVariableOp operation * Changed behavior of RestoreV2 Updated stubs for other ops * Second working implementation, enabled: Program-only models Variables reading from data files * Extended docs * Fixed dynamic type * Fixed naming * Added Snappy submodule to support compression in TF FE * Enabled Snappy Compression for TF FE * Make static linkage of Snappy Changing Warning as error behavior for 3rd party * CI fixes * Added Snappy copyright info * Aligned behavior of StringConstant with UnsupportedConstant * Added correct naming and removing unused inputs/outputs
This commit is contained in:
parent
9eab122952
commit
c5b348dd4f
3
.gitmodules
vendored
3
.gitmodules
vendored
|
@ -66,3 +66,6 @@
|
|||
[submodule "thirdparty/flatbuffers/flatbuffers"]
|
||||
path = thirdparty/flatbuffers/flatbuffers
|
||||
url = https://github.com/google/flatbuffers.git
|
||||
[submodule "thirdparty/snappy"]
|
||||
path = thirdparty/snappy
|
||||
url = https://github.com/google/snappy.git
|
||||
|
|
|
@ -156,6 +156,8 @@ ie_option(ENABLE_OV_TF_FRONTEND "Enable TensorFlow FrontEnd" ON)
|
|||
ie_option(ENABLE_OV_TF_LITE_FRONTEND "Enable TensorFlow Lite FrontEnd" ON)
|
||||
ie_dependent_option(ENABLE_SYSTEM_PROTOBUF "Use system protobuf" OFF
|
||||
"ENABLE_OV_ONNX_FRONTEND OR ENABLE_OV_PADDLE_FRONTEND OR ENABLE_OV_TF_FRONTEND;BUILD_SHARED_LIBS" OFF)
|
||||
ie_dependent_option(ENABLE_SNAPPY_COMPRESSION "Enables compression support for TF FE" ON
|
||||
"ENABLE_OV_TF_FRONTEND" ON)
|
||||
ie_option(ENABLE_OV_IR_FRONTEND "Enable IR FrontEnd" ON)
|
||||
ie_dependent_option(ENABLE_SYSTEM_FLATBUFFERS "Use system flatbuffers" ON
|
||||
"ENABLE_OV_TF_LITE_FRONTEND" OFF)
|
||||
|
|
|
@ -1547,3 +1547,61 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
-------------------------------------------------------------
|
||||
|
||||
28. Snappy (https://github.com/google/snappy/)
|
||||
|
||||
Copyright 2011, Google Inc.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
===
|
||||
|
||||
Some of the benchmark data in testdata/ is licensed differently:
|
||||
|
||||
- fireworks.jpeg is Copyright 2013 Steinar H. Gunderson, and
|
||||
is licensed under the Creative Commons Attribution 3.0 license
|
||||
(CC-BY-3.0). See https://creativecommons.org/licenses/by/3.0/
|
||||
for more information.
|
||||
|
||||
- kppkn.gtb is taken from the Gaviota chess tablebase set, and
|
||||
is licensed under the MIT License. See
|
||||
https://sites.google.com/site/gaviotachessengine/Home/endgame-tablebases-1
|
||||
for more information.
|
||||
|
||||
- paper-100k.pdf is an excerpt (bytes 92160 to 194560) from the paper
|
||||
“Combinatorial Modeling of Chromatin Features Quantitatively Predicts DNA
|
||||
Replication Timing in _Drosophila_” by Federico Comoglio and Renato Paro,
|
||||
which is licensed under the CC-BY license. See
|
||||
http://www.ploscompbiol.org/static/license for more ifnormation.
|
||||
|
||||
- alice29.txt, asyoulik.txt, plrabn12.txt and lcet10.txt are from Project
|
||||
Gutenberg. The first three have expired copyrights and are in the public
|
||||
domain; the latter does not have expired copyright, but is still in the
|
||||
public domain according to the license information
|
||||
(http://www.gutenberg.org/ebooks/53).
|
||||
|
|
|
@ -123,6 +123,15 @@ void create_directory_recursive(const std::string& path);
|
|||
*/
|
||||
bool directory_exists(const std::string& path);
|
||||
|
||||
#ifdef OPENVINO_ENABLE_UNICODE_PATH_SUPPORT
|
||||
/**
|
||||
* @brief Interface function to check if directory exists for given path
|
||||
* @param path - path to directory wide-string
|
||||
* @return true if directory exists, false otherwise
|
||||
*/
|
||||
bool directory_exists(const std::wstring& path);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Returns file size for file
|
||||
* @param[in] path The file name
|
||||
|
|
|
@ -27,6 +27,9 @@
|
|||
# define get_absolute_path(result, path) _fullpath(result, path.c_str(), MAX_ABS_PATH)
|
||||
/// @brief Windows-specific 'stat' wrapper
|
||||
# define stat _stat
|
||||
# ifdef OPENVINO_ENABLE_UNICODE_PATH_SUPPORT
|
||||
# define wstat _wstat
|
||||
# endif
|
||||
/// @brief Windows-specific 'mkdir' wrapper
|
||||
# define makedir(dir) _mkdir(dir)
|
||||
// Copied from linux libc sys/stat.h:
|
||||
|
@ -403,6 +406,21 @@ bool ov::util::directory_exists(const std::string& path) {
|
|||
return false;
|
||||
}
|
||||
|
||||
#ifdef OPENVINO_ENABLE_UNICODE_PATH_SUPPORT
|
||||
bool ov::util::directory_exists(const std::wstring& path) {
|
||||
# ifdef _WIN32
|
||||
struct stat sb;
|
||||
|
||||
if (wstat(path.c_str(), &sb) == 0 && S_ISDIR(sb.st_mode)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
# else
|
||||
return directory_exists(wstring_to_string(path));
|
||||
# endif
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename C,
|
||||
|
|
|
@ -2,7 +2,22 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
|
||||
list(APPEND CUSTOM_LINK_LIBRARIES
|
||||
openvino::core::dev
|
||||
openvino::frontend::tensorflow_common
|
||||
)
|
||||
|
||||
if(ENABLE_SNAPPY_COMPRESSION)
|
||||
list(APPEND CUSTOM_LINK_LIBRARIES
|
||||
snappy
|
||||
)
|
||||
endif()
|
||||
|
||||
ov_add_frontend(NAME tensorflow
|
||||
LINKABLE_FRONTEND
|
||||
FILEDESCRIPTION "FrontEnd to load and convert TensorFlow file format"
|
||||
LINK_LIBRARIES openvino::core::dev openvino::frontend::tensorflow_common)
|
||||
LINK_LIBRARIES ${CUSTOM_LINK_LIBRARIES})
|
||||
|
||||
if(ENABLE_SNAPPY_COMPRESSION)
|
||||
target_compile_definitions(openvino_tensorflow_frontend PUBLIC ENABLE_SNAPPY_COMPRESSION)
|
||||
endif()
|
||||
|
|
|
@ -129,7 +129,12 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
|
|||
}
|
||||
|
||||
case ::tensorflow::AttrValue::ValueCase::kType: {
|
||||
return get_ov_type(attrs[0].type());
|
||||
auto atype = attrs[0].type();
|
||||
if (atype != ::tensorflow::DT_STRING) {
|
||||
return get_ov_type(attrs[0].type());
|
||||
} else {
|
||||
return ov::Any("DT_STRING");
|
||||
}
|
||||
}
|
||||
|
||||
case ::tensorflow::AttrValue::ValueCase::kList: {
|
||||
|
@ -168,7 +173,11 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
|
|||
if (list.type_size()) {
|
||||
std::vector<ov::element::Type> res;
|
||||
for (int idx = 0; idx < list.type_size(); ++idx) {
|
||||
res.emplace_back(get_ov_type(list.type(idx)));
|
||||
if (list.type(idx) != ::tensorflow::DataType::DT_STRING) {
|
||||
res.emplace_back(get_ov_type(list.type(idx)));
|
||||
} else {
|
||||
res.emplace_back(ov::element::undefined);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -194,9 +203,22 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
|
|||
FRONT_END_GENERAL_CHECK(pshape.is_static(), "Dynamic shapes are not supported for Tensor attribute.");
|
||||
const auto& tf_type = tensor_proto.dtype();
|
||||
auto ov_type = get_ov_type(tf_type);
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
ov_type.is_static(),
|
||||
"Encountered unknown element type " + DataType_Name(tf_type) + " on an empty tensor_proto");
|
||||
if (tf_type != ::tensorflow::DataType::DT_STRING) {
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
ov_type.is_static(),
|
||||
"Encountered unknown element type " + DataType_Name(tf_type) + " on an empty tensor_proto");
|
||||
} else {
|
||||
ov_type = ov::element::u64;
|
||||
pshape.resize(0);
|
||||
pshape.push_back(tensor_proto.string_val_size());
|
||||
}
|
||||
if (tf_type == ::tensorflow::DataType::DT_STRING) {
|
||||
auto data = std::vector<std::string>();
|
||||
for (auto& item : tensor_proto.string_val()) {
|
||||
data.push_back(item);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
ov::Tensor res(ov_type, pshape.get_shape());
|
||||
auto tensor_content = tensor_proto.tensor_content();
|
||||
if (!tensor_content.empty() && tensor_proto.has_tensor_shape()) {
|
||||
|
|
|
@ -5,10 +5,12 @@
|
|||
#include "openvino/frontend/tensorflow/frontend.hpp"
|
||||
|
||||
#include "graph_iterator_proto.hpp"
|
||||
#include "graph_iterator_saved_model.hpp"
|
||||
#include "helper_transforms/block_lstm_replacer.hpp"
|
||||
#include "helper_transforms/const_to_result_remover.hpp"
|
||||
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
|
||||
#include "helper_transforms/gru_block_cell_replacer.hpp"
|
||||
#include "helper_transforms/saved_model_unused_remover.hpp"
|
||||
#include "input_model.hpp"
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/frontend/tensorflow/extension/conversion.hpp"
|
||||
|
@ -86,6 +88,8 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
|||
// for automatic deduction of the frontend to convert the model
|
||||
// we have more strict rule that is to have `.pb` extension in the path
|
||||
return true;
|
||||
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
|
@ -97,6 +101,8 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
|||
// for automatic deduction of the frontend to convert the model
|
||||
// we have more strict rule that is to have `.pb` extension in the path
|
||||
return true;
|
||||
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -118,6 +124,18 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
|
|||
if (GraphIteratorProto::is_supported(model_path)) {
|
||||
// handle binary protobuf format
|
||||
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProto>(model_path), m_telemetry);
|
||||
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
|
||||
std::shared_ptr<GraphIteratorSavedModel> graph_iterator;
|
||||
if (variants.size() > 1 && variants[1].is<std::string>()) {
|
||||
graph_iterator = std::make_shared<GraphIteratorSavedModel>(model_path, variants[1].as<std::string>());
|
||||
} else {
|
||||
graph_iterator = std::make_shared<GraphIteratorSavedModel>(model_path, std::string("serve"));
|
||||
}
|
||||
return std::make_shared<InputModel>(graph_iterator,
|
||||
m_telemetry,
|
||||
graph_iterator->get_variables_index(),
|
||||
graph_iterator->get_saved_model_input_names(),
|
||||
graph_iterator->get_saved_model_output_names());
|
||||
}
|
||||
}
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
|
@ -126,6 +144,20 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
|
|||
if (GraphIteratorProto::is_supported(model_path)) {
|
||||
// handle binary protobuf format with a path in Unicode
|
||||
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProto>(model_path), m_telemetry);
|
||||
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
|
||||
std::shared_ptr<GraphIteratorSavedModel> graph_iterator;
|
||||
if (variants.size() > 1 && variants[1].is<std::string>()) {
|
||||
graph_iterator = std::make_shared<GraphIteratorSavedModel>(
|
||||
model_path,
|
||||
ov::util::wstring_to_string(variants[1].as<std::wstring>()));
|
||||
} else {
|
||||
graph_iterator = std::make_shared<GraphIteratorSavedModel>(model_path, std::string("serve"));
|
||||
}
|
||||
return std::make_shared<InputModel>(graph_iterator,
|
||||
m_telemetry,
|
||||
graph_iterator->get_variables_index(),
|
||||
graph_iterator->get_saved_model_input_names(),
|
||||
graph_iterator->get_saved_model_output_names());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -232,6 +264,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
|||
// run transformations to convert sub-graphs with intermediate (or FrameworkNode) operations
|
||||
// into sub-graphs with only OpenVINO operations
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<pass::SavedModelUnusedRemover>();
|
||||
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
|
||||
manager.register_pass<pass::BlockLSTMReplacer>();
|
||||
manager.register_pass<pass::GRUBlockCellReplacer>();
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "decoder_argdef.hpp"
|
||||
#include "decoder_proto.hpp"
|
||||
|
@ -18,6 +19,7 @@ namespace frontend {
|
|||
namespace tensorflow {
|
||||
|
||||
class GraphIteratorProto : public GraphIterator {
|
||||
protected:
|
||||
std::shared_ptr<::tensorflow::GraphDef> m_graph_def;
|
||||
std::shared_ptr<::tensorflow::FunctionDef> m_func_def;
|
||||
|
||||
|
@ -27,6 +29,11 @@ class GraphIteratorProto : public GraphIterator {
|
|||
std::vector<std::string> m_input_names;
|
||||
std::vector<std::string> m_output_names;
|
||||
|
||||
GraphIteratorProto()
|
||||
: m_graph_def(std::make_shared<::tensorflow::GraphDef>()),
|
||||
m_func_def(nullptr),
|
||||
m_library_map() {}
|
||||
|
||||
public:
|
||||
GraphIteratorProto(const std::shared_ptr<::tensorflow::GraphDef>& graph_def,
|
||||
const std::shared_ptr<::tensorflow::FunctionDef>& func_def,
|
||||
|
@ -150,6 +157,7 @@ public:
|
|||
return m_output_names;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
|
291
src/frontends/tensorflow/src/graph_iterator_saved_model.hpp
Normal file
291
src/frontends/tensorflow/src/graph_iterator_saved_model.hpp
Normal file
|
@ -0,0 +1,291 @@
|
|||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "graph_iterator_proto.hpp"
|
||||
#include "openvino/util/file_util.hpp"
|
||||
#include "saved_model.pb.h"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
struct VIBlock;
|
||||
|
||||
template <typename T>
|
||||
std::basic_string<T> get_saved_model_name() {}
|
||||
template <typename T>
|
||||
std::basic_string<T> get_variables_index_name() {}
|
||||
|
||||
template <>
|
||||
std::basic_string<char> get_saved_model_name<char>();
|
||||
template <>
|
||||
std::basic_string<char> get_variables_index_name<char>();
|
||||
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
template <>
|
||||
std::basic_string<wchar_t> get_saved_model_name<wchar_t>();
|
||||
template <>
|
||||
std::basic_string<wchar_t> get_variables_index_name<wchar_t>();
|
||||
#endif
|
||||
|
||||
// Stores information about variables index
|
||||
class SavedModelVariablesIndex {
|
||||
// Contains maximum amount of shards, used for creating corrext extension
|
||||
int32_t m_total_shards;
|
||||
// Contains BundleEntryProto variables list, readed from .index file
|
||||
std::map<std::string, std::vector<char>> m_variables_index;
|
||||
// List of opened data files for using with BundleEntryProto
|
||||
std::map<int32_t, std::shared_ptr<std::ifstream>> m_data_files;
|
||||
// List of mapped variables which could be read using TrackableObjectGraph
|
||||
std::map<std::string, std::string> m_variables_map;
|
||||
|
||||
public:
|
||||
/// \brief Reads variables from opened variable index file. Can cause an asserts in case of issues.
|
||||
/// \param vi_stream Opened stream file, file pointer doesn't matter, it will be rewind internally.
|
||||
/// \param path A path to file with variables data
|
||||
/// \returns Returns true in case of everything loads successfully, false otherwise
|
||||
bool read_variables(std::ifstream& vi_stream, const std::string& path);
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
/// \brief Reads variables from opened variable index file. Can cause an asserts in case of issues.
|
||||
/// \param vi_stream Opened stream file, file pointer doesn't matter, it will be rewind internally.
|
||||
/// \param path A path to file with variables data
|
||||
/// \returns Returns true in case of everything loads successfully, false otherwise
|
||||
bool read_variables(std::ifstream& vi_stream, const std::wstring& path);
|
||||
#endif
|
||||
|
||||
/// \brief Returns data and size of data of stored variable
|
||||
/// \param name Name of variable
|
||||
/// \param data Pointer on a pointer where data pointer will be returned
|
||||
/// \param size Pointer on a variable which will stores data size
|
||||
/// \returns Returns true in case variable was found, false otherwise (data and size will be untouched)
|
||||
bool get_variable(const std::string& name, const char** data, size_t* size) const {
|
||||
auto varItem = m_variables_index.find(name);
|
||||
if (varItem == m_variables_index.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data != nullptr) {
|
||||
*data = varItem->second.data();
|
||||
}
|
||||
if (size != nullptr) {
|
||||
*size = varItem->second.size();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// \brief Returns data and size of data of mapped variable from trackable object graph to variables index
|
||||
/// \param name Name of a mapping variable
|
||||
/// \param data Pointer on a pointer where data pointer will be returned
|
||||
/// \param size Pointer on a variable which will stores data size
|
||||
/// \returns Returns true in case variable was found, false otherwise (data and size will be untouched)
|
||||
bool get_mapped_variable(const std::string& name, const char** data, size_t* size) const {
|
||||
auto mapItem = m_variables_map.find(name);
|
||||
if (mapItem == m_variables_map.end()) {
|
||||
return false;
|
||||
}
|
||||
return get_variable(mapItem->second, data, size);
|
||||
}
|
||||
|
||||
/// \brief Checks if variable has a mapped pair
|
||||
/// \param name Name of variable for checking existance
|
||||
/// \returns True in case variable has mapped value and false otherwise
|
||||
bool has_mapped_variable(const std::string& name) const {
|
||||
auto mapItem = m_variables_map.find(name);
|
||||
return mapItem != m_variables_map.end();
|
||||
}
|
||||
|
||||
/// \brief Returns shared pointer to a requested shard_id, or nullptr in case of shard_id isn't found
|
||||
/// \param shard_id Requested shard_id
|
||||
/// \returns Valid shared_ptr with ifstream or with nullptr if shard isn't found
|
||||
std::shared_ptr<std::ifstream> get_data_file(const int32_t shard_id) const {
|
||||
auto result = m_data_files.find(shard_id);
|
||||
return result != m_data_files.end() ? result->second : nullptr;
|
||||
}
|
||||
|
||||
/// \brief Adds variable mapping to the variables map
|
||||
/// \param var_name Variable full name (from .index file)
|
||||
/// \param map_name Mapped name
|
||||
/// \param rewrite Rewrite mapped value in case it exists
|
||||
/// \returns True if map updated. False if nothing changed (if variable exists and rewrite is false).
|
||||
bool map_variable(const std::string& var_name, const std::string& map_name, bool rewrite = false) {
|
||||
if (m_variables_map.find(var_name) != m_variables_map.end() && rewrite == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
m_variables_map[var_name] = map_name;
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
/// \brief Reads block structure of .index file
|
||||
/// \param[in,out] fs Filestream of .index file, position in file will be updated
|
||||
/// \param[in] index Variables index block which stores information about block
|
||||
/// \param[out] data Block data will be readed
|
||||
/// \param[out] offset Offset of block start
|
||||
/// \param[out] offset_end Offset of block end
|
||||
void read_variables_index_block(std::ifstream& fs,
|
||||
const VIBlock& index,
|
||||
std::vector<char>& data,
|
||||
uint32_t& offset,
|
||||
uint32_t& offset_end);
|
||||
/// \brief Reads key=value pair from provided pointer
|
||||
/// \param[in,out] ptr Actual pointer, will be moved to the end of readed pair (to read next)
|
||||
/// \param[in] ptr_end End of memory which shouldn't be passed in case of broken structure
|
||||
/// \param[out] key Key name
|
||||
/// \param[out] value Stored value for key (isn't a pure string, data block)
|
||||
/// \param[out] val_lenght Length of readed value
|
||||
void read_variables_index_pair(char*& ptr,
|
||||
const char* ptr_end,
|
||||
std::string& key,
|
||||
char*& value,
|
||||
uint32_t& val_length);
|
||||
/// \brief Reads .index file and stores key=value map in provided varIndex
|
||||
/// \param[in,out] fs Filestream should be parsed. Position in file will be updated
|
||||
/// \param[out] varIndex Variables indx (key=value) from given filestream
|
||||
void read_variables_index(std::ifstream& fs, std::map<std::string, std::vector<char>>& varIndex);
|
||||
/// \brief Reads bundle header if it is available. Checks version and saves info about amount of shards
|
||||
void read_bundle_header();
|
||||
/// \brief Reads key=value map from storef _CHECKPOINTABLE_OBJECT_GRAPH variable
|
||||
void read_checkpointable_object_graph();
|
||||
};
|
||||
|
||||
// Loads graph from Tensorflow Saved Model file (saved_model.pb)
|
||||
class GraphIteratorSavedModel : public GraphIteratorProto {
|
||||
std::shared_ptr<::tensorflow::SavedModel> m_saved_model;
|
||||
std::shared_ptr<SavedModelVariablesIndex> m_variables_index;
|
||||
std::shared_ptr<std::map<std::string, std::string>> m_inputs_map;
|
||||
std::shared_ptr<std::map<std::string, std::string>> m_outputs_map;
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
GraphIteratorSavedModel(const std::basic_string<T>& path, const std::string& tags)
|
||||
: m_saved_model(std::make_shared<::tensorflow::SavedModel>()) {
|
||||
this->read_saved_model(path, tags);
|
||||
}
|
||||
|
||||
static bool is_supported(const std::string& path);
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
static bool is_supported(const std::wstring& path);
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SavedModelVariablesIndex> get_variables_index() {
|
||||
return m_variables_index;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_input_names() const {
|
||||
return m_inputs_map;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_output_names() const {
|
||||
return m_outputs_map;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_valid_signature(const ::tensorflow::SignatureDef& signature) const;
|
||||
|
||||
template <typename T>
|
||||
bool read_saved_model(const std::basic_string<T>& path, const std::string& tags) {
|
||||
std::ifstream sm_stream{path + get_saved_model_name<T>(), std::ifstream::in | std::ifstream::binary};
|
||||
FRONT_END_GENERAL_CHECK(sm_stream && sm_stream.is_open(), "Model file does not exist");
|
||||
|
||||
std::basic_string<T> varIndexPath = path + get_variables_index_name<T>();
|
||||
if (ov::util::file_exists(varIndexPath)) {
|
||||
m_variables_index = std::make_shared<SavedModelVariablesIndex>();
|
||||
std::ifstream vi_stream{varIndexPath, std::ifstream::in | std::ifstream::binary};
|
||||
FRONT_END_GENERAL_CHECK(vi_stream && vi_stream.is_open(),
|
||||
"Saved Model's variable index file does not exist");
|
||||
FRONT_END_GENERAL_CHECK(m_variables_index->read_variables(vi_stream, path),
|
||||
"Saved Model's variable index file cannot be parsed");
|
||||
}
|
||||
|
||||
bool res = m_saved_model->ParseFromIstream(&sm_stream);
|
||||
FRONT_END_GENERAL_CHECK(res && m_saved_model->meta_graphs_size(), "Saved Model cannot be parsed");
|
||||
|
||||
for (const auto& meta_graph : m_saved_model->meta_graphs()) {
|
||||
if (!meta_graph.has_graph_def()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (m_saved_model->meta_graphs_size() > 1) {
|
||||
bool tag_found = false;
|
||||
for (const auto& tag : meta_graph.meta_info_def().tags()) {
|
||||
if (tags.find(tag) != std::string::npos) {
|
||||
tag_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!tag_found) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, const ::tensorflow::SignatureDef*> validSignatures = {};
|
||||
for (const auto& sit : meta_graph.signature_def()) {
|
||||
const std::string& key = sit.first;
|
||||
const ::tensorflow::SignatureDef& val = sit.second;
|
||||
if (is_valid_signature(val)) {
|
||||
validSignatures[key] = &val;
|
||||
}
|
||||
}
|
||||
|
||||
auto serving_default = validSignatures.find("serving_default");
|
||||
|
||||
if (serving_default != validSignatures.end()) {
|
||||
m_inputs_map = std::make_shared<std::map<std::string, std::string>>();
|
||||
m_outputs_map = std::make_shared<std::map<std::string, std::string>>();
|
||||
for (const auto& input : serving_default->second->inputs()) {
|
||||
(*m_inputs_map)[input.second.name()] = input.first;
|
||||
}
|
||||
for (const auto& output : serving_default->second->outputs()) {
|
||||
(*m_outputs_map)[output.second.name()] = output.first;
|
||||
}
|
||||
}
|
||||
|
||||
m_graph_def = std::make_shared<::tensorflow::GraphDef>(meta_graph.graph_def());
|
||||
|
||||
// Update variables map using information by resolving AssignVariableOp graph nodes
|
||||
std::map<std::string, std::string> var_map;
|
||||
map_assignvariable(m_graph_def, var_map);
|
||||
for (auto var : var_map) {
|
||||
m_variables_index->map_variable(var.first, var.second);
|
||||
}
|
||||
|
||||
auto nodes_size = m_graph_def->node_size();
|
||||
m_decoders.resize(static_cast<size_t>(nodes_size));
|
||||
for (int node_ind = 0; node_ind < nodes_size; ++node_ind) {
|
||||
m_decoders[node_ind] = std::make_shared<DecoderProto>(&m_graph_def->node(node_ind), m_graph_def);
|
||||
}
|
||||
|
||||
// initialize a library map
|
||||
auto num_funcs = m_graph_def->library().function_size();
|
||||
for (int func_ind = 0; func_ind < num_funcs; ++func_ind) {
|
||||
auto func = m_graph_def->library().function(func_ind);
|
||||
auto func_name = func.signature().name();
|
||||
m_library_map.insert(std::pair<std::string, int>(func_name, func_ind));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
FRONT_END_GENERAL_CHECK(false, "Saved Model doesn't contain MetaGraph with requested tag");
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// \brief Reads relationship between VarHandleOp - RestoreV2 - AssignVariableOp and
|
||||
/// stores this information in a provided key=value map. Where key - name of VarHandleOp,
|
||||
/// value - long variable name which is stored in RestoreV2.
|
||||
/// It needs to map VarHandleOp to right place in .index file.
|
||||
/// \param[in] graph_def GraphDef object for analysis
|
||||
/// \param[out] variables_map Map of variables found in graph_def
|
||||
void map_assignvariable(const std::shared_ptr<::tensorflow::GraphDef> graph_def,
|
||||
std::map<std::string, std::string>& variables_map) const;
|
||||
}; // GraphIteratorSavedModel
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
|
@ -55,7 +55,10 @@ public:
|
|||
InputModelTFImpl(const GraphIterator::Ptr& graph_iterator, const ov::frontend::InputModel& input_model);
|
||||
InputModelTFImpl(const GraphIterator::Ptr& graph_iterator,
|
||||
const ov::frontend::InputModel& input_model,
|
||||
const std::shared_ptr<TelemetryExtension>& telemetry);
|
||||
const std::shared_ptr<TelemetryExtension>& telemetry,
|
||||
const std::shared_ptr<SavedModelVariablesIndex>& variables_index,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names);
|
||||
std::vector<ov::frontend::Place::Ptr> get_inputs() const;
|
||||
std::vector<ov::frontend::Place::Ptr> get_outputs() const;
|
||||
ov::frontend::Place::Ptr get_place_by_tensor_name(const std::string& tensorName) const;
|
||||
|
@ -79,6 +82,9 @@ public:
|
|||
std::shared_ptr<InputModel> get_body_input_model(const std::string& body_model_name) const;
|
||||
std::vector<std::string> get_input_names() const;
|
||||
std::vector<std::string> get_output_names() const;
|
||||
std::shared_ptr<SavedModelVariablesIndex> get_variables_index() const;
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_input_names() const;
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_output_names() const;
|
||||
|
||||
private:
|
||||
void load_places();
|
||||
|
@ -99,6 +105,10 @@ private:
|
|||
|
||||
std::shared_ptr<TelemetryExtension> m_telemetry;
|
||||
|
||||
std::shared_ptr<SavedModelVariablesIndex> m_variables_index;
|
||||
std::shared_ptr<std::map<std::string, std::string>> m_saved_model_input_names;
|
||||
std::shared_ptr<std::map<std::string, std::string>> m_saved_model_output_names;
|
||||
|
||||
// shows if some nodes might be deleted from graph
|
||||
bool m_graph_changed = false;
|
||||
};
|
||||
|
@ -152,10 +162,10 @@ void InputModel::InputModelTFImpl::load_places() {
|
|||
}
|
||||
auto dtype_any = node_decoder->get_attribute("dtype");
|
||||
auto placeholder_name = node_decoder->get_op_name();
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
dtype_any.is<ov::element::Type>(),
|
||||
"Incorrect input model: Placeholder node " + placeholder_name + " has unspecified type.");
|
||||
auto type = dtype_any.as<ov::element::Type>();
|
||||
ov::element::Type type = ov::element::dynamic;
|
||||
if (dtype_any.is<ov::element::Type>()) {
|
||||
type = dtype_any.as<ov::element::Type>();
|
||||
}
|
||||
std::vector<std::string> names = {op_name};
|
||||
auto tensor_place = std::make_shared<TensorPlace>(m_input_model, pshape, type, names);
|
||||
m_tensor_places[op_name] = tensor_place;
|
||||
|
@ -202,6 +212,17 @@ void InputModel::InputModelTFImpl::load_places() {
|
|||
m_outputs.push_back(output_place);
|
||||
}
|
||||
}
|
||||
std::shared_ptr<SavedModelVariablesIndex> InputModel::InputModelTFImpl::get_variables_index() const {
|
||||
return m_variables_index;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::map<std::string, std::string>> InputModel::InputModelTFImpl::get_saved_model_input_names() const {
|
||||
return m_saved_model_input_names;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::map<std::string, std::string>> InputModel::InputModelTFImpl::get_saved_model_output_names() const {
|
||||
return m_saved_model_output_names;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::get_op_places() const {
|
||||
return topologically_sort_op_nodes();
|
||||
|
@ -337,12 +358,19 @@ std::shared_ptr<InputModel> InputModel::InputModelTFImpl::get_body_input_model(
|
|||
return std::make_shared<InputModel>(body_graph_iterator, m_telemetry);
|
||||
}
|
||||
|
||||
InputModel::InputModelTFImpl::InputModelTFImpl(const GraphIterator::Ptr& graph_iterator,
|
||||
const ov::frontend::InputModel& input_model,
|
||||
const std::shared_ptr<TelemetryExtension>& telemetry)
|
||||
InputModel::InputModelTFImpl::InputModelTFImpl(
|
||||
const GraphIterator::Ptr& graph_iterator,
|
||||
const ov::frontend::InputModel& input_model,
|
||||
const std::shared_ptr<TelemetryExtension>& telemetry,
|
||||
const std::shared_ptr<SavedModelVariablesIndex>& variables_index,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names)
|
||||
: m_graph_iterator(graph_iterator),
|
||||
m_input_model(input_model),
|
||||
m_telemetry(telemetry) {
|
||||
m_telemetry(telemetry),
|
||||
m_variables_index(variables_index),
|
||||
m_saved_model_input_names(saved_model_input_names),
|
||||
m_saved_model_output_names(saved_model_output_names) {
|
||||
FRONT_END_GENERAL_CHECK(m_graph_iterator, "Null pointer specified for GraphIterator");
|
||||
m_input_names = graph_iterator->get_input_names();
|
||||
m_output_names = graph_iterator->get_output_names();
|
||||
|
@ -445,8 +473,29 @@ void InputModel::InputModelTFImpl::set_tensor_value(ov::frontend::Place::Ptr pla
|
|||
m_tensor_values[name] = constant;
|
||||
}
|
||||
|
||||
InputModel::InputModel(const GraphIterator::Ptr& graph_iterator, const std::shared_ptr<TelemetryExtension>& telemetry)
|
||||
: _impl{std::make_shared<InputModelTFImpl>(graph_iterator, *this, telemetry)} {}
|
||||
InputModel::InputModel(const GraphIterator::Ptr& graph_iterator,
|
||||
const std::shared_ptr<TelemetryExtension>& telemetry,
|
||||
const std::shared_ptr<SavedModelVariablesIndex>& variables_index,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names)
|
||||
: _impl{std::make_shared<InputModelTFImpl>(graph_iterator,
|
||||
*this,
|
||||
telemetry,
|
||||
variables_index,
|
||||
saved_model_input_names,
|
||||
saved_model_output_names)} {}
|
||||
|
||||
std::shared_ptr<SavedModelVariablesIndex> InputModel::get_variables_index() {
|
||||
return _impl->get_variables_index();
|
||||
}
|
||||
|
||||
std::shared_ptr<std::map<std::string, std::string>> InputModel::get_saved_model_input_names() const {
|
||||
return _impl->get_saved_model_input_names();
|
||||
}
|
||||
|
||||
std::shared_ptr<std::map<std::string, std::string>> InputModel::get_saved_model_output_names() const {
|
||||
return _impl->get_saved_model_output_names();
|
||||
}
|
||||
|
||||
std::vector<std::string> InputModel::get_input_names() const {
|
||||
return _impl->get_input_names();
|
||||
|
|
|
@ -16,6 +16,7 @@ namespace tensorflow {
|
|||
|
||||
class OpPlace;
|
||||
class TensorPlace;
|
||||
class SavedModelVariablesIndex;
|
||||
|
||||
class InputModel : public ov::frontend::InputModel {
|
||||
friend class TranslateSession;
|
||||
|
@ -31,7 +32,10 @@ class InputModel : public ov::frontend::InputModel {
|
|||
|
||||
public:
|
||||
explicit InputModel(const GraphIterator::Ptr& graph_iterator,
|
||||
const std::shared_ptr<TelemetryExtension>& telemetry = {});
|
||||
const std::shared_ptr<TelemetryExtension>& telemetry = {},
|
||||
const std::shared_ptr<SavedModelVariablesIndex>& variables_index = {},
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names = nullptr,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names = nullptr);
|
||||
|
||||
std::vector<ov::frontend::Place::Ptr> get_inputs() const override;
|
||||
std::vector<ov::frontend::Place::Ptr> get_outputs() const override;
|
||||
|
@ -45,6 +49,9 @@ public:
|
|||
void set_element_type(const ov::frontend::Place::Ptr& place, const ov::element::Type&) override;
|
||||
ov::element::Type get_element_type(const ov::frontend::Place::Ptr& place) const override;
|
||||
void set_tensor_value(const ov::frontend::Place::Ptr& place, const void* value) override;
|
||||
std::shared_ptr<SavedModelVariablesIndex> get_variables_index();
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_input_names() const;
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_output_names() const;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
203
src/frontends/tensorflow/src/op/var_handle.cpp
Normal file
203
src/frontends/tensorflow/src/op/var_handle.cpp
Normal file
|
@ -0,0 +1,203 @@
|
|||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "graph_iterator_saved_model.hpp"
|
||||
#include "helper_ops/string_constant.hpp"
|
||||
#include "helper_ops/unsupported_constant.hpp"
|
||||
#include "input_model.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
#include "tensor_bundle.pb.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset8;
|
||||
using namespace ov;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
// Reading variable from shard file
|
||||
template <typename T>
|
||||
static std::shared_ptr<ov::Node> read_variable(std::shared_ptr<SavedModelVariablesIndex> var_index,
|
||||
const ov::element::Type ov_type,
|
||||
const ov::Shape shape,
|
||||
const ::tensorflow::BundleEntryProto& entry,
|
||||
const NodeContext& node) {
|
||||
std::vector<T> var_data;
|
||||
google::protobuf::int64 size = 1;
|
||||
for (uint64_t i = 0; i < shape.size(); ++i) {
|
||||
size *= static_cast<google::protobuf::int64>(shape[i]);
|
||||
}
|
||||
var_data.resize(size);
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
size == static_cast<google::protobuf::int64>(entry.size() / sizeof(T)),
|
||||
"[TensorFlow Frontend] Internal error: Available data size isn't equal to calculated.");
|
||||
auto fs = var_index->get_data_file(entry.shard_id());
|
||||
if (!fs.get()) {
|
||||
TENSORFLOW_OP_VALIDATION(node, var_index, "[TensorFlow Frontend] Internal error: Cannot get shard file.");
|
||||
}
|
||||
fs->seekg(entry.offset(), std::ios::beg);
|
||||
fs->read(reinterpret_cast<char*>(var_data.data()), entry.size());
|
||||
return std::make_shared<Constant>(ov_type, shape, var_data);
|
||||
}
|
||||
|
||||
OutputVector translate_varhandle_op(const NodeContext& node) {
|
||||
default_op_checks(node, 0, {"VarHandleOp"});
|
||||
auto translate_session = node.get_translate_session();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
translate_session,
|
||||
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
|
||||
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
|
||||
auto var_index = model->get_variables_index();
|
||||
auto ov_type = node.get_attribute<element::Type>("dtype");
|
||||
std::shared_ptr<Node> const_node;
|
||||
if (ov_type == element::undefined) {
|
||||
const_node = std::make_shared<UnsupportedConstant>();
|
||||
} else {
|
||||
// Getting variable description from variables index
|
||||
const char* entry_data = nullptr;
|
||||
size_t entry_size = 0;
|
||||
auto var_name = node.get_name();
|
||||
auto shape = node.get_attribute<::ov::PartialShape>("shape").get_shape();
|
||||
bool result = var_index->get_mapped_variable(var_name, &entry_data, &entry_size);
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node, result, "[TensorFlow Frontend] Internal error: Cannot find requested variable.");
|
||||
|
||||
::tensorflow::BundleEntryProto entry;
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
entry.ParseFromArray(entry_data, static_cast<int>(entry_size)),
|
||||
"[TensorFlow Frontend] Internal error: Cannot get read bundle entry.");
|
||||
|
||||
switch (ov_type) {
|
||||
case ov::element::u8:
|
||||
const_node = read_variable<uint8_t>(var_index, ov_type, shape, entry, node);
|
||||
break;
|
||||
case ov::element::i8:
|
||||
const_node = read_variable<int8_t>(var_index, ov_type, shape, entry, node);
|
||||
break;
|
||||
case ov::element::i16:
|
||||
const_node = read_variable<int16_t>(var_index, ov_type, shape, entry, node);
|
||||
break;
|
||||
case ov::element::i32:
|
||||
const_node = read_variable<int32_t>(var_index, ov_type, shape, entry, node);
|
||||
break;
|
||||
case ov::element::i64:
|
||||
const_node = read_variable<int64_t>(var_index, ov_type, shape, entry, node);
|
||||
break;
|
||||
case ov::element::f16:
|
||||
const_node = read_variable<float16>(var_index, ov_type, shape, entry, node);
|
||||
break;
|
||||
case ov::element::f32:
|
||||
const_node = read_variable<float>(var_index, ov_type, shape, entry, node);
|
||||
break;
|
||||
case ov::element::f64:
|
||||
const_node = read_variable<double>(var_index, ov_type, shape, entry, node);
|
||||
break;
|
||||
case ov::element::bf16:
|
||||
const_node = read_variable<bfloat16>(var_index, ov_type, shape, entry, node);
|
||||
break;
|
||||
default:
|
||||
FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name());
|
||||
}
|
||||
}
|
||||
set_node_name(node.get_name(), const_node);
|
||||
return {const_node};
|
||||
}
|
||||
|
||||
OutputVector translate_varisinitialized_op(const NodeContext& node) {
|
||||
auto const_node = std::make_shared<Constant>(::ov::element::boolean, Shape{}, true);
|
||||
set_node_name(node.get_name(), const_node);
|
||||
return {const_node};
|
||||
}
|
||||
|
||||
OutputVector translate_readvariable_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"ReadVariableOp"});
|
||||
// Documentation says it should return only one tensor with dtype, but
|
||||
// _output_shapes in a vector of shapes and it means it could have multiple outputs
|
||||
// https://www.tensorflow.org/api_docs/python/tf/raw_ops/ReadVariableOp
|
||||
auto output_shapes = node.get_attribute<std::vector<::ov::PartialShape>>("_output_shapes");
|
||||
|
||||
OutputVector outs = {};
|
||||
|
||||
for (size_t i = 0; i < output_shapes.size(); ++i) {
|
||||
std::shared_ptr<ov::Node> output_node;
|
||||
if (node.get_input(0).get_partial_shape().is_static() &&
|
||||
output_shapes[i].get_shape() != node.get_input(0).get_shape()) {
|
||||
auto reshape_shape = make_shared<Constant>(ov::element::i32, output_shapes[i].get_shape());
|
||||
output_node = make_shared<Reshape>(node.get_input(0), reshape_shape, false);
|
||||
} else {
|
||||
output_node = node.get_input(0).get_node_shared_ptr();
|
||||
}
|
||||
if (i == 0) {
|
||||
set_out_name(node.get_name(), output_node);
|
||||
set_out_name(node.get_name() + ":" + "0", output_node);
|
||||
} else {
|
||||
set_node_name(node.get_name() + ":" + std::to_string(i), output_node);
|
||||
}
|
||||
outs.push_back(output_node);
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
|
||||
OutputVector translate_assignvariable_op(const NodeContext& node) {
|
||||
default_op_checks(node, 2, {"AssignVariableOp"});
|
||||
auto assignvariableop_node = std::make_shared<UnsupportedConstant>();
|
||||
set_node_name(node.get_name(), assignvariableop_node);
|
||||
return {assignvariableop_node};
|
||||
}
|
||||
|
||||
OutputVector translate_restorev2_op(const NodeContext& node) {
|
||||
default_op_checks(node, 3, {"RestoreV2"});
|
||||
auto translate_session = node.get_translate_session();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
translate_session,
|
||||
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
|
||||
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
|
||||
auto var_index = model->get_variables_index();
|
||||
auto tensor_names =
|
||||
reinterpret_cast<StringConstant*>(node.get_input(1).get_node())->get_data().as<std::vector<std::string>>();
|
||||
auto tensor_types = node.get_attribute<std::vector<ov::element::Type>>("dtypes");
|
||||
|
||||
OutputVector outs = {};
|
||||
|
||||
for (size_t i = 0; i < tensor_names.size(); ++i) {
|
||||
auto const_node = std::make_shared<UnsupportedConstant>();
|
||||
if (i == 0)
|
||||
set_node_name(node.get_name(), const_node);
|
||||
else
|
||||
set_node_name(node.get_name() + ":" + std::to_string(i), const_node);
|
||||
outs.push_back(const_node);
|
||||
}
|
||||
|
||||
return outs;
|
||||
}
|
||||
|
||||
OutputVector translate_staticregexfullmatch_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"StaticRegexFullMatch"});
|
||||
// auto pattern = node.get_attribute_as_any("pattern").as<std::string>();
|
||||
auto const_node = std::make_shared<Constant>(ov::element::boolean, ov::Shape{}, true);
|
||||
set_node_name(node.get_name(), const_node);
|
||||
return {const_node};
|
||||
}
|
||||
|
||||
OutputVector translate_stringjoin_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"StringJoin"});
|
||||
auto const_node = std::make_shared<UnsupportedConstant>();
|
||||
set_node_name(node.get_name(), const_node);
|
||||
return {const_node};
|
||||
}
|
||||
|
||||
OutputVector translate_mergev2checkpoint_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"MergeV2Checkpoint"});
|
||||
auto const_node = std::make_shared<UnsupportedConstant>();
|
||||
set_node_name(node.get_name(), const_node);
|
||||
return {const_node};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
|
@ -32,6 +32,14 @@ TF_OP_CONVERTER(translate_queue_dequeue_many_op);
|
|||
TF_OP_CONVERTER(translate_sparse_fill_empty_rows_op);
|
||||
TF_OP_CONVERTER(translate_sparse_reshape_op);
|
||||
TF_OP_CONVERTER(translate_sparse_segment_sum_op);
|
||||
TF_OP_CONVERTER(translate_varisinitialized_op);
|
||||
TF_OP_CONVERTER(translate_readvariable_op);
|
||||
TF_OP_CONVERTER(translate_assignvariable_op);
|
||||
TF_OP_CONVERTER(translate_varhandle_op);
|
||||
TF_OP_CONVERTER(translate_restorev2_op);
|
||||
TF_OP_CONVERTER(translate_staticregexfullmatch_op);
|
||||
TF_OP_CONVERTER(translate_stringjoin_op);
|
||||
TF_OP_CONVERTER(translate_mergev2checkpoint_op);
|
||||
TF_OP_CONVERTER(translate_while_op);
|
||||
|
||||
const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
|
@ -246,6 +254,15 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||
{"TopK", translate_top_k_op},
|
||||
{"TopKV2", translate_top_k_v2_op},
|
||||
{"Transpose", translate_transpose_op},
|
||||
{"ReadVariableOp", translate_readvariable_op},
|
||||
{"AssignVariableOp", translate_assignvariable_op},
|
||||
{"VarIsInitializedOp", translate_varisinitialized_op},
|
||||
{"VarHandleOp", translate_varhandle_op},
|
||||
{"RestoreV2", translate_restorev2_op},
|
||||
{"StaticRegexFullMatch", translate_staticregexfullmatch_op},
|
||||
{"StringJoin", translate_stringjoin_op},
|
||||
{"ShardedFilename", translate_identity_op},
|
||||
{"MergeV2Checkpoints", translate_identity_op},
|
||||
{"Unpack", translate_unpack_op},
|
||||
{"While", translate_while_op},
|
||||
{"Where", translate_where_op},
|
||||
|
|
159
src/frontends/tensorflow/src/proto/any.proto
Normal file
159
src/frontends/tensorflow/src/proto/any.proto
Normal file
|
@ -0,0 +1,159 @@
|
|||
// Protocol Buffers - Google's data interchange format
|
||||
// Copyright 2008 Google Inc. All rights reserved.
|
||||
// https://developers.google.com/protocol-buffers/
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
// Modification Copyright (C) 2023 Intel Corporation
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package google.protobuf;
|
||||
|
||||
option csharp_namespace = "Google.Protobuf.WellKnownTypes";
|
||||
option go_package = "google.golang.org/protobuf/types/known/anypb";
|
||||
option java_package = "com.google.protobuf";
|
||||
option java_outer_classname = "AnyProto";
|
||||
option java_multiple_files = true;
|
||||
option objc_class_prefix = "GPB";
|
||||
|
||||
// `Any` contains an arbitrary serialized protocol buffer message along with a
|
||||
// URL that describes the type of the serialized message.
|
||||
//
|
||||
// Protobuf library provides support to pack/unpack Any values in the form
|
||||
// of utility functions or additional generated methods of the Any type.
|
||||
//
|
||||
// Example 1: Pack and unpack a message in C++.
|
||||
//
|
||||
// Foo foo = ...;
|
||||
// Any any;
|
||||
// any.PackFrom(foo);
|
||||
// ...
|
||||
// if (any.UnpackTo(&foo)) {
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// Example 2: Pack and unpack a message in Java.
|
||||
//
|
||||
// Foo foo = ...;
|
||||
// Any any = Any.pack(foo);
|
||||
// ...
|
||||
// if (any.is(Foo.class)) {
|
||||
// foo = any.unpack(Foo.class);
|
||||
// }
|
||||
//
|
||||
// Example 3: Pack and unpack a message in Python.
|
||||
//
|
||||
// foo = Foo(...)
|
||||
// any = Any()
|
||||
// any.Pack(foo)
|
||||
// ...
|
||||
// if any.Is(Foo.DESCRIPTOR):
|
||||
// any.Unpack(foo)
|
||||
// ...
|
||||
//
|
||||
// Example 4: Pack and unpack a message in Go
|
||||
//
|
||||
// foo := &pb.Foo{...}
|
||||
// any, err := anypb.New(foo)
|
||||
// if err != nil {
|
||||
// ...
|
||||
// }
|
||||
// ...
|
||||
// foo := &pb.Foo{}
|
||||
// if err := any.UnmarshalTo(foo); err != nil {
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// The pack methods provided by protobuf library will by default use
|
||||
// 'type.googleapis.com/full.type.name' as the type URL and the unpack
|
||||
// methods only use the fully qualified type name after the last '/'
|
||||
// in the type URL, for example "foo.bar.com/x/y.z" will yield type
|
||||
// name "y.z".
|
||||
//
|
||||
//
|
||||
// JSON
|
||||
// ====
|
||||
// The JSON representation of an `Any` value uses the regular
|
||||
// representation of the deserialized, embedded message, with an
|
||||
// additional field `@type` which contains the type URL. Example:
|
||||
//
|
||||
// package google.profile;
|
||||
// message Person {
|
||||
// string first_name = 1;
|
||||
// string last_name = 2;
|
||||
// }
|
||||
//
|
||||
// {
|
||||
// "@type": "type.googleapis.com/google.profile.Person",
|
||||
// "firstName": <string>,
|
||||
// "lastName": <string>
|
||||
// }
|
||||
//
|
||||
// If the embedded message type is well-known and has a custom JSON
|
||||
// representation, that representation will be embedded adding a field
|
||||
// `value` which holds the custom JSON in addition to the `@type`
|
||||
// field. Example (for message [google.protobuf.Duration][]):
|
||||
//
|
||||
// {
|
||||
// "@type": "type.googleapis.com/google.protobuf.Duration",
|
||||
// "value": "1.212s"
|
||||
// }
|
||||
//
|
||||
message Any {
|
||||
// A URL/resource name that uniquely identifies the type of the serialized
|
||||
// protocol buffer message. This string must contain at least
|
||||
// one "/" character. The last segment of the URL's path must represent
|
||||
// the fully qualified name of the type (as in
|
||||
// `path/google.protobuf.Duration`). The name should be in a canonical form
|
||||
// (e.g., leading "." is not accepted).
|
||||
//
|
||||
// In practice, teams usually precompile into the binary all types that they
|
||||
// expect it to use in the context of Any. However, for URLs which use the
|
||||
// scheme `http`, `https`, or no scheme, one can optionally set up a type
|
||||
// server that maps type URLs to message definitions as follows:
|
||||
//
|
||||
// * If no scheme is provided, `https` is assumed.
|
||||
// * An HTTP GET on the URL must yield a [google.protobuf.Type][]
|
||||
// value in binary format, or produce an error.
|
||||
// * Applications are allowed to cache lookup results based on the
|
||||
// URL, or have them precompiled into a binary to avoid any
|
||||
// lookup. Therefore, binary compatibility needs to be preserved
|
||||
// on changes to types. (Use versioned type names to manage
|
||||
// breaking changes.)
|
||||
//
|
||||
// Note: this functionality is not currently available in the official
|
||||
// protobuf release, and it is not used for type URLs beginning with
|
||||
// type.googleapis.com.
|
||||
//
|
||||
// Schemes other than `http`, `https` (or the empty scheme) might be
|
||||
// used with implementation specific semantics.
|
||||
//
|
||||
string type_url = 1;
|
||||
|
||||
// Must be a valid serialized protocol buffer of the above specified type.
|
||||
bytes value = 2;
|
||||
}
|
351
src/frontends/tensorflow/src/proto/meta_graph.proto
Normal file
351
src/frontends/tensorflow/src/proto/meta_graph.proto
Normal file
|
@ -0,0 +1,351 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
// Modification Copyright (C) 2018-2023 Intel Corporation
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "any.proto";
|
||||
import "graph.proto";
|
||||
import "op_def.proto";
|
||||
import "tensor_shape.proto";
|
||||
import "types.proto";
|
||||
import "saved_object_graph.proto";
|
||||
import "saver.proto";
|
||||
import "struct.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "MetaGraphProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
||||
|
||||
// Protocol buffer containing the following which are necessary to restart
|
||||
// training, run inference. It can be used to serialize/de-serialize memory
|
||||
// objects necessary for running computation in a graph when crossing the
|
||||
// process boundary. It can be used for long term storage of graphs,
|
||||
// cross-language execution of graphs, etc.
|
||||
// MetaInfoDef
|
||||
// GraphDef
|
||||
// SaverDef
|
||||
// CollectionDef
|
||||
// TensorInfo
|
||||
// SignatureDef
|
||||
message MetaGraphDef {
|
||||
// Meta information regarding the graph to be exported. To be used by users
|
||||
// of this protocol buffer to encode information regarding their meta graph.
|
||||
message MetaInfoDef {
|
||||
// User specified Version string. Can be the name of the model and revision,
|
||||
// steps this model has been trained to, etc.
|
||||
string meta_graph_version = 1;
|
||||
|
||||
// A copy of the OpDefs used by the producer of this graph_def.
|
||||
// Descriptions and Ops not used in graph_def are stripped out.
|
||||
OpList stripped_op_list = 2;
|
||||
|
||||
// A serialized protobuf. Can be the time this meta graph is created, or
|
||||
// modified, or name of the model.
|
||||
google.protobuf.Any any_info = 3;
|
||||
|
||||
// User supplied tag(s) on the meta_graph and included graph_def.
|
||||
//
|
||||
// MetaGraphDefs should be tagged with their capabilities or use-cases.
|
||||
// Examples: "train", "serve", "gpu", "tpu", etc.
|
||||
// These tags enable loaders to access the MetaGraph(s) appropriate for a
|
||||
// specific use-case or runtime environment.
|
||||
repeated string tags = 4;
|
||||
|
||||
// The __version__ string of the tensorflow build used to write this graph.
|
||||
// This will be populated by the framework, which will overwrite any user
|
||||
// supplied value.
|
||||
string tensorflow_version = 5;
|
||||
|
||||
// The __git_version__ string of the tensorflow build used to write this
|
||||
// graph. This will be populated by the framework, which will overwrite any
|
||||
// user supplied value.
|
||||
string tensorflow_git_version = 6;
|
||||
|
||||
// A flag to denote whether default-valued attrs have been stripped from
|
||||
// the nodes in this graph_def.
|
||||
bool stripped_default_attrs = 7;
|
||||
|
||||
// FunctionDef name to aliases mapping.
|
||||
map<string, string> function_aliases = 8;
|
||||
}
|
||||
MetaInfoDef meta_info_def = 1;
|
||||
|
||||
// GraphDef.
|
||||
GraphDef graph_def = 2;
|
||||
|
||||
// SaverDef.
|
||||
SaverDef saver_def = 3;
|
||||
|
||||
// collection_def: Map from collection name to collections.
|
||||
// See CollectionDef section for details.
|
||||
map<string, CollectionDef> collection_def = 4;
|
||||
|
||||
// signature_def: Map from user supplied key for a signature to a single
|
||||
// SignatureDef.
|
||||
map<string, SignatureDef> signature_def = 5;
|
||||
|
||||
// Asset file def to be used with the defined graph.
|
||||
repeated AssetFileDef asset_file_def = 6;
|
||||
|
||||
// Extra information about the structure of functions and stateful objects.
|
||||
SavedObjectGraph object_graph_def = 7;
|
||||
}
|
||||
|
||||
// CollectionDef should cover most collections.
|
||||
// To add a user-defined collection, do one of the following:
|
||||
// 1. For simple data types, such as string, int, float:
|
||||
// tf.add_to_collection("your_collection_name", your_simple_value)
|
||||
// strings will be stored as bytes_list.
|
||||
//
|
||||
// 2. For Protobuf types, there are three ways to add them:
|
||||
// 1) tf.add_to_collection("your_collection_name",
|
||||
// your_proto.SerializeToString())
|
||||
//
|
||||
// collection_def {
|
||||
// key: "user_defined_bytes_collection"
|
||||
// value {
|
||||
// bytes_list {
|
||||
// value: "queue_name: \"test_queue\"\n"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// or
|
||||
//
|
||||
// 2) tf.add_to_collection("your_collection_name", str(your_proto))
|
||||
//
|
||||
// collection_def {
|
||||
// key: "user_defined_string_collection"
|
||||
// value {
|
||||
// bytes_list {
|
||||
// value: "\n\ntest_queue"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// or
|
||||
//
|
||||
// 3) any_buf = any_pb2.Any()
|
||||
// tf.add_to_collection("your_collection_name",
|
||||
// any_buf.Pack(your_proto))
|
||||
//
|
||||
// collection_def {
|
||||
// key: "user_defined_any_collection"
|
||||
// value {
|
||||
// any_list {
|
||||
// value {
|
||||
// type_url: "type.googleapis.com/tensorflow.QueueRunnerDef"
|
||||
// value: "\n\ntest_queue"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// 3. For Python objects, implement to_proto() and from_proto(), and register
|
||||
// them in the following manner:
|
||||
// ops.register_proto_function("your_collection_name",
|
||||
// proto_type,
|
||||
// to_proto=YourPythonObject.to_proto,
|
||||
// from_proto=YourPythonObject.from_proto)
|
||||
// These functions will be invoked to serialize and de-serialize the
|
||||
// collection. For example,
|
||||
// ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
// proto_type=variable_pb2.VariableDef,
|
||||
// to_proto=Variable.to_proto,
|
||||
// from_proto=Variable.from_proto)
|
||||
message CollectionDef {
|
||||
// NodeList is used for collecting nodes in graph. For example
|
||||
// collection_def {
|
||||
// key: "summaries"
|
||||
// value {
|
||||
// node_list {
|
||||
// value: "input_producer/ScalarSummary:0"
|
||||
// value: "shuffle_batch/ScalarSummary:0"
|
||||
// value: "ImageSummary:0"
|
||||
// }
|
||||
// }
|
||||
message NodeList {
|
||||
repeated string value = 1;
|
||||
}
|
||||
|
||||
// BytesList is used for collecting strings and serialized protobufs. For
|
||||
// example:
|
||||
// collection_def {
|
||||
// key: "trainable_variables"
|
||||
// value {
|
||||
// bytes_list {
|
||||
// value: "\n\017conv1/weights:0\022\024conv1/weights/Assign
|
||||
// \032\024conv1/weights/read:0"
|
||||
// value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032
|
||||
// \023conv1/biases/read:0"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
message BytesList {
|
||||
repeated bytes value = 1;
|
||||
}
|
||||
|
||||
// Int64List is used for collecting int, int64 and long values.
|
||||
message Int64List {
|
||||
repeated int64 value = 1 [packed = true];
|
||||
}
|
||||
|
||||
// FloatList is used for collecting float values.
|
||||
message FloatList {
|
||||
repeated float value = 1 [packed = true];
|
||||
}
|
||||
|
||||
// AnyList is used for collecting Any protos.
|
||||
message AnyList {
|
||||
repeated google.protobuf.Any value = 1;
|
||||
}
|
||||
|
||||
oneof kind {
|
||||
NodeList node_list = 1;
|
||||
BytesList bytes_list = 2;
|
||||
Int64List int64_list = 3;
|
||||
FloatList float_list = 4;
|
||||
AnyList any_list = 5;
|
||||
}
|
||||
}
|
||||
|
||||
// Information about a Tensor necessary for feeding or retrieval.
|
||||
message TensorInfo {
|
||||
// For sparse tensors, The COO encoding stores a triple of values, indices,
|
||||
// and shape.
|
||||
message CooSparse {
|
||||
// The shape of the values Tensor is [?]. Its dtype must be the dtype of
|
||||
// the SparseTensor as a whole, given in the enclosing TensorInfo.
|
||||
string values_tensor_name = 1;
|
||||
|
||||
// The indices Tensor must have dtype int64 and shape [?, ?].
|
||||
string indices_tensor_name = 2;
|
||||
|
||||
// The dynamic logical shape represented by the SparseTensor is recorded in
|
||||
// the Tensor referenced here. It must have dtype int64 and shape [?].
|
||||
string dense_shape_tensor_name = 3;
|
||||
}
|
||||
|
||||
// Generic encoding for composite tensors.
|
||||
message CompositeTensor {
|
||||
// The serialized TypeSpec for the composite tensor.
|
||||
TypeSpecProto type_spec = 1;
|
||||
|
||||
// A TensorInfo for each flattened component tensor.
|
||||
repeated TensorInfo components = 2;
|
||||
}
|
||||
|
||||
oneof encoding {
|
||||
// For dense `Tensor`s, the name of the tensor in the graph.
|
||||
string name = 1;
|
||||
// There are many possible encodings of sparse matrices
|
||||
// (https://en.wikipedia.org/wiki/Sparse_matrix). Currently, TensorFlow
|
||||
// uses only the COO encoding. This is supported and documented in the
|
||||
// SparseTensor Python class.
|
||||
CooSparse coo_sparse = 4;
|
||||
// Generic encoding for CompositeTensors.
|
||||
CompositeTensor composite_tensor = 5;
|
||||
}
|
||||
DataType dtype = 2;
|
||||
// The static shape should be recorded here, to the extent that it can
|
||||
// be known in advance. In the case of a SparseTensor, this field describes
|
||||
// the logical shape of the represented tensor (aka dense_shape).
|
||||
TensorShapeProto tensor_shape = 3;
|
||||
}
|
||||
|
||||
// SignatureDef defines the signature of a computation supported by a TensorFlow
|
||||
// graph.
|
||||
//
|
||||
// For example, a model with two loss computations, sharing a single input,
|
||||
// might have the following signature_def map, in a MetaGraphDef message.
|
||||
//
|
||||
// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key,
|
||||
// output key, and method_name are identical, and will be used by system(s) that
|
||||
// implement or rely upon this particular loss method. The output tensor names
|
||||
// differ, demonstrating how different outputs can exist for the same method.
|
||||
//
|
||||
// signature_def {
|
||||
// key: "loss_A"
|
||||
// value {
|
||||
// inputs {
|
||||
// key: "input"
|
||||
// value {
|
||||
// name: "input:0"
|
||||
// dtype: DT_STRING
|
||||
// tensor_shape: ...
|
||||
// }
|
||||
// }
|
||||
// outputs {
|
||||
// key: "loss_output"
|
||||
// value {
|
||||
// name: "loss_output_A:0"
|
||||
// dtype: DT_FLOAT
|
||||
// tensor_shape: ...
|
||||
// }
|
||||
// }
|
||||
// method_name: "some/package/compute_loss"
|
||||
// }
|
||||
// ...
|
||||
// }
|
||||
// signature_def {
|
||||
// key: "loss_B"
|
||||
// value {
|
||||
// inputs {
|
||||
// key: "input"
|
||||
// value {
|
||||
// name: "input:0"
|
||||
// dtype: DT_STRING
|
||||
// tensor_shape: ...
|
||||
// }
|
||||
// }
|
||||
// outputs {
|
||||
// key: "loss_output"
|
||||
// value {
|
||||
// name: "loss_output_B:0"
|
||||
// dtype: DT_FLOAT
|
||||
// tensor_shape: ...
|
||||
// }
|
||||
// }
|
||||
// method_name: "some/package/compute_loss"
|
||||
// }
|
||||
// ...
|
||||
// }
|
||||
message SignatureDef {
|
||||
// Named input parameters.
|
||||
map<string, TensorInfo> inputs = 1;
|
||||
// Named output parameters.
|
||||
map<string, TensorInfo> outputs = 2;
|
||||
// Extensible method_name information enabling third-party users to mark a
|
||||
// SignatureDef as supporting a particular method. This enables producers and
|
||||
// consumers of SignatureDefs, e.g. a model definition library and a serving
|
||||
// library to have a clear hand-off regarding the semantics of a computation.
|
||||
//
|
||||
// Note that multiple SignatureDefs in a single MetaGraphDef may have the same
|
||||
// method_name. This is commonly used to support multi-headed computation,
|
||||
// where a single graph computation may return multiple results.
|
||||
string method_name = 3;
|
||||
}
|
||||
|
||||
// An asset file def for a single file or a set of sharded files with the same
|
||||
// name.
|
||||
message AssetFileDef {
|
||||
// The tensor to bind the asset filename to.
|
||||
TensorInfo tensor_info = 1;
|
||||
// The filename within an assets directory. Note: does not include the path
|
||||
// prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename
|
||||
// would be "vocab.txt".
|
||||
string filename = 2;
|
||||
}
|
35
src/frontends/tensorflow/src/proto/saved_model.proto
Normal file
35
src/frontends/tensorflow/src/proto/saved_model.proto
Normal file
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
// Modification Copyright (C) 2023 Intel Corporation
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "meta_graph.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "SavedModelProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
||||
|
||||
// SavedModel is the high level serialization format for TensorFlow Models.
|
||||
// See [todo: doc links, similar to session_bundle] for more information.
|
||||
message SavedModel {
|
||||
// The schema version of the SavedModel instance. Used for versioning when
|
||||
// making future changes to the specification/implementation. Initial value
|
||||
// at release will be 1.
|
||||
int64 saved_model_schema_version = 1;
|
||||
|
||||
// One or more MetaGraphs.
|
||||
repeated MetaGraphDef meta_graphs = 2;
|
||||
}
|
263
src/frontends/tensorflow/src/proto/saved_object_graph.proto
Normal file
263
src/frontends/tensorflow/src/proto/saved_object_graph.proto
Normal file
|
@ -0,0 +1,263 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
// Modification Copyright (C) 2023 Intel Corporation
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "any.proto";
|
||||
import "tensor_shape.proto";
|
||||
import "types.proto";
|
||||
import "variable.proto";
|
||||
import "versions.proto";
|
||||
import "struct.proto";
|
||||
import "trackable_object_graph.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
||||
|
||||
// A SavedObjectGraph is part of object-based SavedModels in TF 2.0. It
|
||||
// describes the directed graph of Python objects (or equivalent in other
|
||||
// languages) that make up a model, with nodes[0] at the root.
|
||||
|
||||
// SavedObjectGraph shares some structure with TrackableObjectGraph, but
|
||||
// SavedObjectGraph belongs to the MetaGraph and contains pointers to functions
|
||||
// and type information, while TrackableObjectGraph lives in the checkpoint
|
||||
// and contains pointers only to variable values.
|
||||
|
||||
message SavedObjectGraph {
|
||||
// Flattened list of objects in the object graph.
|
||||
//
|
||||
// The position of the object in this list indicates its id.
|
||||
// Nodes[0] is considered the root node.
|
||||
repeated SavedObject nodes = 1;
|
||||
|
||||
// Information about captures and output structures in concrete functions.
|
||||
// Referenced from SavedBareConcreteFunction and SavedFunction.
|
||||
map<string, SavedConcreteFunction> concrete_functions = 2;
|
||||
}
|
||||
|
||||
message SavedObject {
|
||||
// Objects which this object depends on: named edges in the dependency
|
||||
// graph.
|
||||
//
|
||||
// Note: All kinds of SavedObject may have children, except
|
||||
// "constant" and "captured_tensor".
|
||||
repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;
|
||||
|
||||
// Ordered list of dependencies that must be loaded before this object.
|
||||
// SavedModel loads with the bottom-up approach, by first creating all objects
|
||||
// (in the order defined by the dependencies), then connecting the edges.
|
||||
repeated TrackableObjectGraph.TrackableObject.ObjectReference dependencies =
|
||||
15;
|
||||
|
||||
// Removed when forking SavedObject from TrackableObjectGraph.
|
||||
reserved "attributes";
|
||||
reserved 2;
|
||||
|
||||
// Slot variables owned by this object. This describes the three-way
|
||||
// (optimizer, variable, slot variable) relationship; none of the three
|
||||
// depend on the others directly.
|
||||
//
|
||||
// Note: currently only valid if kind == "user_object".
|
||||
repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
|
||||
slot_variables = 3;
|
||||
|
||||
oneof kind {
|
||||
SavedUserObject user_object = 4;
|
||||
SavedAsset asset = 5;
|
||||
SavedFunction function = 6;
|
||||
SavedVariable variable = 7;
|
||||
SavedBareConcreteFunction bare_concrete_function = 8;
|
||||
SavedConstant constant = 9;
|
||||
SavedResource resource = 10;
|
||||
CapturedTensor captured_tensor = 12;
|
||||
}
|
||||
|
||||
// Stores the functions used to save and restore this object. At most one of
|
||||
// `saveable_objects` or `registered_saver` is defined for each SavedObject.
|
||||
// See the comment below for the difference between SaveableObject and
|
||||
// registered savers.
|
||||
map<string, SaveableObject> saveable_objects = 11;
|
||||
|
||||
// The fields below are filled when the user serializes a registered Trackable
|
||||
// class or an object with a registered saver function.
|
||||
//
|
||||
// Registered classes may save additional metadata and supersede the
|
||||
// default loading process where nodes are recreated from the proto.
|
||||
// If the registered class cannot be found, then the object will load as one
|
||||
// one of the default trackable objects: Autotrackable (a class similar to
|
||||
// tf.Module), tf.function, or tf.Variable.
|
||||
//
|
||||
// Unlike SaveableObjects, which store the functions for saving and restoring
|
||||
// from tensors, registered savers allow Trackables to write checkpoint shards
|
||||
// directly (e.g. for performance or coordination reasons).
|
||||
// *All registered savers must be available when loading the SavedModel.*
|
||||
|
||||
// The name of the registered class of the form "{package}.{class_name}".
|
||||
// This field is used to search for the registered class at loading time.
|
||||
string registered_name = 13;
|
||||
// The user-generated proto storing metadata for this object, to be passed to
|
||||
// the registered classes's _deserialize_from_proto method when this object is
|
||||
// loaded from the SavedModel.
|
||||
google.protobuf.Any serialized_user_proto = 14;
|
||||
|
||||
// String name of the registered saver. At most one of `saveable_objects` or
|
||||
// `registered_saver` is defined for each SavedObject.
|
||||
string registered_saver = 16;
|
||||
}
|
||||
|
||||
// A SavedUserObject is an object (in the object-oriented language of the
|
||||
// TensorFlow program) of some user- or framework-defined class other than
|
||||
// those handled specifically by the other kinds of SavedObjects.
|
||||
//
|
||||
// This object cannot be evaluated as a tensor, and therefore cannot be bound
|
||||
// to an input of a function.
|
||||
message SavedUserObject {
|
||||
// Corresponds to a registration of the type to use in the loading program.
|
||||
string identifier = 1;
|
||||
// Version information from the producer of this SavedUserObject.
|
||||
VersionDef version = 2;
|
||||
// Metadata for deserializing this object.
|
||||
//
|
||||
// Deprecated! At the time of deprecation, Keras was the only user of this
|
||||
// field, and its saving and loading code will be updated shortly.
|
||||
// Please save your application-specific metadata to a separate file.
|
||||
string metadata = 3 [deprecated = true];
|
||||
}
|
||||
|
||||
// A SavedAsset points to an asset in the MetaGraph.
|
||||
//
|
||||
// When bound to a function this object evaluates to a tensor with the absolute
|
||||
// filename. Users should not depend on a particular part of the filename to
|
||||
// remain stable (e.g. basename could be changed).
|
||||
message SavedAsset {
|
||||
// Index into `MetaGraphDef.asset_file_def[]` that describes the Asset.
|
||||
//
|
||||
// Only the field `AssetFileDef.filename` is used. Other fields, such as
|
||||
// `AssetFileDef.tensor_info`, MUST be ignored.
|
||||
int32 asset_file_def_index = 1;
|
||||
}
|
||||
|
||||
// A function with multiple signatures, possibly with non-Tensor arguments.
|
||||
message SavedFunction {
|
||||
repeated string concrete_functions = 1;
|
||||
FunctionSpec function_spec = 2;
|
||||
}
|
||||
|
||||
message CapturedTensor {
|
||||
// Name of captured tensor
|
||||
string name = 1;
|
||||
|
||||
// Name of concrete function which contains the computed graph tensor.
|
||||
string concrete_function = 2;
|
||||
}
|
||||
|
||||
// Stores low-level information about a concrete function. Referenced in either
|
||||
// a SavedFunction or a SavedBareConcreteFunction.
|
||||
message SavedConcreteFunction {
|
||||
repeated int32 bound_inputs = 2;
|
||||
|
||||
// Input in canonicalized form that was received to create this concrete
|
||||
// function.
|
||||
StructuredValue canonicalized_input_signature = 3;
|
||||
// Output that was the return value of this function after replacing all
|
||||
// Tensors with TensorSpecs. This can be an arbitrary nested function and will
|
||||
// be used to reconstruct the full structure from pure tensors.
|
||||
StructuredValue output_signature = 4;
|
||||
}
|
||||
|
||||
message SavedBareConcreteFunction {
|
||||
// Identifies a SavedConcreteFunction.
|
||||
string concrete_function_name = 1;
|
||||
|
||||
// A sequence of unique strings, one per Tensor argument.
|
||||
repeated string argument_keywords = 2;
|
||||
// The prefix of `argument_keywords` which may be identified by position.
|
||||
int64 allowed_positional_arguments = 3;
|
||||
// The spec of the function that this ConcreteFunction is traced from. This
|
||||
// allows the ConcreteFunction to be called with nest structure inputs. This
|
||||
// field may not be populated. If this field is absent, the concrete function
|
||||
// can only be called with flat inputs.
|
||||
// TODO(b/169361281): support calling saved ConcreteFunction with structured
|
||||
// inputs in C++ SavedModel API.
|
||||
FunctionSpec function_spec = 4;
|
||||
}
|
||||
|
||||
message SavedConstant {
|
||||
// An Operation name for a ConstantOp in this SavedObjectGraph's MetaGraph.
|
||||
string operation = 1;
|
||||
}
|
||||
|
||||
// Represents a Variable that is initialized by loading the contents from the
|
||||
// checkpoint.
|
||||
message SavedVariable {
|
||||
DataType dtype = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
bool trainable = 3;
|
||||
VariableSynchronization synchronization = 4;
|
||||
VariableAggregation aggregation = 5;
|
||||
string name = 6;
|
||||
string device = 7;
|
||||
// List of component variables for a distributed variable.
|
||||
//
|
||||
// When this field is non-empty, the SavedVariable will be assumed
|
||||
// to be a distributed variable defined by the components listed here.
|
||||
//
|
||||
// This is only supported by experimental loaders at the moment.
|
||||
repeated SavedVariable experimental_distributed_variable_components = 8;
|
||||
}
|
||||
|
||||
// Represents `FunctionSpec` used in `Function`. This represents a
|
||||
// function that has been wrapped as a TensorFlow `Function`.
|
||||
message FunctionSpec {
|
||||
// Full arg spec from inspect.getfullargspec().
|
||||
StructuredValue fullargspec = 1;
|
||||
// Whether this represents a class method.
|
||||
bool is_method = 2;
|
||||
// The input signature, if specified.
|
||||
StructuredValue input_signature = 5;
|
||||
|
||||
// Whether the function should be compiled by XLA.
|
||||
//
|
||||
// The public interface to `tf.function` uses an optional boolean to
|
||||
// represent three distinct states for this field. Unfortunately, proto3
|
||||
// removes the ability to explicitly check for the presence or absence of a
|
||||
// field, so we instead map to an enum.
|
||||
//
|
||||
// See `tf.function` for details.
|
||||
enum JitCompile {
|
||||
DEFAULT = 0;
|
||||
ON = 1;
|
||||
OFF = 2;
|
||||
}
|
||||
JitCompile jit_compile = 6;
|
||||
|
||||
reserved 3, 4;
|
||||
}
|
||||
|
||||
// A SavedResource represents a TF object that holds state during its lifetime.
|
||||
// An object of this type can have a reference to a:
|
||||
// create_resource() and an initialize() function.
|
||||
message SavedResource {
|
||||
// A device specification indicating a required placement for the resource
|
||||
// creation function, e.g. "CPU". An empty string allows the user to select a
|
||||
// device.
|
||||
string device = 1;
|
||||
}
|
||||
|
||||
message SaveableObject {
|
||||
// Node ids of concrete functions for saving and loading from a checkpoint.
|
||||
// These functions save and restore directly from tensors.
|
||||
int32 save_function = 2;
|
||||
int32 restore_function = 3;
|
||||
}
|
96
src/frontends/tensorflow/src/proto/saved_tensor_slice.proto
Normal file
96
src/frontends/tensorflow/src/proto/saved_tensor_slice.proto
Normal file
|
@ -0,0 +1,96 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
// Modification Copyright (C) 2018-2023 Intel Corporation
|
||||
// Protocol buffers for saved tensor slices. It's used for the brain tensor
|
||||
// ops checkpoints and the V3 checkpoints in dist_belief.
|
||||
|
||||
// A checkpoint file is an sstable. The value for each record is a serialized
|
||||
// SavedTensorSlices message (defined below).
|
||||
//
|
||||
// Each checkpoint file has a record with the empty key (""), which corresponds
|
||||
// to a SavedTensorSlices message that contains a "meta", that serves as a
|
||||
// table of contents on all the tensor slices saved in this file. Since the key
|
||||
// is "", it's always the first record in each file.
|
||||
//
|
||||
// Each of the rest of the records in a checkpoint stores the raw data of a
|
||||
// particular tensor slice, in SavedSlice format. The corresponding key is an
|
||||
// ordered code that encodes the name of the tensor and the slice
|
||||
// information. The name is also stored in the SaveSlice message for ease of
|
||||
// debugging and manual examination.
|
||||
// Modification Copyright (C) 2023 Intel Corporation
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "SavedTensorSliceProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.util";
|
||||
|
||||
import "tensor_shape.proto";
|
||||
import "tensor_slice.proto";
|
||||
import "tensor.proto";
|
||||
import "types.proto";
|
||||
import "versions.proto";
|
||||
|
||||
// Metadata describing the set of slices of the same tensor saved in a
|
||||
// checkpoint file.
|
||||
message SavedSliceMeta {
|
||||
// Name of the tensor.
|
||||
string name = 1;
|
||||
|
||||
// Shape of the tensor
|
||||
TensorShapeProto shape = 2;
|
||||
|
||||
// Type of the tensor
|
||||
DataType type = 3;
|
||||
|
||||
// Explicit list of slices saved in the checkpoint file.
|
||||
repeated TensorSliceProto slice = 4;
|
||||
};
|
||||
|
||||
// Metadata describing the set of tensor slices saved in a checkpoint file.
|
||||
// It is always stored at the beginning of each checkpoint file.
|
||||
message SavedTensorSliceMeta {
|
||||
// Each SavedSliceMeta describes the slices for one tensor.
|
||||
repeated SavedSliceMeta tensor = 1;
|
||||
|
||||
// Compatibility version of this checkpoint. See core/public/version.h
|
||||
// for version history.
|
||||
VersionDef versions = 2;
|
||||
};
|
||||
|
||||
// Saved tensor slice: it stores the name of the tensors, the slice, and the
|
||||
// raw data.
|
||||
message SavedSlice {
|
||||
// Name of the tensor that this slice belongs to. This must be identical to
|
||||
// the name used to encode the key for this record.
|
||||
string name = 1;
|
||||
|
||||
// Extent of the slice. Must have one entry for each of the dimension of the
|
||||
// tensor that this slice belongs to.
|
||||
TensorSliceProto slice = 2;
|
||||
|
||||
// The raw data of the slice is stored as a TensorProto. Only raw data are
|
||||
// stored (we don't fill in fields such as dtype or tensor_shape).
|
||||
TensorProto data = 3;
|
||||
};
|
||||
|
||||
// Each record in a v3 checkpoint file is a serialized SavedTensorSlices
|
||||
// message.
|
||||
message SavedTensorSlices {
|
||||
// This is only present at the first item of each checkpoint file and serves
|
||||
// as a table of contents, listing all the tensor slices saved in this file.
|
||||
SavedTensorSliceMeta meta = 1;
|
||||
|
||||
// This exists in all but the first item of each checkpoint file.
|
||||
SavedSlice data = 2;
|
||||
};
|
60
src/frontends/tensorflow/src/proto/saver.proto
Normal file
60
src/frontends/tensorflow/src/proto/saver.proto
Normal file
|
@ -0,0 +1,60 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
// Modification Copyright (C) 2023 Intel Corporation
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "SaverProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.util";
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
||||
|
||||
// Protocol buffer representing the configuration of a Saver.
|
||||
message SaverDef {
|
||||
// The name of the tensor in which to specify the filename when saving or
|
||||
// restoring a model checkpoint.
|
||||
string filename_tensor_name = 1;
|
||||
|
||||
// The operation to run when saving a model checkpoint.
|
||||
string save_tensor_name = 2;
|
||||
|
||||
// The operation to run when restoring a model checkpoint.
|
||||
string restore_op_name = 3;
|
||||
|
||||
// Maximum number of checkpoints to keep. If 0, no checkpoints are deleted.
|
||||
int32 max_to_keep = 4;
|
||||
|
||||
// Shard the save files, one per device that has Variable nodes.
|
||||
bool sharded = 5;
|
||||
|
||||
// How often to keep an additional checkpoint. If not specified, only the last
|
||||
// "max_to_keep" checkpoints are kept; if specified, in addition to keeping
|
||||
// the last "max_to_keep" checkpoints, an additional checkpoint will be kept
|
||||
// for every n hours of training.
|
||||
float keep_checkpoint_every_n_hours = 6;
|
||||
|
||||
// A version number that identifies a different on-disk checkpoint format.
|
||||
// Usually, each subclass of BaseSaverBuilder works with a particular
|
||||
// version/format. However, it is possible that the same builder may be
|
||||
// upgraded to support a newer checkpoint format in the future.
|
||||
enum CheckpointFormatVersion {
|
||||
// Internal legacy format.
|
||||
LEGACY = 0;
|
||||
// Deprecated format: tf.Saver() which works with tensorflow::table::Table.
|
||||
V1 = 1;
|
||||
// Current format: more efficient.
|
||||
V2 = 2;
|
||||
}
|
||||
CheckpointFormatVersion version = 7;
|
||||
}
|
172
src/frontends/tensorflow/src/proto/struct.proto
Normal file
172
src/frontends/tensorflow/src/proto/struct.proto
Normal file
|
@ -0,0 +1,172 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
// Modification Copyright (C) 2023 Intel Corporation
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensor.proto";
|
||||
import "tensor_shape.proto";
|
||||
import "types.proto";
|
||||
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
||||
|
||||
// `StructuredValue` represents a dynamically typed value representing various
|
||||
// data structures that are inspired by Python data structures typically used in
|
||||
// TensorFlow functions as inputs and outputs.
|
||||
//
|
||||
// For example when saving a Layer there may be a `training` argument. If the
|
||||
// user passes a boolean True/False, that switches between two concrete
|
||||
// TensorFlow functions. In order to switch between them in the same way after
|
||||
// loading the SavedModel, we need to represent "True" and "False".
|
||||
//
|
||||
// A more advanced example might be a function which takes a list of
|
||||
// dictionaries mapping from strings to Tensors. In order to map from
|
||||
// user-specified arguments `[{"a": tf.constant(1.)}, {"q": tf.constant(3.)}]`
|
||||
// after load to the right saved TensorFlow function, we need to represent the
|
||||
// nested structure and the strings, recording that we have a trace for anything
|
||||
// matching `[{"a": tf.TensorSpec(None, tf.float32)}, {"q": tf.TensorSpec([],
|
||||
// tf.float64)}]` as an example.
|
||||
//
|
||||
// Likewise functions may return nested structures of Tensors, for example
|
||||
// returning a dictionary mapping from strings to Tensors. In order for the
|
||||
// loaded function to return the same structure we need to serialize it.
|
||||
//
|
||||
// This is an ergonomic aid for working with loaded SavedModels, not a promise
|
||||
// to serialize all possible function signatures. For example we do not expect
|
||||
// to pickle generic Python objects, and ideally we'd stay language-agnostic.
|
||||
message StructuredValue {
|
||||
// The kind of value.
|
||||
oneof kind {
|
||||
// Represents None.
|
||||
NoneValue none_value = 1;
|
||||
|
||||
// Represents a double-precision floating-point value (a Python `float`).
|
||||
double float64_value = 11;
|
||||
// Represents a signed integer value, limited to 64 bits.
|
||||
// Larger values from Python's arbitrary-precision integers are unsupported.
|
||||
sint64 int64_value = 12;
|
||||
// Represents a string of Unicode characters stored in a Python `str`.
|
||||
// In Python 3, this is exactly what type `str` is.
|
||||
// In Python 2, this is the UTF-8 encoding of the characters.
|
||||
// For strings with ASCII characters only (as often used in TensorFlow code)
|
||||
// there is effectively no difference between the language versions.
|
||||
// The obsolescent `unicode` type of Python 2 is not supported here.
|
||||
string string_value = 13;
|
||||
// Represents a boolean value.
|
||||
bool bool_value = 14;
|
||||
|
||||
// Represents a TensorShape.
|
||||
tensorflow.TensorShapeProto tensor_shape_value = 31;
|
||||
// Represents an enum value for dtype.
|
||||
tensorflow.DataType tensor_dtype_value = 32;
|
||||
// Represents a value for tf.TensorSpec.
|
||||
TensorSpecProto tensor_spec_value = 33;
|
||||
// Represents a value for tf.TypeSpec.
|
||||
TypeSpecProto type_spec_value = 34;
|
||||
// Represents a value for tf.BoundedTensorSpec.
|
||||
BoundedTensorSpecProto bounded_tensor_spec_value = 35;
|
||||
|
||||
// Represents a list of `Value`.
|
||||
ListValue list_value = 51;
|
||||
// Represents a tuple of `Value`.
|
||||
TupleValue tuple_value = 52;
|
||||
// Represents a dict `Value`.
|
||||
DictValue dict_value = 53;
|
||||
// Represents Python's namedtuple.
|
||||
NamedTupleValue named_tuple_value = 54;
|
||||
}
|
||||
}
|
||||
|
||||
// Represents None.
|
||||
message NoneValue {}
|
||||
|
||||
// Represents a Python list.
|
||||
message ListValue {
|
||||
repeated StructuredValue values = 1;
|
||||
}
|
||||
|
||||
// Represents a Python tuple.
|
||||
message TupleValue {
|
||||
repeated StructuredValue values = 1;
|
||||
}
|
||||
|
||||
// Represents a Python dict keyed by `str`.
|
||||
// The comment on Unicode from Value.string_value applies analogously.
|
||||
message DictValue {
|
||||
map<string, StructuredValue> fields = 1;
|
||||
}
|
||||
|
||||
// Represents a (key, value) pair.
|
||||
message PairValue {
|
||||
string key = 1;
|
||||
StructuredValue value = 2;
|
||||
}
|
||||
|
||||
// Represents Python's namedtuple.
|
||||
message NamedTupleValue {
|
||||
string name = 1;
|
||||
repeated PairValue values = 2;
|
||||
}
|
||||
|
||||
// A protobuf to represent tf.TensorSpec.
|
||||
message TensorSpecProto {
|
||||
string name = 1;
|
||||
tensorflow.TensorShapeProto shape = 2;
|
||||
tensorflow.DataType dtype = 3;
|
||||
}
|
||||
|
||||
// A protobuf to represent tf.BoundedTensorSpec.
|
||||
message BoundedTensorSpecProto {
|
||||
string name = 1;
|
||||
tensorflow.TensorShapeProto shape = 2;
|
||||
tensorflow.DataType dtype = 3;
|
||||
tensorflow.TensorProto minimum = 4;
|
||||
tensorflow.TensorProto maximum = 5;
|
||||
}
|
||||
|
||||
// Represents a tf.TypeSpec
|
||||
message TypeSpecProto {
|
||||
enum TypeSpecClass {
|
||||
UNKNOWN = 0;
|
||||
SPARSE_TENSOR_SPEC = 1; // tf.SparseTensorSpec
|
||||
INDEXED_SLICES_SPEC = 2; // tf.IndexedSlicesSpec
|
||||
RAGGED_TENSOR_SPEC = 3; // tf.RaggedTensorSpec
|
||||
TENSOR_ARRAY_SPEC = 4; // tf.TensorArraySpec
|
||||
DATA_DATASET_SPEC = 5; // tf.data.DatasetSpec
|
||||
DATA_ITERATOR_SPEC = 6; // IteratorSpec from data/ops/iterator_ops.py
|
||||
OPTIONAL_SPEC = 7; // tf.OptionalSpec
|
||||
PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py
|
||||
VARIABLE_SPEC = 9; // tf.VariableSpec
|
||||
ROW_PARTITION_SPEC = 10; // RowPartitionSpec from ragged/row_partition.py
|
||||
reserved 11;
|
||||
REGISTERED_TYPE_SPEC = 12; // The type registered as type_spec_class_name.
|
||||
EXTENSION_TYPE_SPEC = 13; // Subclasses of tf.ExtensionType
|
||||
}
|
||||
TypeSpecClass type_spec_class = 1;
|
||||
|
||||
// The value returned by TypeSpec._serialize().
|
||||
StructuredValue type_state = 2;
|
||||
|
||||
// The name of the TypeSpec class.
|
||||
// * If type_spec_class == REGISTERED_TYPE_SPEC, the TypeSpec class is
|
||||
// the one registered under this name. For types registered outside
|
||||
// core TensorFlow by an add-on library, that library must be loaded
|
||||
// before this value can be deserialized by nested_structure_coder.
|
||||
// * If type_spec_class specifies a particular TypeSpec class, this field is
|
||||
// redundant with the type_spec_class enum, and is only used for error
|
||||
// reporting in older binaries that do not know the tupe_spec_class enum.
|
||||
string type_spec_class_name = 3;
|
||||
|
||||
// The number of flat tensor components required by this TypeSpec.
|
||||
int32 num_flat_components = 4;
|
||||
}
|
78
src/frontends/tensorflow/src/proto/tensor_bundle.proto
Normal file
78
src/frontends/tensorflow/src/proto/tensor_bundle.proto
Normal file
|
@ -0,0 +1,78 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
// Modification Copyright (C) 2023 Intel Corporation
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensor_shape.proto";
|
||||
import "tensor_slice.proto";
|
||||
import "types.proto";
|
||||
import "versions.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "TensorBundleProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.util";
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
||||
|
||||
// Protos used in the tensor bundle module (tf/core/util/tensor_bundle/).
|
||||
|
||||
// Special header that is associated with a bundle.
|
||||
//
|
||||
// TODO(zongheng,zhifengc): maybe in the future, we can add information about
|
||||
// which binary produced this checkpoint, timestamp, etc. Sometime, these can be
|
||||
// valuable debugging information. And if needed, these can be used as defensive
|
||||
// information ensuring reader (binary version) of the checkpoint and the writer
|
||||
// (binary version) must match within certain range, etc.
|
||||
message BundleHeaderProto {
|
||||
// Number of data files in the bundle.
|
||||
int32 num_shards = 1;
|
||||
|
||||
// An enum indicating the endianness of the platform that produced this
|
||||
// bundle. A bundle can only be read by a platform with matching endianness.
|
||||
// Defaults to LITTLE, as most modern platforms are little-endian.
|
||||
//
|
||||
// Affects the binary tensor data bytes only, not the metadata in protobufs.
|
||||
enum Endianness {
|
||||
LITTLE = 0;
|
||||
BIG = 1;
|
||||
}
|
||||
Endianness endianness = 2;
|
||||
|
||||
// Versioning of the tensor bundle format.
|
||||
VersionDef version = 3;
|
||||
}
|
||||
|
||||
// Describes the metadata related to a checkpointed tensor.
|
||||
message BundleEntryProto {
|
||||
// The tensor dtype and shape.
|
||||
DataType dtype = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
// The binary content of the tensor lies in:
|
||||
// File "shard_id": bytes [offset, offset + size).
|
||||
int32 shard_id = 3;
|
||||
int64 offset = 4;
|
||||
int64 size = 5;
|
||||
|
||||
// The CRC32C checksum of the tensor bytes.
|
||||
fixed32 crc32c = 6;
|
||||
|
||||
// Iff present, this entry represents a partitioned tensor. The previous
|
||||
// fields are interpreted as follows:
|
||||
//
|
||||
// "dtype", "shape": describe the full tensor.
|
||||
// "shard_id", "offset", "size", "crc32c": all IGNORED.
|
||||
// These information for each slice can be looked up in their own
|
||||
// BundleEntryProto, keyed by each "slice_name".
|
||||
repeated TensorSliceProto slices = 7;
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
// Modification Copyright (C) 2023 Intel Corporation
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "wrappers.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
||||
|
||||
// A TensorBundle addition which saves extra information about the objects which
|
||||
// own variables, allowing for more robust checkpoint loading into modified
|
||||
// programs.
|
||||
|
||||
message TrackableObjectGraph {
|
||||
message TrackableObject {
|
||||
message ObjectReference {
|
||||
// An index into `TrackableObjectGraph.nodes`, indicating the object
|
||||
// being referenced.
|
||||
int32 node_id = 1;
|
||||
// A user-provided name for the edge.
|
||||
string local_name = 2;
|
||||
}
|
||||
|
||||
message SerializedTensor {
|
||||
// A name for the Tensor. Simple variables have only one
|
||||
// `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may
|
||||
// be restored on object creation as an optimization.
|
||||
string name = 1;
|
||||
// The full name of the variable/tensor, if applicable. Used to allow
|
||||
// name-based loading of checkpoints which were saved using an
|
||||
// object-based API. Should match the checkpoint key which would have been
|
||||
// assigned by tf.train.Saver.
|
||||
string full_name = 2;
|
||||
// The generated name of the Tensor in the checkpoint.
|
||||
string checkpoint_key = 3;
|
||||
// Deprecated bool field for optional restore. This field has never been
|
||||
// set to True.
|
||||
reserved "optional_restore";
|
||||
reserved 4;
|
||||
}
|
||||
|
||||
message SlotVariableReference {
|
||||
// An index into `TrackableObjectGraph.nodes`, indicating the
|
||||
// variable object this slot was created for.
|
||||
int32 original_variable_node_id = 1;
|
||||
// The name of the slot (e.g. "m"/"v").
|
||||
string slot_name = 2;
|
||||
// An index into `TrackableObjectGraph.nodes`, indicating the
|
||||
// `Object` with the value of the slot variable.
|
||||
int32 slot_variable_node_id = 3;
|
||||
}
|
||||
|
||||
// Objects which this object depends on.
|
||||
repeated ObjectReference children = 1;
|
||||
// Serialized data specific to this object.
|
||||
repeated SerializedTensor attributes = 2;
|
||||
// Slot variables owned by this object.
|
||||
repeated SlotVariableReference slot_variables = 3;
|
||||
|
||||
// The registered saver used to save this object. If this saver is not
|
||||
// present when loading the checkpoint, then loading will fail.
|
||||
RegisteredSaver registered_saver = 4;
|
||||
|
||||
// Whether this object has checkpoint values or descendants with checkpoint
|
||||
// values. This is computed at save time to avoid traversing the entire
|
||||
// object graph proto when restoring (which also has to traverse the live
|
||||
// object graph).
|
||||
google.protobuf.BoolValue has_checkpoint_values = 5;
|
||||
}
|
||||
|
||||
repeated TrackableObject nodes = 1;
|
||||
}
|
||||
|
||||
message RegisteredSaver {
|
||||
// The name of the registered saver/restore function.
|
||||
string name = 1;
|
||||
|
||||
// Unique auto-generated name of the object.
|
||||
string object_name = 2;
|
||||
}
|
124
src/frontends/tensorflow/src/proto/wrappers.proto
Normal file
124
src/frontends/tensorflow/src/proto/wrappers.proto
Normal file
|
@ -0,0 +1,124 @@
|
|||
// Protocol Buffers - Google's data interchange format
|
||||
// Copyright 2008 Google Inc. All rights reserved.
|
||||
// https://developers.google.com/protocol-buffers/
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
// Modification Copyright (C) 2023 Intel Corporation
|
||||
|
||||
// Wrappers for primitive (non-message) types. These types are useful
|
||||
// for embedding primitives in the `google.protobuf.Any` type and for places
|
||||
// where we need to distinguish between the absence of a primitive
|
||||
// typed field and its default value.
|
||||
//
|
||||
// These wrappers have no meaningful use within repeated fields as they lack
|
||||
// the ability to detect presence on individual elements.
|
||||
// These wrappers have no meaningful use within a map or a oneof since
|
||||
// individual entries of a map or fields of a oneof can already detect presence.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package google.protobuf;
|
||||
|
||||
option csharp_namespace = "Google.Protobuf.WellKnownTypes";
|
||||
option cc_enable_arenas = true;
|
||||
option go_package = "google.golang.org/protobuf/types/known/wrapperspb";
|
||||
option java_package = "com.google.protobuf";
|
||||
option java_outer_classname = "WrappersProto";
|
||||
option java_multiple_files = true;
|
||||
option objc_class_prefix = "GPB";
|
||||
|
||||
// Wrapper message for `double`.
|
||||
//
|
||||
// The JSON representation for `DoubleValue` is JSON number.
|
||||
message DoubleValue {
|
||||
// The double value.
|
||||
double value = 1;
|
||||
}
|
||||
|
||||
// Wrapper message for `float`.
|
||||
//
|
||||
// The JSON representation for `FloatValue` is JSON number.
|
||||
message FloatValue {
|
||||
// The float value.
|
||||
float value = 1;
|
||||
}
|
||||
|
||||
// Wrapper message for `int64`.
|
||||
//
|
||||
// The JSON representation for `Int64Value` is JSON string.
|
||||
message Int64Value {
|
||||
// The int64 value.
|
||||
int64 value = 1;
|
||||
}
|
||||
|
||||
// Wrapper message for `uint64`.
|
||||
//
|
||||
// The JSON representation for `UInt64Value` is JSON string.
|
||||
message UInt64Value {
|
||||
// The uint64 value.
|
||||
uint64 value = 1;
|
||||
}
|
||||
|
||||
// Wrapper message for `int32`.
|
||||
//
|
||||
// The JSON representation for `Int32Value` is JSON number.
|
||||
message Int32Value {
|
||||
// The int32 value.
|
||||
int32 value = 1;
|
||||
}
|
||||
|
||||
// Wrapper message for `uint32`.
|
||||
//
|
||||
// The JSON representation for `UInt32Value` is JSON number.
|
||||
message UInt32Value {
|
||||
// The uint32 value.
|
||||
uint32 value = 1;
|
||||
}
|
||||
|
||||
// Wrapper message for `bool`.
|
||||
//
|
||||
// The JSON representation for `BoolValue` is JSON `true` and `false`.
|
||||
message BoolValue {
|
||||
// The bool value.
|
||||
bool value = 1;
|
||||
}
|
||||
|
||||
// Wrapper message for `string`.
|
||||
//
|
||||
// The JSON representation for `StringValue` is JSON string.
|
||||
message StringValue {
|
||||
// The string value.
|
||||
string value = 1;
|
||||
}
|
||||
|
||||
// Wrapper message for `bytes`.
|
||||
//
|
||||
// The JSON representation for `BytesValue` is JSON string.
|
||||
message BytesValue {
|
||||
// The bytes value.
|
||||
bytes value = 1;
|
||||
}
|
482
src/frontends/tensorflow/src/saved_model.cpp
Normal file
482
src/frontends/tensorflow/src/saved_model.cpp
Normal file
|
@ -0,0 +1,482 @@
|
|||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#include "graph_iterator_saved_model.hpp"
|
||||
#include "openvino/core/type/element_type.hpp"
|
||||
#include "tensor_bundle.pb.h"
|
||||
#include "trackable_object_graph.pb.h"
|
||||
|
||||
#ifdef ENABLE_SNAPPY_COMPRESSION
|
||||
# include "snappy.h"
|
||||
#endif
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
static T smReadFixed(const char* ptr) {
|
||||
T result = 0;
|
||||
for (uint8_t i = 0; i < sizeof(T); ++i) {
|
||||
result |= ptr[i] << (i * 8);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T smUnpack(char*& ptr, const char* ptr_end) {
|
||||
T result = 0;
|
||||
for (uint8_t i = 0; i < sizeof(T) * 7 && ptr < ptr_end; i += 7) {
|
||||
T byte = *(ptr++);
|
||||
if (byte & 0x80) {
|
||||
result |= ((byte & 0x7F) << i);
|
||||
} else {
|
||||
result |= byte << i;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct VIBlock {
|
||||
uint64_t m_size;
|
||||
uint64_t m_offset;
|
||||
|
||||
void read(char*& ptr, const char* ptr_end) {
|
||||
m_offset = smUnpack<uint64_t>(ptr, ptr_end);
|
||||
m_size = smUnpack<uint64_t>(ptr, ptr_end);
|
||||
}
|
||||
};
|
||||
|
||||
struct VIFooter {
|
||||
VIBlock m_metaIndex;
|
||||
VIBlock m_index;
|
||||
|
||||
void read(char*& ptr, const char* ptr_end) {
|
||||
m_index.read(ptr, ptr_end);
|
||||
m_metaIndex.read(ptr, ptr_end);
|
||||
}
|
||||
|
||||
void read(std::ifstream& fs) {
|
||||
fs.seekg(0, std::ios::end);
|
||||
size_t size = fs.tellg();
|
||||
char footerData[48] = {}, *ptr = &footerData[0];
|
||||
fs.seekg(size - sizeof(footerData));
|
||||
fs.read(ptr, sizeof(footerData));
|
||||
|
||||
// https://github.com/tensorflow/tensorflow/blob/9659b7bdca80a8ef8240eb021d4da089034eeb00/tensorflow/tsl/lib/io/format.cc#L59
|
||||
ptr += sizeof(footerData) - 8;
|
||||
uint32_t magic_lo = *reinterpret_cast<const uint32_t*>(ptr);
|
||||
uint32_t magic_hi = *reinterpret_cast<const uint32_t*>(ptr + 4);
|
||||
uint64_t magic_no = (static_cast<uint64_t>(magic_hi) << 32) | static_cast<uint64_t>(magic_lo);
|
||||
|
||||
FRONT_END_GENERAL_CHECK(magic_no == 0xdb4775248b80fb57ull, "Wrong index file, magic number mismatch detected");
|
||||
|
||||
ptr = &footerData[0];
|
||||
m_metaIndex.read(ptr, ptr + sizeof(footerData));
|
||||
m_index.read(ptr, ptr + sizeof(footerData));
|
||||
}
|
||||
};
|
||||
|
||||
void SavedModelVariablesIndex::read_variables_index_block(std::ifstream& fs,
|
||||
const VIBlock& index,
|
||||
std::vector<char>& data,
|
||||
uint32_t& offset,
|
||||
uint32_t& offset_end) {
|
||||
size_t block_size = index.m_size;
|
||||
data.clear();
|
||||
data.resize(block_size + 5 /*kBlockTrailerSize*/);
|
||||
fs.seekg(index.m_offset, std::ios::beg);
|
||||
fs.read(data.data(), data.size());
|
||||
#ifndef ENABLE_SNAPPY_COMPRESSION
|
||||
FRONT_END_GENERAL_CHECK(data[block_size] == 0, "Compressed files aren't supported");
|
||||
#else
|
||||
FRONT_END_GENERAL_CHECK(data[block_size] == 0 || data[block_size] == 1, "Compression method isn't supported");
|
||||
if (data[block_size] == 1) {
|
||||
size_t uncompressed_length = 0;
|
||||
FRONT_END_GENERAL_CHECK(snappy::GetUncompressedLength(data.data(), data.size(), &uncompressed_length),
|
||||
"Cannot retrieve uncompressed block length");
|
||||
std::string uncompressed_string;
|
||||
uncompressed_string.reserve(uncompressed_length);
|
||||
snappy::Uncompress(data.data(), data.size(), &uncompressed_string);
|
||||
data.resize(uncompressed_length);
|
||||
std::copy(uncompressed_string.begin(), uncompressed_string.end(), data.begin());
|
||||
block_size = uncompressed_length;
|
||||
}
|
||||
#endif
|
||||
uint32_t numRestarts = smReadFixed<uint32_t>(data.data() + block_size - sizeof(uint32_t));
|
||||
size_t maxRestarts = (block_size - sizeof(uint32_t)) / sizeof(uint32_t);
|
||||
FRONT_END_GENERAL_CHECK(maxRestarts >= numRestarts, "Wrong restarts value");
|
||||
offset_end = static_cast<uint32_t>(block_size) - ((numRestarts + 1) * sizeof(uint32_t));
|
||||
offset = smReadFixed<uint32_t>(data.data() + offset_end);
|
||||
}
|
||||
|
||||
void SavedModelVariablesIndex::read_variables_index_pair(char*& ptr,
|
||||
const char* ptr_end,
|
||||
std::string& key,
|
||||
char*& value,
|
||||
uint32_t& val_length) {
|
||||
uint32_t shared, nonShared;
|
||||
shared = smUnpack<uint32_t>(ptr, ptr_end);
|
||||
nonShared = smUnpack<uint32_t>(ptr, ptr_end);
|
||||
val_length = smUnpack<uint32_t>(ptr, ptr_end);
|
||||
|
||||
// Key inherits last part of string (shared-size bytes) and appends new string
|
||||
// shared_part_key1 //resize(0) + append(shared_part_key1)
|
||||
// ............key2 //resize(12) + append(key2)
|
||||
// ............key3 //resize(12) + append(key3)
|
||||
// new_shared_key4 //resize(0) + append(new_shared_key4)
|
||||
// ...........key5 //resize(11) + append(key5)
|
||||
key.resize(shared);
|
||||
key.append(ptr, nonShared);
|
||||
|
||||
value = ptr + nonShared;
|
||||
ptr = value + val_length;
|
||||
}
|
||||
|
||||
void SavedModelVariablesIndex::read_variables_index(std::ifstream& fs,
|
||||
std::map<std::string, std::vector<char>>& varIndex) {
|
||||
VIFooter footer;
|
||||
|
||||
footer.read(fs);
|
||||
|
||||
std::vector<VIBlock> secondLevel;
|
||||
std::vector<char> blockData;
|
||||
|
||||
uint32_t offset = 0, offset_end = 0;
|
||||
|
||||
read_variables_index_block(fs, footer.m_index, blockData, offset, offset_end);
|
||||
char *ptr = blockData.data() + offset, *ptr_end = blockData.data() + offset_end, *value = nullptr;
|
||||
std::string key = "";
|
||||
uint32_t valLength;
|
||||
|
||||
while (ptr < ptr_end) {
|
||||
read_variables_index_pair(ptr, ptr_end, key, value, valLength);
|
||||
|
||||
VIBlock valBlock;
|
||||
valBlock.read(value, value + valLength);
|
||||
secondLevel.push_back(valBlock);
|
||||
ptr = value + valLength;
|
||||
}
|
||||
|
||||
for (auto& block : secondLevel) {
|
||||
read_variables_index_block(fs, block, blockData, offset, offset_end);
|
||||
|
||||
key = "";
|
||||
ptr = blockData.data() + offset;
|
||||
ptr_end = blockData.data() + offset_end;
|
||||
while (ptr < ptr_end) {
|
||||
read_variables_index_pair(ptr, ptr_end, key, value, valLength);
|
||||
varIndex[key] = std::vector<char>(value, value + valLength);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SavedModelVariablesIndex::read_bundle_header() {
|
||||
auto item = m_variables_index.find("");
|
||||
FRONT_END_GENERAL_CHECK(item != m_variables_index.end(), "Bundle Header isn't found in index");
|
||||
|
||||
::tensorflow::BundleHeaderProto bundleHeader;
|
||||
FRONT_END_GENERAL_CHECK(bundleHeader.ParseFromString(item->second.data()),
|
||||
"Bundle Header: Cannot parse Bundle Header");
|
||||
FRONT_END_GENERAL_CHECK(bundleHeader.version().producer() == 1, "Bundle Header: Unsupported producer version");
|
||||
FRONT_END_GENERAL_CHECK(bundleHeader.version().min_consumer() == 0, "Bundle Header: Unsupported consumer version");
|
||||
FRONT_END_GENERAL_CHECK(bundleHeader.endianness() == 0, "Bundle Header: BIG endian isn't supported");
|
||||
|
||||
m_total_shards = bundleHeader.num_shards();
|
||||
}
|
||||
|
||||
void SavedModelVariablesIndex::read_checkpointable_object_graph() {
|
||||
m_variables_map.clear();
|
||||
|
||||
auto item = m_variables_index.find("_CHECKPOINTABLE_OBJECT_GRAPH");
|
||||
FRONT_END_GENERAL_CHECK(item != m_variables_index.end(), "Checkpointable Object Graph isn't found in index");
|
||||
|
||||
::tensorflow::BundleEntryProto entry;
|
||||
FRONT_END_GENERAL_CHECK(entry.ParseFromArray(item->second.data(), static_cast<int>(item->second.size())),
|
||||
"CMO: Cannot parse Bundle Entry");
|
||||
|
||||
FRONT_END_GENERAL_CHECK(entry.slices().empty(), "CMO: Slices are not supported");
|
||||
|
||||
auto shard = m_data_files.find(entry.shard_id());
|
||||
FRONT_END_GENERAL_CHECK(shard != m_data_files.end(), "CMO: data files isn't found");
|
||||
|
||||
std::vector<char> data(entry.size());
|
||||
::tensorflow::TrackableObjectGraph tog;
|
||||
|
||||
// TODO: have to understand this offset
|
||||
// It looks like reinterpret_cast artifact
|
||||
// https://github.com/tensorflow/tensorflow/blob/d90f1947ebcf510b23c238f43c2191e5b3817cb3/tensorflow/cc/experimental/libexport/load.cc#L70
|
||||
int chg = 6;
|
||||
shard->second->seekg(entry.offset() + chg);
|
||||
shard->second->read(data.data(), entry.size() - chg);
|
||||
|
||||
// Might be need to remove this verification:
|
||||
// https://github.com/tensorflow/tensorflow/blob/d90f1947ebcf510b23c238f43c2191e5b3817cb3/tensorflow/cc/experimental/libexport/load.cc#L73
|
||||
// FRONT_END_GENERAL_CHECK(tog.ParseFromArray(data.data(), static_cast<int>(data.size()) - chg), "CMO: Trackable
|
||||
// Object Graph couldn't be read");
|
||||
|
||||
tog.ParseFromArray(data.data(), static_cast<int>(data.size()) - chg);
|
||||
|
||||
for (const auto& node : tog.nodes()) {
|
||||
for (const auto& attr : node.attributes()) {
|
||||
m_variables_map[attr.full_name()] = attr.checkpoint_key();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool GraphIteratorSavedModel::is_valid_signature(const ::tensorflow::SignatureDef& signature) const {
|
||||
const std::map<::tensorflow::DataType, ov::element::Type> types{
|
||||
{::tensorflow::DataType::DT_BOOL, ov::element::boolean},
|
||||
{::tensorflow::DataType::DT_INT16, ov::element::i16},
|
||||
{::tensorflow::DataType::DT_INT32, ov::element::i32},
|
||||
{::tensorflow::DataType::DT_INT64, ov::element::i64},
|
||||
{::tensorflow::DataType::DT_HALF, ov::element::f16},
|
||||
{::tensorflow::DataType::DT_FLOAT, ov::element::f32},
|
||||
{::tensorflow::DataType::DT_DOUBLE, ov::element::f64},
|
||||
{::tensorflow::DataType::DT_UINT8, ov::element::u8},
|
||||
{::tensorflow::DataType::DT_INT8, ov::element::i8},
|
||||
{::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16},
|
||||
{::tensorflow::DataType::DT_STRING, ov::element::undefined}};
|
||||
|
||||
for (const auto& it : signature.inputs()) {
|
||||
if (it.second.name().empty() || types.find(it.second.dtype()) == types.end())
|
||||
return false;
|
||||
}
|
||||
for (const auto& it : signature.outputs()) {
|
||||
if (it.second.name().empty() || types.find(it.second.dtype()) == types.end())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SavedModelVariablesIndex::read_variables(std::ifstream& vi_stream, const std::string& path) {
|
||||
m_variables_index.clear();
|
||||
read_variables_index(vi_stream, m_variables_index);
|
||||
read_bundle_header();
|
||||
|
||||
std::vector<char> suffix(20);
|
||||
for (int32_t shard = 0; shard < m_total_shards; ++shard) {
|
||||
std::snprintf(suffix.data(), suffix.size(), "data-%05d-of-%05d", shard, m_total_shards);
|
||||
std::string fullPath = ov::util::path_join({path, "variables", std::string("variables.") + suffix.data()});
|
||||
m_data_files[shard] =
|
||||
std::shared_ptr<std::ifstream>(new std::ifstream(fullPath, std::ifstream::in | std::ifstream::binary));
|
||||
FRONT_END_GENERAL_CHECK(m_data_files[shard]->is_open(), "Saved Model's variable index file does not exist");
|
||||
}
|
||||
|
||||
read_checkpointable_object_graph();
|
||||
return true;
|
||||
}
|
||||
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
bool SavedModelVariablesIndex::read_variables(std::ifstream& vi_stream, const std::wstring& path) {
|
||||
m_variables_index.clear();
|
||||
read_variables_index(vi_stream, m_variables_index);
|
||||
read_bundle_header();
|
||||
|
||||
std::vector<wchar_t> suffix(20);
|
||||
for (int32_t shard = 0; shard < m_total_shards; ++shard) {
|
||||
swprintf_s(suffix.data(), suffix.size(), L"data-%05d-of-%05d", shard, m_total_shards);
|
||||
std::wstring fullPath =
|
||||
ov::util::path_join_w({path, L"variables", std::wstring(L"variables.") + suffix.data()});
|
||||
m_data_files[shard] =
|
||||
std::shared_ptr<std::ifstream>(new std::ifstream(fullPath, std::ifstream::in | std::ifstream::binary));
|
||||
FRONT_END_GENERAL_CHECK(m_data_files[shard]->is_open(), "Saved Model's variable index file does not exist");
|
||||
}
|
||||
|
||||
read_checkpointable_object_graph();
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct PtrNode {
|
||||
const ::tensorflow::NodeDef* node;
|
||||
std::vector<PtrNode*> inputs;
|
||||
std::vector<PtrNode*> outputs;
|
||||
|
||||
PtrNode() : node(nullptr), inputs(), outputs() {}
|
||||
|
||||
PtrNode(const ::tensorflow::NodeDef& src_node, const std::map<std::string, PtrNode*>& node_dictionary) {
|
||||
node = &src_node;
|
||||
std::vector<std::string> parsedName;
|
||||
for (const auto& input_name : node->input()) {
|
||||
parse_node_name(input_name, parsedName);
|
||||
|
||||
auto input_node = node_dictionary.find(parsedName[0]);
|
||||
if (input_node == node_dictionary.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
input_node->second->outputs.push_back(this);
|
||||
inputs.push_back(input_node->second);
|
||||
}
|
||||
}
|
||||
|
||||
void find_parent_by_op(const std::string& op, std::vector<PtrNode*>& result) const {
|
||||
for (auto input : inputs) {
|
||||
if (input->op() == op) {
|
||||
result.push_back(input);
|
||||
}
|
||||
input->find_parent_by_op(op, result);
|
||||
}
|
||||
}
|
||||
|
||||
static void parse_node_name(const std::string& name, std::vector<std::string>& result) {
|
||||
result.clear();
|
||||
size_t left_pos = name.find_first_of('^'), right_pos = name.find(':');
|
||||
if (left_pos != std::string::npos && left_pos < right_pos) {
|
||||
++left_pos;
|
||||
} else {
|
||||
left_pos = 0;
|
||||
}
|
||||
while (right_pos != std::string::npos && right_pos > left_pos) {
|
||||
result.push_back(name.substr(left_pos, right_pos - left_pos));
|
||||
left_pos = right_pos + 1;
|
||||
right_pos = name.find(':', left_pos);
|
||||
}
|
||||
result.push_back(name.substr(left_pos, name.length() - left_pos));
|
||||
}
|
||||
|
||||
const std::string& op() const {
|
||||
return node->op();
|
||||
}
|
||||
};
|
||||
|
||||
static void read_stateful_partitioned_call(const std::shared_ptr<::tensorflow::GraphDef> graph_def,
|
||||
const ::tensorflow::NodeDef& partCall,
|
||||
std::map<std::string, PtrNode*>& node_dictionary) {
|
||||
FRONT_END_GENERAL_CHECK(partCall.op() == "StatefulPartitionedCall", "Passed node isn't StatefulPartitionedCall");
|
||||
|
||||
std::string func_name = partCall.attr().at("f").func().name();
|
||||
|
||||
const ::tensorflow::FunctionDef* func_def = nullptr;
|
||||
for (const auto& func : graph_def->library().function()) {
|
||||
if (func.signature().name() == func_name) {
|
||||
func_def = &func;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
FRONT_END_GENERAL_CHECK(func_def, "Function isn't found in the library");
|
||||
FRONT_END_GENERAL_CHECK(graph_def->has_library(), "GraphDef contains functions, but doesn't have the library");
|
||||
|
||||
std::map<std::string, PtrNode*> nodes;
|
||||
|
||||
// Filling temporary input nodes for exact function
|
||||
for (int i = 0; i < func_def->signature().input_arg_size(); ++i) {
|
||||
const auto& input_arg = func_def->signature().input_arg(i).name();
|
||||
const auto& parent_input = partCall.input(i);
|
||||
auto input_node = node_dictionary.find(parent_input);
|
||||
if (input_node != node_dictionary.end()) {
|
||||
nodes[input_arg] = input_node->second;
|
||||
}
|
||||
}
|
||||
|
||||
// Parsing nodes and inline partitioned calls
|
||||
for (const auto& node : func_def->node_def()) {
|
||||
nodes[node.name()] = new PtrNode(node, nodes);
|
||||
|
||||
if (node.op() == "StatefulPartitionedCall") {
|
||||
read_stateful_partitioned_call(graph_def, node, nodes);
|
||||
}
|
||||
}
|
||||
|
||||
// Removing temporary input nodes
|
||||
for (int i = 0; i < func_def->signature().input_arg_size(); ++i) {
|
||||
const auto& input_arg = func_def->signature().input_arg(i).name();
|
||||
auto input_node = nodes.find(input_arg);
|
||||
if (input_node != nodes.end()) {
|
||||
nodes.erase(input_node);
|
||||
}
|
||||
}
|
||||
|
||||
// Moving nodes to the global dictionary
|
||||
for (const auto& node : nodes) {
|
||||
std::string global_name = partCall.name() + "/" + node.first;
|
||||
node_dictionary[global_name] = node.second;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphIteratorSavedModel::map_assignvariable(const std::shared_ptr<::tensorflow::GraphDef> graph_def,
|
||||
std::map<std::string, std::string>& variables_map) const {
|
||||
std::map<std::string, PtrNode*> nodes;
|
||||
|
||||
for (const auto& node : graph_def->node()) {
|
||||
nodes[node.name()] = new PtrNode(node, nodes);
|
||||
|
||||
if (node.op() == "StatefulPartitionedCall") {
|
||||
read_stateful_partitioned_call(graph_def, node, nodes);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& node : nodes) {
|
||||
if (node.second->op() != "AssignVariableOp") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: assets reading
|
||||
|
||||
std::vector<PtrNode*> restorev2_nodes;
|
||||
std::vector<PtrNode*> varhandle_nodes;
|
||||
|
||||
node.second->find_parent_by_op("RestoreV2", restorev2_nodes);
|
||||
node.second->find_parent_by_op("VarHandleOp", varhandle_nodes);
|
||||
|
||||
FRONT_END_GENERAL_CHECK(restorev2_nodes.size() == 1, "Found unexpected amount of RestoreV2 nodes");
|
||||
FRONT_END_GENERAL_CHECK(varhandle_nodes.size() == 1, "Found unexpected amount of VarHandleOp nodes");
|
||||
|
||||
std::vector<std::string> restore_output;
|
||||
// Expected path is: RestoreV2 -(output_index)-(0)-> Identity -(0)-(1)-> AssignVariableOp
|
||||
PtrNode::parse_node_name(node.second->inputs[1]->node->input(0), restore_output);
|
||||
|
||||
int output_index = std::atoi(restore_output[restore_output.size() - 1].c_str());
|
||||
|
||||
// Expected path is: Const(tensor_names) -(0)-(1)-> RestoreV2
|
||||
const auto& variable_name =
|
||||
restorev2_nodes[0]->inputs[1]->node->attr().at("value").tensor().string_val(output_index);
|
||||
|
||||
variables_map[varhandle_nodes[0]->node->name()] = variable_name;
|
||||
}
|
||||
|
||||
nodes.clear();
|
||||
}
|
||||
|
||||
bool GraphIteratorSavedModel::is_supported(const std::string& path) {
|
||||
return ov::util::directory_exists(path) && ov::util::file_exists(ov::util::path_join({path, "saved_model.pb"}));
|
||||
}
|
||||
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
bool GraphIteratorSavedModel::is_supported(const std::wstring& path) {
|
||||
return ov::util::directory_exists(path) && ov::util::file_exists(ov::util::path_join_w({path, L"saved_model.pb"}));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
std::basic_string<char> get_saved_model_name<char>() {
|
||||
return "/saved_model.pb";
|
||||
}
|
||||
template <>
|
||||
std::basic_string<char> get_variables_index_name<char>() {
|
||||
return "/variables/variables.index";
|
||||
}
|
||||
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
template <>
|
||||
std::basic_string<wchar_t> get_saved_model_name<wchar_t>() {
|
||||
return L"/saved_model.pb";
|
||||
}
|
||||
template <>
|
||||
std::basic_string<wchar_t> get_variables_index_name<wchar_t>() {
|
||||
return L"/variables/variables.index";
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
|
@ -36,6 +36,25 @@ std::vector<T> reorder_ops_by_names(const std::vector<std::string>& names, const
|
|||
}
|
||||
return resulted_ops;
|
||||
};
|
||||
|
||||
/// \brief Adds known input names from Saved Model file format
|
||||
/// \param[in] node Node which should be updated
|
||||
/// \param[in] saved_model_names Map of names from saved model
|
||||
/// \returns True if node was updated, false otherwise
|
||||
static bool apply_saved_model_names(std::shared_ptr<ov::Node> node,
|
||||
const std::shared_ptr<std::map<std::string, std::string>>& saved_model_names) {
|
||||
for (size_t i = 0; i < node->get_output_size(); ++i) {
|
||||
const auto& node_names = node->get_output_tensor(i).get_names();
|
||||
for (const auto& name : node_names) {
|
||||
const auto& saved_model_name = saved_model_names->find(name);
|
||||
if (saved_model_name != saved_model_names->end()) {
|
||||
node->set_friendly_name(saved_model_name->second);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TranslateSession::TranslateSession(const ov::frontend::InputModel::Ptr& input_model,
|
||||
|
@ -94,6 +113,8 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu
|
|||
const auto& model_inputs = model_tf->get_inputs();
|
||||
const auto& model_outputs = model_tf->get_outputs();
|
||||
const auto& model_frozen_inputs = model_tf->get_tensor_values();
|
||||
const auto& saved_model_inputs = model_tf->get_saved_model_input_names();
|
||||
const auto& saved_model_outputs = model_tf->get_saved_model_output_names();
|
||||
|
||||
// fill ng_op_map with Constant outputs for frozen inputs
|
||||
for (const auto& frozen_input : model_frozen_inputs) {
|
||||
|
@ -123,6 +144,11 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu
|
|||
|
||||
auto param = std::make_shared<ov::opset8::Parameter>(input_type, input_shape);
|
||||
set_node_name(input_name, param);
|
||||
if (saved_model_inputs.get() && saved_model_inputs->size() > 0) {
|
||||
if (!apply_saved_model_names(param, saved_model_inputs)) {
|
||||
param->get_output_tensor(0).add_names({"saved_model_unused"});
|
||||
}
|
||||
}
|
||||
params.push_back(param);
|
||||
ng_op_map[input_name] = {param};
|
||||
}
|
||||
|
@ -273,10 +299,30 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu
|
|||
if (port_type == "none") {
|
||||
for (const auto& node_output : ng_op_map[operation_name]) {
|
||||
auto result_node = std::make_shared<ov::opset8::Result>(node_output);
|
||||
// to be aligned with Legacy Frontend we set a name along with output port index
|
||||
// though, the Result name is not used in the OV API 2.0 but it is checked in MO args tests
|
||||
result_node->set_friendly_name(model_output_name + ":0");
|
||||
results.push_back(result_node);
|
||||
// Customize output name in case we have mapping from Saved Model format
|
||||
if (saved_model_outputs.get() && saved_model_outputs->size() > 0) {
|
||||
bool isUsed = true;
|
||||
for (const auto& name : model_output_tensor_place->get_names()) {
|
||||
auto saved_model_name = saved_model_outputs->find(name);
|
||||
if (saved_model_name == saved_model_outputs->end()) {
|
||||
saved_model_name = saved_model_outputs->find(name + ":0");
|
||||
}
|
||||
if (saved_model_name != saved_model_outputs->end()) {
|
||||
result_node->set_friendly_name(saved_model_name->second);
|
||||
results.push_back(result_node);
|
||||
isUsed = false;
|
||||
break;
|
||||
}
|
||||
if (!isUsed) {
|
||||
result_node->get_input_tensor(0).add_names({"saved_model_unused"});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// to be aligned with Legacy Frontend we set a name along with output port index
|
||||
// though, the Result name is not used in the OV API 2.0 but it is checked in MO args tests
|
||||
result_node->set_friendly_name(model_output_name + ":0");
|
||||
results.push_back(result_node);
|
||||
}
|
||||
}
|
||||
} else if (port_type == "out") {
|
||||
const auto& node_outputs = ng_op_map[operation_name];
|
||||
|
|
|
@ -31,6 +31,10 @@ public:
|
|||
|
||||
std::shared_ptr<ov::Model> get_body_ov_model(const std::string& body_graph_name);
|
||||
|
||||
ov::frontend::InputModel::Ptr get_input_model(void) const {
|
||||
return m_input_model;
|
||||
}
|
||||
|
||||
private:
|
||||
const ov::frontend::InputModel::Ptr m_input_model;
|
||||
const std::shared_ptr<TranslatorDictionaryType> m_translator_map;
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "internal_operation.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
/// Pseudo-entity for storing strings
|
||||
class StringConstant : public InternalOperation {
|
||||
public:
|
||||
OPENVINO_OP("StringConstant", "ov::frontend::tensorflow::util", InternalOperation);
|
||||
|
||||
StringConstant(ov::Any data, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: InternalOperation(decoder, {}, 1),
|
||||
m_data(data) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
StringConstant(std::string& str, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: InternalOperation(decoder, {}, 1),
|
||||
m_data({str}) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
StringConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: InternalOperation(decoder, {}, 1) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
void validate_and_infer_types() override {
|
||||
set_output_type(0, ov::element::dynamic, ov::PartialShape::dynamic());
|
||||
}
|
||||
|
||||
ov::Any get_data() {
|
||||
return m_data;
|
||||
}
|
||||
|
||||
std::string& get_string() {
|
||||
return m_data.as<std::vector<std::string>>()[0];
|
||||
}
|
||||
|
||||
private:
|
||||
ov::Any m_data;
|
||||
ov::Shape m_shape;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace pass {
|
||||
|
||||
// This transformation removes isolated subgraph unused Parameters and
|
||||
// Results marked as unused by Saved Model settings
|
||||
class SavedModelUnusedRemover : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::tensorflow::pass::SavedModelUnusedRemover");
|
||||
SavedModelUnusedRemover() {}
|
||||
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#include "helper_transforms/const_to_result_remover.hpp"
|
||||
|
||||
#include "helper_ops/string_constant.hpp"
|
||||
#include "helper_ops/unsupported_constant.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
|
||||
|
@ -22,7 +23,8 @@ bool ConstToResultRemover::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
|||
for (const auto& result : m->get_results()) {
|
||||
auto unsupported_const = as_type_ptr<UnsupportedConstant>(result->get_input_node_shared_ptr(0));
|
||||
auto const_node = as_type_ptr<Constant>(result->get_input_node_shared_ptr(0));
|
||||
if (unsupported_const || const_node) {
|
||||
auto string_const = as_type_ptr<StringConstant>(result->get_input_node_shared_ptr(0));
|
||||
if (unsupported_const || const_node || string_const) {
|
||||
results_to_remove.push_back(result);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "helper_transforms/saved_model_unused_remover.hpp"
|
||||
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset8;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace pass {
|
||||
|
||||
bool SavedModelUnusedRemover::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
||||
ParameterVector params_to_remove;
|
||||
ResultVector results_to_remove;
|
||||
|
||||
// There is two cases
|
||||
// 1. When we found unused result with/without unused parameter
|
||||
// 2. When we found unused parameter
|
||||
|
||||
for (const auto& result : m->get_results()) {
|
||||
bool isUsed = false;
|
||||
for (size_t i = 0; i < result->get_input_size(); ++i) {
|
||||
const auto& node_names = result->get_input_tensor(i).get_names();
|
||||
isUsed = std::find(node_names.begin(), node_names.end(), "saved_model_unused") == node_names.end();
|
||||
}
|
||||
if (!isUsed) {
|
||||
results_to_remove.push_back(result);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto param = as_type_ptr<Parameter>(result->get_input_node_shared_ptr(0));
|
||||
if (param) {
|
||||
for (size_t i = 0; i < param->get_output_size(); ++i) {
|
||||
const auto& node_names = param->get_output_tensor(i).get_names();
|
||||
isUsed = std::find(node_names.begin(), node_names.end(), "saved_model_unused") == node_names.end();
|
||||
}
|
||||
if (!isUsed) {
|
||||
results_to_remove.push_back(result);
|
||||
params_to_remove.push_back(param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& param : m->get_parameters()) {
|
||||
bool isUsed = false;
|
||||
for (size_t i = 0; i < param->get_output_size(); ++i) {
|
||||
const auto& node_names = param->get_output_tensor(i).get_names();
|
||||
isUsed = std::find(node_names.begin(), node_names.end(), "saved_model_unused") == node_names.end();
|
||||
}
|
||||
if (!isUsed && std::find(params_to_remove.begin(), params_to_remove.end(), param) == params_to_remove.end()) {
|
||||
params_to_remove.push_back(param);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& result : results_to_remove) {
|
||||
m->remove_result(result);
|
||||
}
|
||||
|
||||
for (const auto& param : params_to_remove) {
|
||||
m->remove_parameter(param);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace pass
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
|
@ -3,6 +3,7 @@
|
|||
//
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "helper_ops/string_constant.hpp"
|
||||
#include "helper_ops/unsupported_constant.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
|
||||
|
@ -16,10 +17,15 @@ namespace tensorflow {
|
|||
namespace op {
|
||||
|
||||
OutputVector translate_const_op(const NodeContext& node) {
|
||||
auto ov_type = node.get_attribute<element::Type>("dtype");
|
||||
auto ov_type = node.get_attribute_as_any("dtype");
|
||||
std::shared_ptr<Node> const_node;
|
||||
if (ov_type == element::dynamic) {
|
||||
const_node = std::make_shared<UnsupportedConstant>();
|
||||
if (!ov_type.is<ov::element::Type>() || ov_type.as<ov::element::Type>() == ov::element::dynamic ||
|
||||
ov_type.as<ov::element::Type>() == ov::element::undefined) {
|
||||
if (ov_type.is<std::string>() && ov_type.as<std::string>() == "DT_STRING") {
|
||||
const_node = std::make_shared<StringConstant>(node.get_attribute_as_any("value"));
|
||||
} else {
|
||||
const_node = std::make_shared<UnsupportedConstant>();
|
||||
}
|
||||
} else {
|
||||
auto tensor = node.get_attribute<Tensor>("value");
|
||||
const_node = std::make_shared<Constant>(tensor.get_element_type(), tensor.get_shape(), tensor.data());
|
||||
|
|
|
@ -14,7 +14,13 @@ namespace tensorflow {
|
|||
namespace op {
|
||||
|
||||
OutputVector translate_identity_op(const NodeContext& node) {
|
||||
vector<string> supported_ops = {"Identity", "PreventGradient", "Snapshot", "StopGradient"};
|
||||
vector<string> supported_ops = {"Identity",
|
||||
"PreventGradient",
|
||||
"Snapshot",
|
||||
"StopGradient",
|
||||
"ReadVariableOp",
|
||||
"ShardedFilename",
|
||||
"MergeV2Checkpoints"};
|
||||
default_op_checks(node, 1, supported_ops);
|
||||
auto input = node.get_input(0);
|
||||
|
||||
|
|
26
thirdparty/CMakeLists.txt
vendored
26
thirdparty/CMakeLists.txt
vendored
|
@ -416,6 +416,32 @@ if(ENABLE_OV_TF_LITE_FRONTEND)
|
|||
set(flatbuffers_DEPENDENCY ${flatbuffers_DEPENDENCY} PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Snappy Compression
|
||||
#
|
||||
if(ENABLE_SNAPPY_COMPRESSION)
|
||||
function(tf_build_snappy)
|
||||
set(BUILD_SHARED_LIBS OFF)
|
||||
set(SNAPPY_BUILD_BENCHMARKS OFF)
|
||||
set(SNAPPY_BUILD_TESTS OFF)
|
||||
set(INSTALL_GTEST OFF)
|
||||
set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
|
||||
# Removes 3rd party errors which may affect OpenVINO CI
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
if(NOT CMAKE_CXX_FLAGS MATCHES "-Werror=return-type")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=return-type")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
add_subdirectory(snappy EXCLUDE_FROM_ALL)
|
||||
endfunction()
|
||||
|
||||
tf_build_snappy()
|
||||
ov_install_static_lib(snappy ${OV_CPACK_COMP_CORE})
|
||||
endif()
|
||||
|
||||
#
|
||||
# ONNX
|
||||
#
|
||||
|
|
1
thirdparty/snappy
vendored
Submodule
1
thirdparty/snappy
vendored
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit dc05e026488865bc69313a68bcc03ef2e4ea8e83
|
Loading…
Reference in New Issue
Block a user