[POC][TF FE] Support SavedModel format (with compression) (#16317)

* Added Saved Model proto descriptors

* Included Google's protobuf repository

* Added wstring version of ov::util::directory_exists

* Added initial implementation of Saved Model iterator

# Conflicts:
#	src/frontends/tensorflow/src/frontend.cpp

* Added missing proto files to repository

* Implemented reading of variables index and data files

# Conflicts:
#	src/frontends/tensorflow/src/frontend.cpp

* Renamed class

# Conflicts:
#	src/frontends/tensorflow/src/frontend.cpp

* Fix for cross-platform directory_exists

* Fixed codestyle and simplified code

* CI fixes

* Separeted Saved Model iterator from Proto iterator

* Moved variables index into separate class

* Added initial implementation of reading a variables from
saved model

# Conflicts:
#	src/frontends/tensorflow/src/frontend.cpp

* Added external variable mapping

* Code cleanup

* Commit is for discussion purposes!!!
Implemented RestoreV2 with a workaround for strings
Not optimized, includes mem leak

* In progress...

* Added DT_STRING coverage into decoder_proto

* m_variables_index moved into underlying class

* Updated copyrgihts, added space between license and code

* Moved string constant to separate class

* Added AssignVariableOp operation

* Changed behavior of RestoreV2
Updated stubs for other ops

* Second working implementation, enabled:
Program-only models
Variables reading from data files

* Extended docs

* Fixed dynamic type

* Fixed naming

* Added Snappy submodule to support compression in TF FE

* Enabled Snappy Compression for TF FE

* Make static linkage of Snappy
Changing Warning as error behavior for 3rd party

* CI fixes

* Added Snappy copyright info

* Aligned behavior of StringConstant with UnsupportedConstant

* Added correct naming and removing unused inputs/outputs
This commit is contained in:
Georgy Krivoruchko 2023-03-24 15:07:16 +04:00 committed by GitHub
parent 9eab122952
commit c5b348dd4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 2923 additions and 27 deletions

3
.gitmodules vendored
View File

@ -66,3 +66,6 @@
[submodule "thirdparty/flatbuffers/flatbuffers"]
path = thirdparty/flatbuffers/flatbuffers
url = https://github.com/google/flatbuffers.git
[submodule "thirdparty/snappy"]
path = thirdparty/snappy
url = https://github.com/google/snappy.git

View File

@ -156,6 +156,8 @@ ie_option(ENABLE_OV_TF_FRONTEND "Enable TensorFlow FrontEnd" ON)
ie_option(ENABLE_OV_TF_LITE_FRONTEND "Enable TensorFlow Lite FrontEnd" ON)
ie_dependent_option(ENABLE_SYSTEM_PROTOBUF "Use system protobuf" OFF
"ENABLE_OV_ONNX_FRONTEND OR ENABLE_OV_PADDLE_FRONTEND OR ENABLE_OV_TF_FRONTEND;BUILD_SHARED_LIBS" OFF)
ie_dependent_option(ENABLE_SNAPPY_COMPRESSION "Enables compression support for TF FE" ON
"ENABLE_OV_TF_FRONTEND" ON)
ie_option(ENABLE_OV_IR_FRONTEND "Enable IR FrontEnd" ON)
ie_dependent_option(ENABLE_SYSTEM_FLATBUFFERS "Use system flatbuffers" ON
"ENABLE_OV_TF_LITE_FRONTEND" OFF)

View File

@ -1547,3 +1547,61 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
See the License for the specific language governing permissions and
limitations under the License.
-------------------------------------------------------------
28. Snappy (https://github.com/google/snappy/)
Copyright 2011, Google Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
===
Some of the benchmark data in testdata/ is licensed differently:
- fireworks.jpeg is Copyright 2013 Steinar H. Gunderson, and
is licensed under the Creative Commons Attribution 3.0 license
(CC-BY-3.0). See https://creativecommons.org/licenses/by/3.0/
for more information.
- kppkn.gtb is taken from the Gaviota chess tablebase set, and
is licensed under the MIT License. See
https://sites.google.com/site/gaviotachessengine/Home/endgame-tablebases-1
for more information.
- paper-100k.pdf is an excerpt (bytes 92160 to 194560) from the paper
“Combinatorial Modeling of Chromatin Features Quantitatively Predicts DNA
Replication Timing in _Drosophila_” by Federico Comoglio and Renato Paro,
which is licensed under the CC-BY license. See
http://www.ploscompbiol.org/static/license for more ifnormation.
- alice29.txt, asyoulik.txt, plrabn12.txt and lcet10.txt are from Project
Gutenberg. The first three have expired copyrights and are in the public
domain; the latter does not have expired copyright, but is still in the
public domain according to the license information
(http://www.gutenberg.org/ebooks/53).

View File

@ -123,6 +123,15 @@ void create_directory_recursive(const std::string& path);
*/
bool directory_exists(const std::string& path);
#ifdef OPENVINO_ENABLE_UNICODE_PATH_SUPPORT
/**
* @brief Interface function to check if directory exists for given path
* @param path - path to directory wide-string
* @return true if directory exists, false otherwise
*/
bool directory_exists(const std::wstring& path);
#endif
/**
* @brief Returns file size for file
* @param[in] path The file name

View File

@ -27,6 +27,9 @@
# define get_absolute_path(result, path) _fullpath(result, path.c_str(), MAX_ABS_PATH)
/// @brief Windows-specific 'stat' wrapper
# define stat _stat
# ifdef OPENVINO_ENABLE_UNICODE_PATH_SUPPORT
# define wstat _wstat
# endif
/// @brief Windows-specific 'mkdir' wrapper
# define makedir(dir) _mkdir(dir)
// Copied from linux libc sys/stat.h:
@ -403,6 +406,21 @@ bool ov::util::directory_exists(const std::string& path) {
return false;
}
#ifdef OPENVINO_ENABLE_UNICODE_PATH_SUPPORT
bool ov::util::directory_exists(const std::wstring& path) {
# ifdef _WIN32
struct stat sb;
if (wstat(path.c_str(), &sb) == 0 && S_ISDIR(sb.st_mode)) {
return true;
}
return false;
# else
return directory_exists(wstring_to_string(path));
# endif
}
#endif
namespace {
template <typename C,

View File

@ -2,7 +2,22 @@
# SPDX-License-Identifier: Apache-2.0
#
list(APPEND CUSTOM_LINK_LIBRARIES
openvino::core::dev
openvino::frontend::tensorflow_common
)
if(ENABLE_SNAPPY_COMPRESSION)
list(APPEND CUSTOM_LINK_LIBRARIES
snappy
)
endif()
ov_add_frontend(NAME tensorflow
LINKABLE_FRONTEND
FILEDESCRIPTION "FrontEnd to load and convert TensorFlow file format"
LINK_LIBRARIES openvino::core::dev openvino::frontend::tensorflow_common)
LINK_LIBRARIES ${CUSTOM_LINK_LIBRARIES})
if(ENABLE_SNAPPY_COMPRESSION)
target_compile_definitions(openvino_tensorflow_frontend PUBLIC ENABLE_SNAPPY_COMPRESSION)
endif()

View File

@ -129,7 +129,12 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
}
case ::tensorflow::AttrValue::ValueCase::kType: {
return get_ov_type(attrs[0].type());
auto atype = attrs[0].type();
if (atype != ::tensorflow::DT_STRING) {
return get_ov_type(attrs[0].type());
} else {
return ov::Any("DT_STRING");
}
}
case ::tensorflow::AttrValue::ValueCase::kList: {
@ -168,7 +173,11 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
if (list.type_size()) {
std::vector<ov::element::Type> res;
for (int idx = 0; idx < list.type_size(); ++idx) {
res.emplace_back(get_ov_type(list.type(idx)));
if (list.type(idx) != ::tensorflow::DataType::DT_STRING) {
res.emplace_back(get_ov_type(list.type(idx)));
} else {
res.emplace_back(ov::element::undefined);
}
}
return res;
}
@ -194,9 +203,22 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
FRONT_END_GENERAL_CHECK(pshape.is_static(), "Dynamic shapes are not supported for Tensor attribute.");
const auto& tf_type = tensor_proto.dtype();
auto ov_type = get_ov_type(tf_type);
FRONT_END_GENERAL_CHECK(
ov_type.is_static(),
"Encountered unknown element type " + DataType_Name(tf_type) + " on an empty tensor_proto");
if (tf_type != ::tensorflow::DataType::DT_STRING) {
FRONT_END_GENERAL_CHECK(
ov_type.is_static(),
"Encountered unknown element type " + DataType_Name(tf_type) + " on an empty tensor_proto");
} else {
ov_type = ov::element::u64;
pshape.resize(0);
pshape.push_back(tensor_proto.string_val_size());
}
if (tf_type == ::tensorflow::DataType::DT_STRING) {
auto data = std::vector<std::string>();
for (auto& item : tensor_proto.string_val()) {
data.push_back(item);
}
return data;
}
ov::Tensor res(ov_type, pshape.get_shape());
auto tensor_content = tensor_proto.tensor_content();
if (!tensor_content.empty() && tensor_proto.has_tensor_shape()) {

View File

@ -5,10 +5,12 @@
#include "openvino/frontend/tensorflow/frontend.hpp"
#include "graph_iterator_proto.hpp"
#include "graph_iterator_saved_model.hpp"
#include "helper_transforms/block_lstm_replacer.hpp"
#include "helper_transforms/const_to_result_remover.hpp"
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
#include "helper_transforms/gru_block_cell_replacer.hpp"
#include "helper_transforms/saved_model_unused_remover.hpp"
#include "input_model.hpp"
#include "op_table.hpp"
#include "openvino/frontend/tensorflow/extension/conversion.hpp"
@ -86,6 +88,8 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
// for automatic deduction of the frontend to convert the model
// we have more strict rule that is to have `.pb` extension in the path
return true;
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
return true;
}
}
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
@ -97,6 +101,8 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
// for automatic deduction of the frontend to convert the model
// we have more strict rule that is to have `.pb` extension in the path
return true;
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
return true;
}
}
#endif
@ -118,6 +124,18 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
if (GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProto>(model_path), m_telemetry);
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
std::shared_ptr<GraphIteratorSavedModel> graph_iterator;
if (variants.size() > 1 && variants[1].is<std::string>()) {
graph_iterator = std::make_shared<GraphIteratorSavedModel>(model_path, variants[1].as<std::string>());
} else {
graph_iterator = std::make_shared<GraphIteratorSavedModel>(model_path, std::string("serve"));
}
return std::make_shared<InputModel>(graph_iterator,
m_telemetry,
graph_iterator->get_variables_index(),
graph_iterator->get_saved_model_input_names(),
graph_iterator->get_saved_model_output_names());
}
}
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
@ -126,6 +144,20 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
if (GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format with a path in Unicode
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProto>(model_path), m_telemetry);
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
std::shared_ptr<GraphIteratorSavedModel> graph_iterator;
if (variants.size() > 1 && variants[1].is<std::string>()) {
graph_iterator = std::make_shared<GraphIteratorSavedModel>(
model_path,
ov::util::wstring_to_string(variants[1].as<std::wstring>()));
} else {
graph_iterator = std::make_shared<GraphIteratorSavedModel>(model_path, std::string("serve"));
}
return std::make_shared<InputModel>(graph_iterator,
m_telemetry,
graph_iterator->get_variables_index(),
graph_iterator->get_saved_model_input_names(),
graph_iterator->get_saved_model_output_names());
}
}
#endif
@ -232,6 +264,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
// run transformations to convert sub-graphs with intermediate (or FrameworkNode) operations
// into sub-graphs with only OpenVINO operations
ov::pass::Manager manager;
manager.register_pass<pass::SavedModelUnusedRemover>();
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
manager.register_pass<pass::BlockLSTMReplacer>();
manager.register_pass<pass::GRUBlockCellReplacer>();

View File

@ -5,6 +5,7 @@
#pragma once
#include <fstream>
#include <vector>
#include "decoder_argdef.hpp"
#include "decoder_proto.hpp"
@ -18,6 +19,7 @@ namespace frontend {
namespace tensorflow {
class GraphIteratorProto : public GraphIterator {
protected:
std::shared_ptr<::tensorflow::GraphDef> m_graph_def;
std::shared_ptr<::tensorflow::FunctionDef> m_func_def;
@ -27,6 +29,11 @@ class GraphIteratorProto : public GraphIterator {
std::vector<std::string> m_input_names;
std::vector<std::string> m_output_names;
GraphIteratorProto()
: m_graph_def(std::make_shared<::tensorflow::GraphDef>()),
m_func_def(nullptr),
m_library_map() {}
public:
GraphIteratorProto(const std::shared_ptr<::tensorflow::GraphDef>& graph_def,
const std::shared_ptr<::tensorflow::FunctionDef>& func_def,
@ -150,6 +157,7 @@ public:
return m_output_names;
}
};
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,291 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <map>
#include "graph_iterator_proto.hpp"
#include "openvino/util/file_util.hpp"
#include "saved_model.pb.h"
namespace ov {
namespace frontend {
namespace tensorflow {
struct VIBlock;
template <typename T>
std::basic_string<T> get_saved_model_name() {}
template <typename T>
std::basic_string<T> get_variables_index_name() {}
template <>
std::basic_string<char> get_saved_model_name<char>();
template <>
std::basic_string<char> get_variables_index_name<char>();
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
template <>
std::basic_string<wchar_t> get_saved_model_name<wchar_t>();
template <>
std::basic_string<wchar_t> get_variables_index_name<wchar_t>();
#endif
// Stores information about variables index
class SavedModelVariablesIndex {
// Contains maximum amount of shards, used for creating corrext extension
int32_t m_total_shards;
// Contains BundleEntryProto variables list, readed from .index file
std::map<std::string, std::vector<char>> m_variables_index;
// List of opened data files for using with BundleEntryProto
std::map<int32_t, std::shared_ptr<std::ifstream>> m_data_files;
// List of mapped variables which could be read using TrackableObjectGraph
std::map<std::string, std::string> m_variables_map;
public:
/// \brief Reads variables from opened variable index file. Can cause an asserts in case of issues.
/// \param vi_stream Opened stream file, file pointer doesn't matter, it will be rewind internally.
/// \param path A path to file with variables data
/// \returns Returns true in case of everything loads successfully, false otherwise
bool read_variables(std::ifstream& vi_stream, const std::string& path);
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
/// \brief Reads variables from opened variable index file. Can cause an asserts in case of issues.
/// \param vi_stream Opened stream file, file pointer doesn't matter, it will be rewind internally.
/// \param path A path to file with variables data
/// \returns Returns true in case of everything loads successfully, false otherwise
bool read_variables(std::ifstream& vi_stream, const std::wstring& path);
#endif
/// \brief Returns data and size of data of stored variable
/// \param name Name of variable
/// \param data Pointer on a pointer where data pointer will be returned
/// \param size Pointer on a variable which will stores data size
/// \returns Returns true in case variable was found, false otherwise (data and size will be untouched)
bool get_variable(const std::string& name, const char** data, size_t* size) const {
auto varItem = m_variables_index.find(name);
if (varItem == m_variables_index.end()) {
return false;
}
if (data != nullptr) {
*data = varItem->second.data();
}
if (size != nullptr) {
*size = varItem->second.size();
}
return true;
}
/// \brief Returns data and size of data of mapped variable from trackable object graph to variables index
/// \param name Name of a mapping variable
/// \param data Pointer on a pointer where data pointer will be returned
/// \param size Pointer on a variable which will stores data size
/// \returns Returns true in case variable was found, false otherwise (data and size will be untouched)
bool get_mapped_variable(const std::string& name, const char** data, size_t* size) const {
auto mapItem = m_variables_map.find(name);
if (mapItem == m_variables_map.end()) {
return false;
}
return get_variable(mapItem->second, data, size);
}
/// \brief Checks if variable has a mapped pair
/// \param name Name of variable for checking existance
/// \returns True in case variable has mapped value and false otherwise
bool has_mapped_variable(const std::string& name) const {
auto mapItem = m_variables_map.find(name);
return mapItem != m_variables_map.end();
}
/// \brief Returns shared pointer to a requested shard_id, or nullptr in case of shard_id isn't found
/// \param shard_id Requested shard_id
/// \returns Valid shared_ptr with ifstream or with nullptr if shard isn't found
std::shared_ptr<std::ifstream> get_data_file(const int32_t shard_id) const {
auto result = m_data_files.find(shard_id);
return result != m_data_files.end() ? result->second : nullptr;
}
/// \brief Adds variable mapping to the variables map
/// \param var_name Variable full name (from .index file)
/// \param map_name Mapped name
/// \param rewrite Rewrite mapped value in case it exists
/// \returns True if map updated. False if nothing changed (if variable exists and rewrite is false).
bool map_variable(const std::string& var_name, const std::string& map_name, bool rewrite = false) {
if (m_variables_map.find(var_name) != m_variables_map.end() && rewrite == false) {
return false;
}
m_variables_map[var_name] = map_name;
return true;
}
private:
/// \brief Reads block structure of .index file
/// \param[in,out] fs Filestream of .index file, position in file will be updated
/// \param[in] index Variables index block which stores information about block
/// \param[out] data Block data will be readed
/// \param[out] offset Offset of block start
/// \param[out] offset_end Offset of block end
void read_variables_index_block(std::ifstream& fs,
const VIBlock& index,
std::vector<char>& data,
uint32_t& offset,
uint32_t& offset_end);
/// \brief Reads key=value pair from provided pointer
/// \param[in,out] ptr Actual pointer, will be moved to the end of readed pair (to read next)
/// \param[in] ptr_end End of memory which shouldn't be passed in case of broken structure
/// \param[out] key Key name
/// \param[out] value Stored value for key (isn't a pure string, data block)
/// \param[out] val_lenght Length of readed value
void read_variables_index_pair(char*& ptr,
const char* ptr_end,
std::string& key,
char*& value,
uint32_t& val_length);
/// \brief Reads .index file and stores key=value map in provided varIndex
/// \param[in,out] fs Filestream should be parsed. Position in file will be updated
/// \param[out] varIndex Variables indx (key=value) from given filestream
void read_variables_index(std::ifstream& fs, std::map<std::string, std::vector<char>>& varIndex);
/// \brief Reads bundle header if it is available. Checks version and saves info about amount of shards
void read_bundle_header();
/// \brief Reads key=value map from storef _CHECKPOINTABLE_OBJECT_GRAPH variable
void read_checkpointable_object_graph();
};
// Loads graph from Tensorflow Saved Model file (saved_model.pb)
class GraphIteratorSavedModel : public GraphIteratorProto {
std::shared_ptr<::tensorflow::SavedModel> m_saved_model;
std::shared_ptr<SavedModelVariablesIndex> m_variables_index;
std::shared_ptr<std::map<std::string, std::string>> m_inputs_map;
std::shared_ptr<std::map<std::string, std::string>> m_outputs_map;
public:
template <typename T>
GraphIteratorSavedModel(const std::basic_string<T>& path, const std::string& tags)
: m_saved_model(std::make_shared<::tensorflow::SavedModel>()) {
this->read_saved_model(path, tags);
}
static bool is_supported(const std::string& path);
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
static bool is_supported(const std::wstring& path);
#endif
std::shared_ptr<SavedModelVariablesIndex> get_variables_index() {
return m_variables_index;
}
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_input_names() const {
return m_inputs_map;
}
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_output_names() const {
return m_outputs_map;
}
private:
bool is_valid_signature(const ::tensorflow::SignatureDef& signature) const;
template <typename T>
bool read_saved_model(const std::basic_string<T>& path, const std::string& tags) {
std::ifstream sm_stream{path + get_saved_model_name<T>(), std::ifstream::in | std::ifstream::binary};
FRONT_END_GENERAL_CHECK(sm_stream && sm_stream.is_open(), "Model file does not exist");
std::basic_string<T> varIndexPath = path + get_variables_index_name<T>();
if (ov::util::file_exists(varIndexPath)) {
m_variables_index = std::make_shared<SavedModelVariablesIndex>();
std::ifstream vi_stream{varIndexPath, std::ifstream::in | std::ifstream::binary};
FRONT_END_GENERAL_CHECK(vi_stream && vi_stream.is_open(),
"Saved Model's variable index file does not exist");
FRONT_END_GENERAL_CHECK(m_variables_index->read_variables(vi_stream, path),
"Saved Model's variable index file cannot be parsed");
}
bool res = m_saved_model->ParseFromIstream(&sm_stream);
FRONT_END_GENERAL_CHECK(res && m_saved_model->meta_graphs_size(), "Saved Model cannot be parsed");
for (const auto& meta_graph : m_saved_model->meta_graphs()) {
if (!meta_graph.has_graph_def()) {
continue;
}
if (m_saved_model->meta_graphs_size() > 1) {
bool tag_found = false;
for (const auto& tag : meta_graph.meta_info_def().tags()) {
if (tags.find(tag) != std::string::npos) {
tag_found = true;
break;
}
}
if (!tag_found) {
continue;
}
}
std::map<std::string, const ::tensorflow::SignatureDef*> validSignatures = {};
for (const auto& sit : meta_graph.signature_def()) {
const std::string& key = sit.first;
const ::tensorflow::SignatureDef& val = sit.second;
if (is_valid_signature(val)) {
validSignatures[key] = &val;
}
}
auto serving_default = validSignatures.find("serving_default");
if (serving_default != validSignatures.end()) {
m_inputs_map = std::make_shared<std::map<std::string, std::string>>();
m_outputs_map = std::make_shared<std::map<std::string, std::string>>();
for (const auto& input : serving_default->second->inputs()) {
(*m_inputs_map)[input.second.name()] = input.first;
}
for (const auto& output : serving_default->second->outputs()) {
(*m_outputs_map)[output.second.name()] = output.first;
}
}
m_graph_def = std::make_shared<::tensorflow::GraphDef>(meta_graph.graph_def());
// Update variables map using information by resolving AssignVariableOp graph nodes
std::map<std::string, std::string> var_map;
map_assignvariable(m_graph_def, var_map);
for (auto var : var_map) {
m_variables_index->map_variable(var.first, var.second);
}
auto nodes_size = m_graph_def->node_size();
m_decoders.resize(static_cast<size_t>(nodes_size));
for (int node_ind = 0; node_ind < nodes_size; ++node_ind) {
m_decoders[node_ind] = std::make_shared<DecoderProto>(&m_graph_def->node(node_ind), m_graph_def);
}
// initialize a library map
auto num_funcs = m_graph_def->library().function_size();
for (int func_ind = 0; func_ind < num_funcs; ++func_ind) {
auto func = m_graph_def->library().function(func_ind);
auto func_name = func.signature().name();
m_library_map.insert(std::pair<std::string, int>(func_name, func_ind));
}
return true;
}
FRONT_END_GENERAL_CHECK(false, "Saved Model doesn't contain MetaGraph with requested tag");
return false;
}
/// \brief Reads relationship between VarHandleOp - RestoreV2 - AssignVariableOp and
/// stores this information in a provided key=value map. Where key - name of VarHandleOp,
/// value - long variable name which is stored in RestoreV2.
/// It needs to map VarHandleOp to right place in .index file.
/// \param[in] graph_def GraphDef object for analysis
/// \param[out] variables_map Map of variables found in graph_def
void map_assignvariable(const std::shared_ptr<::tensorflow::GraphDef> graph_def,
std::map<std::string, std::string>& variables_map) const;
}; // GraphIteratorSavedModel
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -55,7 +55,10 @@ public:
InputModelTFImpl(const GraphIterator::Ptr& graph_iterator, const ov::frontend::InputModel& input_model);
InputModelTFImpl(const GraphIterator::Ptr& graph_iterator,
const ov::frontend::InputModel& input_model,
const std::shared_ptr<TelemetryExtension>& telemetry);
const std::shared_ptr<TelemetryExtension>& telemetry,
const std::shared_ptr<SavedModelVariablesIndex>& variables_index,
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names,
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names);
std::vector<ov::frontend::Place::Ptr> get_inputs() const;
std::vector<ov::frontend::Place::Ptr> get_outputs() const;
ov::frontend::Place::Ptr get_place_by_tensor_name(const std::string& tensorName) const;
@ -79,6 +82,9 @@ public:
std::shared_ptr<InputModel> get_body_input_model(const std::string& body_model_name) const;
std::vector<std::string> get_input_names() const;
std::vector<std::string> get_output_names() const;
std::shared_ptr<SavedModelVariablesIndex> get_variables_index() const;
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_input_names() const;
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_output_names() const;
private:
void load_places();
@ -99,6 +105,10 @@ private:
std::shared_ptr<TelemetryExtension> m_telemetry;
std::shared_ptr<SavedModelVariablesIndex> m_variables_index;
std::shared_ptr<std::map<std::string, std::string>> m_saved_model_input_names;
std::shared_ptr<std::map<std::string, std::string>> m_saved_model_output_names;
// shows if some nodes might be deleted from graph
bool m_graph_changed = false;
};
@ -152,10 +162,10 @@ void InputModel::InputModelTFImpl::load_places() {
}
auto dtype_any = node_decoder->get_attribute("dtype");
auto placeholder_name = node_decoder->get_op_name();
FRONT_END_GENERAL_CHECK(
dtype_any.is<ov::element::Type>(),
"Incorrect input model: Placeholder node " + placeholder_name + " has unspecified type.");
auto type = dtype_any.as<ov::element::Type>();
ov::element::Type type = ov::element::dynamic;
if (dtype_any.is<ov::element::Type>()) {
type = dtype_any.as<ov::element::Type>();
}
std::vector<std::string> names = {op_name};
auto tensor_place = std::make_shared<TensorPlace>(m_input_model, pshape, type, names);
m_tensor_places[op_name] = tensor_place;
@ -202,6 +212,17 @@ void InputModel::InputModelTFImpl::load_places() {
m_outputs.push_back(output_place);
}
}
std::shared_ptr<SavedModelVariablesIndex> InputModel::InputModelTFImpl::get_variables_index() const {
return m_variables_index;
}
std::shared_ptr<std::map<std::string, std::string>> InputModel::InputModelTFImpl::get_saved_model_input_names() const {
return m_saved_model_input_names;
}
std::shared_ptr<std::map<std::string, std::string>> InputModel::InputModelTFImpl::get_saved_model_output_names() const {
return m_saved_model_output_names;
}
std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::get_op_places() const {
return topologically_sort_op_nodes();
@ -337,12 +358,19 @@ std::shared_ptr<InputModel> InputModel::InputModelTFImpl::get_body_input_model(
return std::make_shared<InputModel>(body_graph_iterator, m_telemetry);
}
InputModel::InputModelTFImpl::InputModelTFImpl(const GraphIterator::Ptr& graph_iterator,
const ov::frontend::InputModel& input_model,
const std::shared_ptr<TelemetryExtension>& telemetry)
InputModel::InputModelTFImpl::InputModelTFImpl(
const GraphIterator::Ptr& graph_iterator,
const ov::frontend::InputModel& input_model,
const std::shared_ptr<TelemetryExtension>& telemetry,
const std::shared_ptr<SavedModelVariablesIndex>& variables_index,
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names,
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names)
: m_graph_iterator(graph_iterator),
m_input_model(input_model),
m_telemetry(telemetry) {
m_telemetry(telemetry),
m_variables_index(variables_index),
m_saved_model_input_names(saved_model_input_names),
m_saved_model_output_names(saved_model_output_names) {
FRONT_END_GENERAL_CHECK(m_graph_iterator, "Null pointer specified for GraphIterator");
m_input_names = graph_iterator->get_input_names();
m_output_names = graph_iterator->get_output_names();
@ -445,8 +473,29 @@ void InputModel::InputModelTFImpl::set_tensor_value(ov::frontend::Place::Ptr pla
m_tensor_values[name] = constant;
}
InputModel::InputModel(const GraphIterator::Ptr& graph_iterator, const std::shared_ptr<TelemetryExtension>& telemetry)
: _impl{std::make_shared<InputModelTFImpl>(graph_iterator, *this, telemetry)} {}
InputModel::InputModel(const GraphIterator::Ptr& graph_iterator,
const std::shared_ptr<TelemetryExtension>& telemetry,
const std::shared_ptr<SavedModelVariablesIndex>& variables_index,
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names,
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names)
: _impl{std::make_shared<InputModelTFImpl>(graph_iterator,
*this,
telemetry,
variables_index,
saved_model_input_names,
saved_model_output_names)} {}
std::shared_ptr<SavedModelVariablesIndex> InputModel::get_variables_index() {
return _impl->get_variables_index();
}
std::shared_ptr<std::map<std::string, std::string>> InputModel::get_saved_model_input_names() const {
return _impl->get_saved_model_input_names();
}
std::shared_ptr<std::map<std::string, std::string>> InputModel::get_saved_model_output_names() const {
return _impl->get_saved_model_output_names();
}
std::vector<std::string> InputModel::get_input_names() const {
return _impl->get_input_names();

View File

@ -16,6 +16,7 @@ namespace tensorflow {
class OpPlace;
class TensorPlace;
class SavedModelVariablesIndex;
class InputModel : public ov::frontend::InputModel {
friend class TranslateSession;
@ -31,7 +32,10 @@ class InputModel : public ov::frontend::InputModel {
public:
explicit InputModel(const GraphIterator::Ptr& graph_iterator,
const std::shared_ptr<TelemetryExtension>& telemetry = {});
const std::shared_ptr<TelemetryExtension>& telemetry = {},
const std::shared_ptr<SavedModelVariablesIndex>& variables_index = {},
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names = nullptr,
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names = nullptr);
std::vector<ov::frontend::Place::Ptr> get_inputs() const override;
std::vector<ov::frontend::Place::Ptr> get_outputs() const override;
@ -45,6 +49,9 @@ public:
void set_element_type(const ov::frontend::Place::Ptr& place, const ov::element::Type&) override;
ov::element::Type get_element_type(const ov::frontend::Place::Ptr& place) const override;
void set_tensor_value(const ov::frontend::Place::Ptr& place, const void* value) override;
std::shared_ptr<SavedModelVariablesIndex> get_variables_index();
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_input_names() const;
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_output_names() const;
};
} // namespace tensorflow

View File

@ -0,0 +1,203 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "common_op_table.hpp"
#include "graph_iterator_saved_model.hpp"
#include "helper_ops/string_constant.hpp"
#include "helper_ops/unsupported_constant.hpp"
#include "input_model.hpp"
#include "openvino/opsets/opset8.hpp"
#include "tensor_bundle.pb.h"
using namespace std;
using namespace ov::opset8;
using namespace ov;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
// Reading variable from shard file
template <typename T>
static std::shared_ptr<ov::Node> read_variable(std::shared_ptr<SavedModelVariablesIndex> var_index,
const ov::element::Type ov_type,
const ov::Shape shape,
const ::tensorflow::BundleEntryProto& entry,
const NodeContext& node) {
std::vector<T> var_data;
google::protobuf::int64 size = 1;
for (uint64_t i = 0; i < shape.size(); ++i) {
size *= static_cast<google::protobuf::int64>(shape[i]);
}
var_data.resize(size);
TENSORFLOW_OP_VALIDATION(node,
size == static_cast<google::protobuf::int64>(entry.size() / sizeof(T)),
"[TensorFlow Frontend] Internal error: Available data size isn't equal to calculated.");
auto fs = var_index->get_data_file(entry.shard_id());
if (!fs.get()) {
TENSORFLOW_OP_VALIDATION(node, var_index, "[TensorFlow Frontend] Internal error: Cannot get shard file.");
}
fs->seekg(entry.offset(), std::ios::beg);
fs->read(reinterpret_cast<char*>(var_data.data()), entry.size());
return std::make_shared<Constant>(ov_type, shape, var_data);
}
OutputVector translate_varhandle_op(const NodeContext& node) {
default_op_checks(node, 0, {"VarHandleOp"});
auto translate_session = node.get_translate_session();
TENSORFLOW_OP_VALIDATION(node,
translate_session,
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
auto var_index = model->get_variables_index();
auto ov_type = node.get_attribute<element::Type>("dtype");
std::shared_ptr<Node> const_node;
if (ov_type == element::undefined) {
const_node = std::make_shared<UnsupportedConstant>();
} else {
// Getting variable description from variables index
const char* entry_data = nullptr;
size_t entry_size = 0;
auto var_name = node.get_name();
auto shape = node.get_attribute<::ov::PartialShape>("shape").get_shape();
bool result = var_index->get_mapped_variable(var_name, &entry_data, &entry_size);
TENSORFLOW_OP_VALIDATION(node, result, "[TensorFlow Frontend] Internal error: Cannot find requested variable.");
::tensorflow::BundleEntryProto entry;
TENSORFLOW_OP_VALIDATION(node,
entry.ParseFromArray(entry_data, static_cast<int>(entry_size)),
"[TensorFlow Frontend] Internal error: Cannot get read bundle entry.");
switch (ov_type) {
case ov::element::u8:
const_node = read_variable<uint8_t>(var_index, ov_type, shape, entry, node);
break;
case ov::element::i8:
const_node = read_variable<int8_t>(var_index, ov_type, shape, entry, node);
break;
case ov::element::i16:
const_node = read_variable<int16_t>(var_index, ov_type, shape, entry, node);
break;
case ov::element::i32:
const_node = read_variable<int32_t>(var_index, ov_type, shape, entry, node);
break;
case ov::element::i64:
const_node = read_variable<int64_t>(var_index, ov_type, shape, entry, node);
break;
case ov::element::f16:
const_node = read_variable<float16>(var_index, ov_type, shape, entry, node);
break;
case ov::element::f32:
const_node = read_variable<float>(var_index, ov_type, shape, entry, node);
break;
case ov::element::f64:
const_node = read_variable<double>(var_index, ov_type, shape, entry, node);
break;
case ov::element::bf16:
const_node = read_variable<bfloat16>(var_index, ov_type, shape, entry, node);
break;
default:
FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name());
}
}
set_node_name(node.get_name(), const_node);
return {const_node};
}
OutputVector translate_varisinitialized_op(const NodeContext& node) {
auto const_node = std::make_shared<Constant>(::ov::element::boolean, Shape{}, true);
set_node_name(node.get_name(), const_node);
return {const_node};
}
OutputVector translate_readvariable_op(const NodeContext& node) {
default_op_checks(node, 1, {"ReadVariableOp"});
// Documentation says it should return only one tensor with dtype, but
// _output_shapes in a vector of shapes and it means it could have multiple outputs
// https://www.tensorflow.org/api_docs/python/tf/raw_ops/ReadVariableOp
auto output_shapes = node.get_attribute<std::vector<::ov::PartialShape>>("_output_shapes");
OutputVector outs = {};
for (size_t i = 0; i < output_shapes.size(); ++i) {
std::shared_ptr<ov::Node> output_node;
if (node.get_input(0).get_partial_shape().is_static() &&
output_shapes[i].get_shape() != node.get_input(0).get_shape()) {
auto reshape_shape = make_shared<Constant>(ov::element::i32, output_shapes[i].get_shape());
output_node = make_shared<Reshape>(node.get_input(0), reshape_shape, false);
} else {
output_node = node.get_input(0).get_node_shared_ptr();
}
if (i == 0) {
set_out_name(node.get_name(), output_node);
set_out_name(node.get_name() + ":" + "0", output_node);
} else {
set_node_name(node.get_name() + ":" + std::to_string(i), output_node);
}
outs.push_back(output_node);
}
return outs;
}
OutputVector translate_assignvariable_op(const NodeContext& node) {
default_op_checks(node, 2, {"AssignVariableOp"});
auto assignvariableop_node = std::make_shared<UnsupportedConstant>();
set_node_name(node.get_name(), assignvariableop_node);
return {assignvariableop_node};
}
OutputVector translate_restorev2_op(const NodeContext& node) {
default_op_checks(node, 3, {"RestoreV2"});
auto translate_session = node.get_translate_session();
TENSORFLOW_OP_VALIDATION(node,
translate_session,
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
auto model = reinterpret_cast<ov::frontend::tensorflow::InputModel*>(translate_session->get_input_model().get());
auto var_index = model->get_variables_index();
auto tensor_names =
reinterpret_cast<StringConstant*>(node.get_input(1).get_node())->get_data().as<std::vector<std::string>>();
auto tensor_types = node.get_attribute<std::vector<ov::element::Type>>("dtypes");
OutputVector outs = {};
for (size_t i = 0; i < tensor_names.size(); ++i) {
auto const_node = std::make_shared<UnsupportedConstant>();
if (i == 0)
set_node_name(node.get_name(), const_node);
else
set_node_name(node.get_name() + ":" + std::to_string(i), const_node);
outs.push_back(const_node);
}
return outs;
}
OutputVector translate_staticregexfullmatch_op(const NodeContext& node) {
default_op_checks(node, 1, {"StaticRegexFullMatch"});
// auto pattern = node.get_attribute_as_any("pattern").as<std::string>();
auto const_node = std::make_shared<Constant>(ov::element::boolean, ov::Shape{}, true);
set_node_name(node.get_name(), const_node);
return {const_node};
}
OutputVector translate_stringjoin_op(const NodeContext& node) {
default_op_checks(node, 1, {"StringJoin"});
auto const_node = std::make_shared<UnsupportedConstant>();
set_node_name(node.get_name(), const_node);
return {const_node};
}
OutputVector translate_mergev2checkpoint_op(const NodeContext& node) {
default_op_checks(node, 1, {"MergeV2Checkpoint"});
auto const_node = std::make_shared<UnsupportedConstant>();
set_node_name(node.get_name(), const_node);
return {const_node};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -32,6 +32,14 @@ TF_OP_CONVERTER(translate_queue_dequeue_many_op);
TF_OP_CONVERTER(translate_sparse_fill_empty_rows_op);
TF_OP_CONVERTER(translate_sparse_reshape_op);
TF_OP_CONVERTER(translate_sparse_segment_sum_op);
TF_OP_CONVERTER(translate_varisinitialized_op);
TF_OP_CONVERTER(translate_readvariable_op);
TF_OP_CONVERTER(translate_assignvariable_op);
TF_OP_CONVERTER(translate_varhandle_op);
TF_OP_CONVERTER(translate_restorev2_op);
TF_OP_CONVERTER(translate_staticregexfullmatch_op);
TF_OP_CONVERTER(translate_stringjoin_op);
TF_OP_CONVERTER(translate_mergev2checkpoint_op);
TF_OP_CONVERTER(translate_while_op);
const std::map<std::string, CreatorFunction> get_supported_ops() {
@ -246,6 +254,15 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"TopK", translate_top_k_op},
{"TopKV2", translate_top_k_v2_op},
{"Transpose", translate_transpose_op},
{"ReadVariableOp", translate_readvariable_op},
{"AssignVariableOp", translate_assignvariable_op},
{"VarIsInitializedOp", translate_varisinitialized_op},
{"VarHandleOp", translate_varhandle_op},
{"RestoreV2", translate_restorev2_op},
{"StaticRegexFullMatch", translate_staticregexfullmatch_op},
{"StringJoin", translate_stringjoin_op},
{"ShardedFilename", translate_identity_op},
{"MergeV2Checkpoints", translate_identity_op},
{"Unpack", translate_unpack_op},
{"While", translate_while_op},
{"Where", translate_where_op},

View File

@ -0,0 +1,159 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Modification Copyright (C) 2023 Intel Corporation
syntax = "proto3";
package google.protobuf;
option csharp_namespace = "Google.Protobuf.WellKnownTypes";
option go_package = "google.golang.org/protobuf/types/known/anypb";
option java_package = "com.google.protobuf";
option java_outer_classname = "AnyProto";
option java_multiple_files = true;
option objc_class_prefix = "GPB";
// `Any` contains an arbitrary serialized protocol buffer message along with a
// URL that describes the type of the serialized message.
//
// Protobuf library provides support to pack/unpack Any values in the form
// of utility functions or additional generated methods of the Any type.
//
// Example 1: Pack and unpack a message in C++.
//
// Foo foo = ...;
// Any any;
// any.PackFrom(foo);
// ...
// if (any.UnpackTo(&foo)) {
// ...
// }
//
// Example 2: Pack and unpack a message in Java.
//
// Foo foo = ...;
// Any any = Any.pack(foo);
// ...
// if (any.is(Foo.class)) {
// foo = any.unpack(Foo.class);
// }
//
// Example 3: Pack and unpack a message in Python.
//
// foo = Foo(...)
// any = Any()
// any.Pack(foo)
// ...
// if any.Is(Foo.DESCRIPTOR):
// any.Unpack(foo)
// ...
//
// Example 4: Pack and unpack a message in Go
//
// foo := &pb.Foo{...}
// any, err := anypb.New(foo)
// if err != nil {
// ...
// }
// ...
// foo := &pb.Foo{}
// if err := any.UnmarshalTo(foo); err != nil {
// ...
// }
//
// The pack methods provided by protobuf library will by default use
// 'type.googleapis.com/full.type.name' as the type URL and the unpack
// methods only use the fully qualified type name after the last '/'
// in the type URL, for example "foo.bar.com/x/y.z" will yield type
// name "y.z".
//
//
// JSON
// ====
// The JSON representation of an `Any` value uses the regular
// representation of the deserialized, embedded message, with an
// additional field `@type` which contains the type URL. Example:
//
// package google.profile;
// message Person {
// string first_name = 1;
// string last_name = 2;
// }
//
// {
// "@type": "type.googleapis.com/google.profile.Person",
// "firstName": <string>,
// "lastName": <string>
// }
//
// If the embedded message type is well-known and has a custom JSON
// representation, that representation will be embedded adding a field
// `value` which holds the custom JSON in addition to the `@type`
// field. Example (for message [google.protobuf.Duration][]):
//
// {
// "@type": "type.googleapis.com/google.protobuf.Duration",
// "value": "1.212s"
// }
//
message Any {
// A URL/resource name that uniquely identifies the type of the serialized
// protocol buffer message. This string must contain at least
// one "/" character. The last segment of the URL's path must represent
// the fully qualified name of the type (as in
// `path/google.protobuf.Duration`). The name should be in a canonical form
// (e.g., leading "." is not accepted).
//
// In practice, teams usually precompile into the binary all types that they
// expect it to use in the context of Any. However, for URLs which use the
// scheme `http`, `https`, or no scheme, one can optionally set up a type
// server that maps type URLs to message definitions as follows:
//
// * If no scheme is provided, `https` is assumed.
// * An HTTP GET on the URL must yield a [google.protobuf.Type][]
// value in binary format, or produce an error.
// * Applications are allowed to cache lookup results based on the
// URL, or have them precompiled into a binary to avoid any
// lookup. Therefore, binary compatibility needs to be preserved
// on changes to types. (Use versioned type names to manage
// breaking changes.)
//
// Note: this functionality is not currently available in the official
// protobuf release, and it is not used for type URLs beginning with
// type.googleapis.com.
//
// Schemes other than `http`, `https` (or the empty scheme) might be
// used with implementation specific semantics.
//
string type_url = 1;
// Must be a valid serialized protocol buffer of the above specified type.
bytes value = 2;
}

View File

@ -0,0 +1,351 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
// Modification Copyright (C) 2018-2023 Intel Corporation
syntax = "proto3";
package tensorflow;
import "any.proto";
import "graph.proto";
import "op_def.proto";
import "tensor_shape.proto";
import "types.proto";
import "saved_object_graph.proto";
import "saver.proto";
import "struct.proto";
option cc_enable_arenas = true;
option java_outer_classname = "MetaGraphProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
// Protocol buffer containing the following which are necessary to restart
// training, run inference. It can be used to serialize/de-serialize memory
// objects necessary for running computation in a graph when crossing the
// process boundary. It can be used for long term storage of graphs,
// cross-language execution of graphs, etc.
// MetaInfoDef
// GraphDef
// SaverDef
// CollectionDef
// TensorInfo
// SignatureDef
message MetaGraphDef {
// Meta information regarding the graph to be exported. To be used by users
// of this protocol buffer to encode information regarding their meta graph.
message MetaInfoDef {
// User specified Version string. Can be the name of the model and revision,
// steps this model has been trained to, etc.
string meta_graph_version = 1;
// A copy of the OpDefs used by the producer of this graph_def.
// Descriptions and Ops not used in graph_def are stripped out.
OpList stripped_op_list = 2;
// A serialized protobuf. Can be the time this meta graph is created, or
// modified, or name of the model.
google.protobuf.Any any_info = 3;
// User supplied tag(s) on the meta_graph and included graph_def.
//
// MetaGraphDefs should be tagged with their capabilities or use-cases.
// Examples: "train", "serve", "gpu", "tpu", etc.
// These tags enable loaders to access the MetaGraph(s) appropriate for a
// specific use-case or runtime environment.
repeated string tags = 4;
// The __version__ string of the tensorflow build used to write this graph.
// This will be populated by the framework, which will overwrite any user
// supplied value.
string tensorflow_version = 5;
// The __git_version__ string of the tensorflow build used to write this
// graph. This will be populated by the framework, which will overwrite any
// user supplied value.
string tensorflow_git_version = 6;
// A flag to denote whether default-valued attrs have been stripped from
// the nodes in this graph_def.
bool stripped_default_attrs = 7;
// FunctionDef name to aliases mapping.
map<string, string> function_aliases = 8;
}
MetaInfoDef meta_info_def = 1;
// GraphDef.
GraphDef graph_def = 2;
// SaverDef.
SaverDef saver_def = 3;
// collection_def: Map from collection name to collections.
// See CollectionDef section for details.
map<string, CollectionDef> collection_def = 4;
// signature_def: Map from user supplied key for a signature to a single
// SignatureDef.
map<string, SignatureDef> signature_def = 5;
// Asset file def to be used with the defined graph.
repeated AssetFileDef asset_file_def = 6;
// Extra information about the structure of functions and stateful objects.
SavedObjectGraph object_graph_def = 7;
}
// CollectionDef should cover most collections.
// To add a user-defined collection, do one of the following:
// 1. For simple data types, such as string, int, float:
// tf.add_to_collection("your_collection_name", your_simple_value)
// strings will be stored as bytes_list.
//
// 2. For Protobuf types, there are three ways to add them:
// 1) tf.add_to_collection("your_collection_name",
// your_proto.SerializeToString())
//
// collection_def {
// key: "user_defined_bytes_collection"
// value {
// bytes_list {
// value: "queue_name: \"test_queue\"\n"
// }
// }
// }
//
// or
//
// 2) tf.add_to_collection("your_collection_name", str(your_proto))
//
// collection_def {
// key: "user_defined_string_collection"
// value {
// bytes_list {
// value: "\n\ntest_queue"
// }
// }
// }
//
// or
//
// 3) any_buf = any_pb2.Any()
// tf.add_to_collection("your_collection_name",
// any_buf.Pack(your_proto))
//
// collection_def {
// key: "user_defined_any_collection"
// value {
// any_list {
// value {
// type_url: "type.googleapis.com/tensorflow.QueueRunnerDef"
// value: "\n\ntest_queue"
// }
// }
// }
// }
//
// 3. For Python objects, implement to_proto() and from_proto(), and register
// them in the following manner:
// ops.register_proto_function("your_collection_name",
// proto_type,
// to_proto=YourPythonObject.to_proto,
// from_proto=YourPythonObject.from_proto)
// These functions will be invoked to serialize and de-serialize the
// collection. For example,
// ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES,
// proto_type=variable_pb2.VariableDef,
// to_proto=Variable.to_proto,
// from_proto=Variable.from_proto)
message CollectionDef {
// NodeList is used for collecting nodes in graph. For example
// collection_def {
// key: "summaries"
// value {
// node_list {
// value: "input_producer/ScalarSummary:0"
// value: "shuffle_batch/ScalarSummary:0"
// value: "ImageSummary:0"
// }
// }
message NodeList {
repeated string value = 1;
}
// BytesList is used for collecting strings and serialized protobufs. For
// example:
// collection_def {
// key: "trainable_variables"
// value {
// bytes_list {
// value: "\n\017conv1/weights:0\022\024conv1/weights/Assign
// \032\024conv1/weights/read:0"
// value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032
// \023conv1/biases/read:0"
// }
// }
// }
message BytesList {
repeated bytes value = 1;
}
// Int64List is used for collecting int, int64 and long values.
message Int64List {
repeated int64 value = 1 [packed = true];
}
// FloatList is used for collecting float values.
message FloatList {
repeated float value = 1 [packed = true];
}
// AnyList is used for collecting Any protos.
message AnyList {
repeated google.protobuf.Any value = 1;
}
oneof kind {
NodeList node_list = 1;
BytesList bytes_list = 2;
Int64List int64_list = 3;
FloatList float_list = 4;
AnyList any_list = 5;
}
}
// Information about a Tensor necessary for feeding or retrieval.
message TensorInfo {
// For sparse tensors, The COO encoding stores a triple of values, indices,
// and shape.
message CooSparse {
// The shape of the values Tensor is [?]. Its dtype must be the dtype of
// the SparseTensor as a whole, given in the enclosing TensorInfo.
string values_tensor_name = 1;
// The indices Tensor must have dtype int64 and shape [?, ?].
string indices_tensor_name = 2;
// The dynamic logical shape represented by the SparseTensor is recorded in
// the Tensor referenced here. It must have dtype int64 and shape [?].
string dense_shape_tensor_name = 3;
}
// Generic encoding for composite tensors.
message CompositeTensor {
// The serialized TypeSpec for the composite tensor.
TypeSpecProto type_spec = 1;
// A TensorInfo for each flattened component tensor.
repeated TensorInfo components = 2;
}
oneof encoding {
// For dense `Tensor`s, the name of the tensor in the graph.
string name = 1;
// There are many possible encodings of sparse matrices
// (https://en.wikipedia.org/wiki/Sparse_matrix). Currently, TensorFlow
// uses only the COO encoding. This is supported and documented in the
// SparseTensor Python class.
CooSparse coo_sparse = 4;
// Generic encoding for CompositeTensors.
CompositeTensor composite_tensor = 5;
}
DataType dtype = 2;
// The static shape should be recorded here, to the extent that it can
// be known in advance. In the case of a SparseTensor, this field describes
// the logical shape of the represented tensor (aka dense_shape).
TensorShapeProto tensor_shape = 3;
}
// SignatureDef defines the signature of a computation supported by a TensorFlow
// graph.
//
// For example, a model with two loss computations, sharing a single input,
// might have the following signature_def map, in a MetaGraphDef message.
//
// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key,
// output key, and method_name are identical, and will be used by system(s) that
// implement or rely upon this particular loss method. The output tensor names
// differ, demonstrating how different outputs can exist for the same method.
//
// signature_def {
// key: "loss_A"
// value {
// inputs {
// key: "input"
// value {
// name: "input:0"
// dtype: DT_STRING
// tensor_shape: ...
// }
// }
// outputs {
// key: "loss_output"
// value {
// name: "loss_output_A:0"
// dtype: DT_FLOAT
// tensor_shape: ...
// }
// }
// method_name: "some/package/compute_loss"
// }
// ...
// }
// signature_def {
// key: "loss_B"
// value {
// inputs {
// key: "input"
// value {
// name: "input:0"
// dtype: DT_STRING
// tensor_shape: ...
// }
// }
// outputs {
// key: "loss_output"
// value {
// name: "loss_output_B:0"
// dtype: DT_FLOAT
// tensor_shape: ...
// }
// }
// method_name: "some/package/compute_loss"
// }
// ...
// }
message SignatureDef {
// Named input parameters.
map<string, TensorInfo> inputs = 1;
// Named output parameters.
map<string, TensorInfo> outputs = 2;
// Extensible method_name information enabling third-party users to mark a
// SignatureDef as supporting a particular method. This enables producers and
// consumers of SignatureDefs, e.g. a model definition library and a serving
// library to have a clear hand-off regarding the semantics of a computation.
//
// Note that multiple SignatureDefs in a single MetaGraphDef may have the same
// method_name. This is commonly used to support multi-headed computation,
// where a single graph computation may return multiple results.
string method_name = 3;
}
// An asset file def for a single file or a set of sharded files with the same
// name.
message AssetFileDef {
// The tensor to bind the asset filename to.
TensorInfo tensor_info = 1;
// The filename within an assets directory. Note: does not include the path
// prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename
// would be "vocab.txt".
string filename = 2;
}

View File

@ -0,0 +1,35 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
// Modification Copyright (C) 2023 Intel Corporation
syntax = "proto3";
package tensorflow;
import "meta_graph.proto";
option cc_enable_arenas = true;
option java_outer_classname = "SavedModelProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
// SavedModel is the high level serialization format for TensorFlow Models.
// See [todo: doc links, similar to session_bundle] for more information.
message SavedModel {
// The schema version of the SavedModel instance. Used for versioning when
// making future changes to the specification/implementation. Initial value
// at release will be 1.
int64 saved_model_schema_version = 1;
// One or more MetaGraphs.
repeated MetaGraphDef meta_graphs = 2;
}

View File

@ -0,0 +1,263 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
// Modification Copyright (C) 2023 Intel Corporation
syntax = "proto3";
package tensorflow;
import "any.proto";
import "tensor_shape.proto";
import "types.proto";
import "variable.proto";
import "versions.proto";
import "struct.proto";
import "trackable_object_graph.proto";
option cc_enable_arenas = true;
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
// A SavedObjectGraph is part of object-based SavedModels in TF 2.0. It
// describes the directed graph of Python objects (or equivalent in other
// languages) that make up a model, with nodes[0] at the root.
// SavedObjectGraph shares some structure with TrackableObjectGraph, but
// SavedObjectGraph belongs to the MetaGraph and contains pointers to functions
// and type information, while TrackableObjectGraph lives in the checkpoint
// and contains pointers only to variable values.
message SavedObjectGraph {
// Flattened list of objects in the object graph.
//
// The position of the object in this list indicates its id.
// Nodes[0] is considered the root node.
repeated SavedObject nodes = 1;
// Information about captures and output structures in concrete functions.
// Referenced from SavedBareConcreteFunction and SavedFunction.
map<string, SavedConcreteFunction> concrete_functions = 2;
}
message SavedObject {
// Objects which this object depends on: named edges in the dependency
// graph.
//
// Note: All kinds of SavedObject may have children, except
// "constant" and "captured_tensor".
repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;
// Ordered list of dependencies that must be loaded before this object.
// SavedModel loads with the bottom-up approach, by first creating all objects
// (in the order defined by the dependencies), then connecting the edges.
repeated TrackableObjectGraph.TrackableObject.ObjectReference dependencies =
15;
// Removed when forking SavedObject from TrackableObjectGraph.
reserved "attributes";
reserved 2;
// Slot variables owned by this object. This describes the three-way
// (optimizer, variable, slot variable) relationship; none of the three
// depend on the others directly.
//
// Note: currently only valid if kind == "user_object".
repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
slot_variables = 3;
oneof kind {
SavedUserObject user_object = 4;
SavedAsset asset = 5;
SavedFunction function = 6;
SavedVariable variable = 7;
SavedBareConcreteFunction bare_concrete_function = 8;
SavedConstant constant = 9;
SavedResource resource = 10;
CapturedTensor captured_tensor = 12;
}
// Stores the functions used to save and restore this object. At most one of
// `saveable_objects` or `registered_saver` is defined for each SavedObject.
// See the comment below for the difference between SaveableObject and
// registered savers.
map<string, SaveableObject> saveable_objects = 11;
// The fields below are filled when the user serializes a registered Trackable
// class or an object with a registered saver function.
//
// Registered classes may save additional metadata and supersede the
// default loading process where nodes are recreated from the proto.
// If the registered class cannot be found, then the object will load as one
// one of the default trackable objects: Autotrackable (a class similar to
// tf.Module), tf.function, or tf.Variable.
//
// Unlike SaveableObjects, which store the functions for saving and restoring
// from tensors, registered savers allow Trackables to write checkpoint shards
// directly (e.g. for performance or coordination reasons).
// *All registered savers must be available when loading the SavedModel.*
// The name of the registered class of the form "{package}.{class_name}".
// This field is used to search for the registered class at loading time.
string registered_name = 13;
// The user-generated proto storing metadata for this object, to be passed to
// the registered classes's _deserialize_from_proto method when this object is
// loaded from the SavedModel.
google.protobuf.Any serialized_user_proto = 14;
// String name of the registered saver. At most one of `saveable_objects` or
// `registered_saver` is defined for each SavedObject.
string registered_saver = 16;
}
// A SavedUserObject is an object (in the object-oriented language of the
// TensorFlow program) of some user- or framework-defined class other than
// those handled specifically by the other kinds of SavedObjects.
//
// This object cannot be evaluated as a tensor, and therefore cannot be bound
// to an input of a function.
message SavedUserObject {
// Corresponds to a registration of the type to use in the loading program.
string identifier = 1;
// Version information from the producer of this SavedUserObject.
VersionDef version = 2;
// Metadata for deserializing this object.
//
// Deprecated! At the time of deprecation, Keras was the only user of this
// field, and its saving and loading code will be updated shortly.
// Please save your application-specific metadata to a separate file.
string metadata = 3 [deprecated = true];
}
// A SavedAsset points to an asset in the MetaGraph.
//
// When bound to a function this object evaluates to a tensor with the absolute
// filename. Users should not depend on a particular part of the filename to
// remain stable (e.g. basename could be changed).
message SavedAsset {
// Index into `MetaGraphDef.asset_file_def[]` that describes the Asset.
//
// Only the field `AssetFileDef.filename` is used. Other fields, such as
// `AssetFileDef.tensor_info`, MUST be ignored.
int32 asset_file_def_index = 1;
}
// A function with multiple signatures, possibly with non-Tensor arguments.
message SavedFunction {
repeated string concrete_functions = 1;
FunctionSpec function_spec = 2;
}
message CapturedTensor {
// Name of captured tensor
string name = 1;
// Name of concrete function which contains the computed graph tensor.
string concrete_function = 2;
}
// Stores low-level information about a concrete function. Referenced in either
// a SavedFunction or a SavedBareConcreteFunction.
message SavedConcreteFunction {
repeated int32 bound_inputs = 2;
// Input in canonicalized form that was received to create this concrete
// function.
StructuredValue canonicalized_input_signature = 3;
// Output that was the return value of this function after replacing all
// Tensors with TensorSpecs. This can be an arbitrary nested function and will
// be used to reconstruct the full structure from pure tensors.
StructuredValue output_signature = 4;
}
message SavedBareConcreteFunction {
// Identifies a SavedConcreteFunction.
string concrete_function_name = 1;
// A sequence of unique strings, one per Tensor argument.
repeated string argument_keywords = 2;
// The prefix of `argument_keywords` which may be identified by position.
int64 allowed_positional_arguments = 3;
// The spec of the function that this ConcreteFunction is traced from. This
// allows the ConcreteFunction to be called with nest structure inputs. This
// field may not be populated. If this field is absent, the concrete function
// can only be called with flat inputs.
// TODO(b/169361281): support calling saved ConcreteFunction with structured
// inputs in C++ SavedModel API.
FunctionSpec function_spec = 4;
}
message SavedConstant {
// An Operation name for a ConstantOp in this SavedObjectGraph's MetaGraph.
string operation = 1;
}
// Represents a Variable that is initialized by loading the contents from the
// checkpoint.
message SavedVariable {
DataType dtype = 1;
TensorShapeProto shape = 2;
bool trainable = 3;
VariableSynchronization synchronization = 4;
VariableAggregation aggregation = 5;
string name = 6;
string device = 7;
// List of component variables for a distributed variable.
//
// When this field is non-empty, the SavedVariable will be assumed
// to be a distributed variable defined by the components listed here.
//
// This is only supported by experimental loaders at the moment.
repeated SavedVariable experimental_distributed_variable_components = 8;
}
// Represents `FunctionSpec` used in `Function`. This represents a
// function that has been wrapped as a TensorFlow `Function`.
message FunctionSpec {
// Full arg spec from inspect.getfullargspec().
StructuredValue fullargspec = 1;
// Whether this represents a class method.
bool is_method = 2;
// The input signature, if specified.
StructuredValue input_signature = 5;
// Whether the function should be compiled by XLA.
//
// The public interface to `tf.function` uses an optional boolean to
// represent three distinct states for this field. Unfortunately, proto3
// removes the ability to explicitly check for the presence or absence of a
// field, so we instead map to an enum.
//
// See `tf.function` for details.
enum JitCompile {
DEFAULT = 0;
ON = 1;
OFF = 2;
}
JitCompile jit_compile = 6;
reserved 3, 4;
}
// A SavedResource represents a TF object that holds state during its lifetime.
// An object of this type can have a reference to a:
// create_resource() and an initialize() function.
message SavedResource {
// A device specification indicating a required placement for the resource
// creation function, e.g. "CPU". An empty string allows the user to select a
// device.
string device = 1;
}
message SaveableObject {
// Node ids of concrete functions for saving and loading from a checkpoint.
// These functions save and restore directly from tensors.
int32 save_function = 2;
int32 restore_function = 3;
}

View File

@ -0,0 +1,96 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
// Modification Copyright (C) 2018-2023 Intel Corporation
// Protocol buffers for saved tensor slices. It's used for the brain tensor
// ops checkpoints and the V3 checkpoints in dist_belief.
// A checkpoint file is an sstable. The value for each record is a serialized
// SavedTensorSlices message (defined below).
//
// Each checkpoint file has a record with the empty key (""), which corresponds
// to a SavedTensorSlices message that contains a "meta", that serves as a
// table of contents on all the tensor slices saved in this file. Since the key
// is "", it's always the first record in each file.
//
// Each of the rest of the records in a checkpoint stores the raw data of a
// particular tensor slice, in SavedSlice format. The corresponding key is an
// ordered code that encodes the name of the tensor and the slice
// information. The name is also stored in the SaveSlice message for ease of
// debugging and manual examination.
// Modification Copyright (C) 2023 Intel Corporation
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "SavedTensorSliceProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.util";
import "tensor_shape.proto";
import "tensor_slice.proto";
import "tensor.proto";
import "types.proto";
import "versions.proto";
// Metadata describing the set of slices of the same tensor saved in a
// checkpoint file.
message SavedSliceMeta {
// Name of the tensor.
string name = 1;
// Shape of the tensor
TensorShapeProto shape = 2;
// Type of the tensor
DataType type = 3;
// Explicit list of slices saved in the checkpoint file.
repeated TensorSliceProto slice = 4;
};
// Metadata describing the set of tensor slices saved in a checkpoint file.
// It is always stored at the beginning of each checkpoint file.
message SavedTensorSliceMeta {
// Each SavedSliceMeta describes the slices for one tensor.
repeated SavedSliceMeta tensor = 1;
// Compatibility version of this checkpoint. See core/public/version.h
// for version history.
VersionDef versions = 2;
};
// Saved tensor slice: it stores the name of the tensors, the slice, and the
// raw data.
message SavedSlice {
// Name of the tensor that this slice belongs to. This must be identical to
// the name used to encode the key for this record.
string name = 1;
// Extent of the slice. Must have one entry for each of the dimension of the
// tensor that this slice belongs to.
TensorSliceProto slice = 2;
// The raw data of the slice is stored as a TensorProto. Only raw data are
// stored (we don't fill in fields such as dtype or tensor_shape).
TensorProto data = 3;
};
// Each record in a v3 checkpoint file is a serialized SavedTensorSlices
// message.
message SavedTensorSlices {
// This is only present at the first item of each checkpoint file and serves
// as a table of contents, listing all the tensor slices saved in this file.
SavedTensorSliceMeta meta = 1;
// This exists in all but the first item of each checkpoint file.
SavedSlice data = 2;
};

View File

@ -0,0 +1,60 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
// Modification Copyright (C) 2023 Intel Corporation
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "SaverProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.util";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
// Protocol buffer representing the configuration of a Saver.
message SaverDef {
// The name of the tensor in which to specify the filename when saving or
// restoring a model checkpoint.
string filename_tensor_name = 1;
// The operation to run when saving a model checkpoint.
string save_tensor_name = 2;
// The operation to run when restoring a model checkpoint.
string restore_op_name = 3;
// Maximum number of checkpoints to keep. If 0, no checkpoints are deleted.
int32 max_to_keep = 4;
// Shard the save files, one per device that has Variable nodes.
bool sharded = 5;
// How often to keep an additional checkpoint. If not specified, only the last
// "max_to_keep" checkpoints are kept; if specified, in addition to keeping
// the last "max_to_keep" checkpoints, an additional checkpoint will be kept
// for every n hours of training.
float keep_checkpoint_every_n_hours = 6;
// A version number that identifies a different on-disk checkpoint format.
// Usually, each subclass of BaseSaverBuilder works with a particular
// version/format. However, it is possible that the same builder may be
// upgraded to support a newer checkpoint format in the future.
enum CheckpointFormatVersion {
// Internal legacy format.
LEGACY = 0;
// Deprecated format: tf.Saver() which works with tensorflow::table::Table.
V1 = 1;
// Current format: more efficient.
V2 = 2;
}
CheckpointFormatVersion version = 7;
}

View File

@ -0,0 +1,172 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
// Modification Copyright (C) 2023 Intel Corporation
syntax = "proto3";
package tensorflow;
import "tensor.proto";
import "tensor_shape.proto";
import "types.proto";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
// `StructuredValue` represents a dynamically typed value representing various
// data structures that are inspired by Python data structures typically used in
// TensorFlow functions as inputs and outputs.
//
// For example when saving a Layer there may be a `training` argument. If the
// user passes a boolean True/False, that switches between two concrete
// TensorFlow functions. In order to switch between them in the same way after
// loading the SavedModel, we need to represent "True" and "False".
//
// A more advanced example might be a function which takes a list of
// dictionaries mapping from strings to Tensors. In order to map from
// user-specified arguments `[{"a": tf.constant(1.)}, {"q": tf.constant(3.)}]`
// after load to the right saved TensorFlow function, we need to represent the
// nested structure and the strings, recording that we have a trace for anything
// matching `[{"a": tf.TensorSpec(None, tf.float32)}, {"q": tf.TensorSpec([],
// tf.float64)}]` as an example.
//
// Likewise functions may return nested structures of Tensors, for example
// returning a dictionary mapping from strings to Tensors. In order for the
// loaded function to return the same structure we need to serialize it.
//
// This is an ergonomic aid for working with loaded SavedModels, not a promise
// to serialize all possible function signatures. For example we do not expect
// to pickle generic Python objects, and ideally we'd stay language-agnostic.
message StructuredValue {
// The kind of value.
oneof kind {
// Represents None.
NoneValue none_value = 1;
// Represents a double-precision floating-point value (a Python `float`).
double float64_value = 11;
// Represents a signed integer value, limited to 64 bits.
// Larger values from Python's arbitrary-precision integers are unsupported.
sint64 int64_value = 12;
// Represents a string of Unicode characters stored in a Python `str`.
// In Python 3, this is exactly what type `str` is.
// In Python 2, this is the UTF-8 encoding of the characters.
// For strings with ASCII characters only (as often used in TensorFlow code)
// there is effectively no difference between the language versions.
// The obsolescent `unicode` type of Python 2 is not supported here.
string string_value = 13;
// Represents a boolean value.
bool bool_value = 14;
// Represents a TensorShape.
tensorflow.TensorShapeProto tensor_shape_value = 31;
// Represents an enum value for dtype.
tensorflow.DataType tensor_dtype_value = 32;
// Represents a value for tf.TensorSpec.
TensorSpecProto tensor_spec_value = 33;
// Represents a value for tf.TypeSpec.
TypeSpecProto type_spec_value = 34;
// Represents a value for tf.BoundedTensorSpec.
BoundedTensorSpecProto bounded_tensor_spec_value = 35;
// Represents a list of `Value`.
ListValue list_value = 51;
// Represents a tuple of `Value`.
TupleValue tuple_value = 52;
// Represents a dict `Value`.
DictValue dict_value = 53;
// Represents Python's namedtuple.
NamedTupleValue named_tuple_value = 54;
}
}
// Represents None.
message NoneValue {}
// Represents a Python list.
message ListValue {
repeated StructuredValue values = 1;
}
// Represents a Python tuple.
message TupleValue {
repeated StructuredValue values = 1;
}
// Represents a Python dict keyed by `str`.
// The comment on Unicode from Value.string_value applies analogously.
message DictValue {
map<string, StructuredValue> fields = 1;
}
// Represents a (key, value) pair.
message PairValue {
string key = 1;
StructuredValue value = 2;
}
// Represents Python's namedtuple.
message NamedTupleValue {
string name = 1;
repeated PairValue values = 2;
}
// A protobuf to represent tf.TensorSpec.
message TensorSpecProto {
string name = 1;
tensorflow.TensorShapeProto shape = 2;
tensorflow.DataType dtype = 3;
}
// A protobuf to represent tf.BoundedTensorSpec.
message BoundedTensorSpecProto {
string name = 1;
tensorflow.TensorShapeProto shape = 2;
tensorflow.DataType dtype = 3;
tensorflow.TensorProto minimum = 4;
tensorflow.TensorProto maximum = 5;
}
// Represents a tf.TypeSpec
message TypeSpecProto {
enum TypeSpecClass {
UNKNOWN = 0;
SPARSE_TENSOR_SPEC = 1; // tf.SparseTensorSpec
INDEXED_SLICES_SPEC = 2; // tf.IndexedSlicesSpec
RAGGED_TENSOR_SPEC = 3; // tf.RaggedTensorSpec
TENSOR_ARRAY_SPEC = 4; // tf.TensorArraySpec
DATA_DATASET_SPEC = 5; // tf.data.DatasetSpec
DATA_ITERATOR_SPEC = 6; // IteratorSpec from data/ops/iterator_ops.py
OPTIONAL_SPEC = 7; // tf.OptionalSpec
PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py
VARIABLE_SPEC = 9; // tf.VariableSpec
ROW_PARTITION_SPEC = 10; // RowPartitionSpec from ragged/row_partition.py
reserved 11;
REGISTERED_TYPE_SPEC = 12; // The type registered as type_spec_class_name.
EXTENSION_TYPE_SPEC = 13; // Subclasses of tf.ExtensionType
}
TypeSpecClass type_spec_class = 1;
// The value returned by TypeSpec._serialize().
StructuredValue type_state = 2;
// The name of the TypeSpec class.
// * If type_spec_class == REGISTERED_TYPE_SPEC, the TypeSpec class is
// the one registered under this name. For types registered outside
// core TensorFlow by an add-on library, that library must be loaded
// before this value can be deserialized by nested_structure_coder.
// * If type_spec_class specifies a particular TypeSpec class, this field is
// redundant with the type_spec_class enum, and is only used for error
// reporting in older binaries that do not know the tupe_spec_class enum.
string type_spec_class_name = 3;
// The number of flat tensor components required by this TypeSpec.
int32 num_flat_components = 4;
}

View File

@ -0,0 +1,78 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
// Modification Copyright (C) 2023 Intel Corporation
syntax = "proto3";
package tensorflow;
import "tensor_shape.proto";
import "tensor_slice.proto";
import "types.proto";
import "versions.proto";
option cc_enable_arenas = true;
option java_outer_classname = "TensorBundleProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.util";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
// Protos used in the tensor bundle module (tf/core/util/tensor_bundle/).
// Special header that is associated with a bundle.
//
// TODO(zongheng,zhifengc): maybe in the future, we can add information about
// which binary produced this checkpoint, timestamp, etc. Sometime, these can be
// valuable debugging information. And if needed, these can be used as defensive
// information ensuring reader (binary version) of the checkpoint and the writer
// (binary version) must match within certain range, etc.
message BundleHeaderProto {
// Number of data files in the bundle.
int32 num_shards = 1;
// An enum indicating the endianness of the platform that produced this
// bundle. A bundle can only be read by a platform with matching endianness.
// Defaults to LITTLE, as most modern platforms are little-endian.
//
// Affects the binary tensor data bytes only, not the metadata in protobufs.
enum Endianness {
LITTLE = 0;
BIG = 1;
}
Endianness endianness = 2;
// Versioning of the tensor bundle format.
VersionDef version = 3;
}
// Describes the metadata related to a checkpointed tensor.
message BundleEntryProto {
// The tensor dtype and shape.
DataType dtype = 1;
TensorShapeProto shape = 2;
// The binary content of the tensor lies in:
// File "shard_id": bytes [offset, offset + size).
int32 shard_id = 3;
int64 offset = 4;
int64 size = 5;
// The CRC32C checksum of the tensor bytes.
fixed32 crc32c = 6;
// Iff present, this entry represents a partitioned tensor. The previous
// fields are interpreted as follows:
//
// "dtype", "shape": describe the full tensor.
// "shard_id", "offset", "size", "crc32c": all IGNORED.
// These information for each slice can be looked up in their own
// BundleEntryProto, keyed by each "slice_name".
repeated TensorSliceProto slices = 7;
}

View File

@ -0,0 +1,92 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
// Modification Copyright (C) 2023 Intel Corporation
syntax = "proto3";
package tensorflow;
import "wrappers.proto";
option cc_enable_arenas = true;
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
// A TensorBundle addition which saves extra information about the objects which
// own variables, allowing for more robust checkpoint loading into modified
// programs.
message TrackableObjectGraph {
message TrackableObject {
message ObjectReference {
// An index into `TrackableObjectGraph.nodes`, indicating the object
// being referenced.
int32 node_id = 1;
// A user-provided name for the edge.
string local_name = 2;
}
message SerializedTensor {
// A name for the Tensor. Simple variables have only one
// `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may
// be restored on object creation as an optimization.
string name = 1;
// The full name of the variable/tensor, if applicable. Used to allow
// name-based loading of checkpoints which were saved using an
// object-based API. Should match the checkpoint key which would have been
// assigned by tf.train.Saver.
string full_name = 2;
// The generated name of the Tensor in the checkpoint.
string checkpoint_key = 3;
// Deprecated bool field for optional restore. This field has never been
// set to True.
reserved "optional_restore";
reserved 4;
}
message SlotVariableReference {
// An index into `TrackableObjectGraph.nodes`, indicating the
// variable object this slot was created for.
int32 original_variable_node_id = 1;
// The name of the slot (e.g. "m"/"v").
string slot_name = 2;
// An index into `TrackableObjectGraph.nodes`, indicating the
// `Object` with the value of the slot variable.
int32 slot_variable_node_id = 3;
}
// Objects which this object depends on.
repeated ObjectReference children = 1;
// Serialized data specific to this object.
repeated SerializedTensor attributes = 2;
// Slot variables owned by this object.
repeated SlotVariableReference slot_variables = 3;
// The registered saver used to save this object. If this saver is not
// present when loading the checkpoint, then loading will fail.
RegisteredSaver registered_saver = 4;
// Whether this object has checkpoint values or descendants with checkpoint
// values. This is computed at save time to avoid traversing the entire
// object graph proto when restoring (which also has to traverse the live
// object graph).
google.protobuf.BoolValue has_checkpoint_values = 5;
}
repeated TrackableObject nodes = 1;
}
message RegisteredSaver {
// The name of the registered saver/restore function.
string name = 1;
// Unique auto-generated name of the object.
string object_name = 2;
}

View File

@ -0,0 +1,124 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Modification Copyright (C) 2023 Intel Corporation
// Wrappers for primitive (non-message) types. These types are useful
// for embedding primitives in the `google.protobuf.Any` type and for places
// where we need to distinguish between the absence of a primitive
// typed field and its default value.
//
// These wrappers have no meaningful use within repeated fields as they lack
// the ability to detect presence on individual elements.
// These wrappers have no meaningful use within a map or a oneof since
// individual entries of a map or fields of a oneof can already detect presence.
syntax = "proto3";
package google.protobuf;
option csharp_namespace = "Google.Protobuf.WellKnownTypes";
option cc_enable_arenas = true;
option go_package = "google.golang.org/protobuf/types/known/wrapperspb";
option java_package = "com.google.protobuf";
option java_outer_classname = "WrappersProto";
option java_multiple_files = true;
option objc_class_prefix = "GPB";
// Wrapper message for `double`.
//
// The JSON representation for `DoubleValue` is JSON number.
message DoubleValue {
// The double value.
double value = 1;
}
// Wrapper message for `float`.
//
// The JSON representation for `FloatValue` is JSON number.
message FloatValue {
// The float value.
float value = 1;
}
// Wrapper message for `int64`.
//
// The JSON representation for `Int64Value` is JSON string.
message Int64Value {
// The int64 value.
int64 value = 1;
}
// Wrapper message for `uint64`.
//
// The JSON representation for `UInt64Value` is JSON string.
message UInt64Value {
// The uint64 value.
uint64 value = 1;
}
// Wrapper message for `int32`.
//
// The JSON representation for `Int32Value` is JSON number.
message Int32Value {
// The int32 value.
int32 value = 1;
}
// Wrapper message for `uint32`.
//
// The JSON representation for `UInt32Value` is JSON number.
message UInt32Value {
// The uint32 value.
uint32 value = 1;
}
// Wrapper message for `bool`.
//
// The JSON representation for `BoolValue` is JSON `true` and `false`.
message BoolValue {
// The bool value.
bool value = 1;
}
// Wrapper message for `string`.
//
// The JSON representation for `StringValue` is JSON string.
message StringValue {
// The string value.
string value = 1;
}
// Wrapper message for `bytes`.
//
// The JSON representation for `BytesValue` is JSON string.
message BytesValue {
// The bytes value.
bytes value = 1;
}

View File

@ -0,0 +1,482 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <stdlib.h>
#include <fstream>
#include <string>
#include "graph_iterator_saved_model.hpp"
#include "openvino/core/type/element_type.hpp"
#include "tensor_bundle.pb.h"
#include "trackable_object_graph.pb.h"
#ifdef ENABLE_SNAPPY_COMPRESSION
# include "snappy.h"
#endif
namespace ov {
namespace frontend {
namespace tensorflow {
template <typename T>
static T smReadFixed(const char* ptr) {
T result = 0;
for (uint8_t i = 0; i < sizeof(T); ++i) {
result |= ptr[i] << (i * 8);
}
return result;
}
template <typename T>
static T smUnpack(char*& ptr, const char* ptr_end) {
T result = 0;
for (uint8_t i = 0; i < sizeof(T) * 7 && ptr < ptr_end; i += 7) {
T byte = *(ptr++);
if (byte & 0x80) {
result |= ((byte & 0x7F) << i);
} else {
result |= byte << i;
return result;
}
}
return 0;
}
struct VIBlock {
uint64_t m_size;
uint64_t m_offset;
void read(char*& ptr, const char* ptr_end) {
m_offset = smUnpack<uint64_t>(ptr, ptr_end);
m_size = smUnpack<uint64_t>(ptr, ptr_end);
}
};
struct VIFooter {
VIBlock m_metaIndex;
VIBlock m_index;
void read(char*& ptr, const char* ptr_end) {
m_index.read(ptr, ptr_end);
m_metaIndex.read(ptr, ptr_end);
}
void read(std::ifstream& fs) {
fs.seekg(0, std::ios::end);
size_t size = fs.tellg();
char footerData[48] = {}, *ptr = &footerData[0];
fs.seekg(size - sizeof(footerData));
fs.read(ptr, sizeof(footerData));
// https://github.com/tensorflow/tensorflow/blob/9659b7bdca80a8ef8240eb021d4da089034eeb00/tensorflow/tsl/lib/io/format.cc#L59
ptr += sizeof(footerData) - 8;
uint32_t magic_lo = *reinterpret_cast<const uint32_t*>(ptr);
uint32_t magic_hi = *reinterpret_cast<const uint32_t*>(ptr + 4);
uint64_t magic_no = (static_cast<uint64_t>(magic_hi) << 32) | static_cast<uint64_t>(magic_lo);
FRONT_END_GENERAL_CHECK(magic_no == 0xdb4775248b80fb57ull, "Wrong index file, magic number mismatch detected");
ptr = &footerData[0];
m_metaIndex.read(ptr, ptr + sizeof(footerData));
m_index.read(ptr, ptr + sizeof(footerData));
}
};
void SavedModelVariablesIndex::read_variables_index_block(std::ifstream& fs,
const VIBlock& index,
std::vector<char>& data,
uint32_t& offset,
uint32_t& offset_end) {
size_t block_size = index.m_size;
data.clear();
data.resize(block_size + 5 /*kBlockTrailerSize*/);
fs.seekg(index.m_offset, std::ios::beg);
fs.read(data.data(), data.size());
#ifndef ENABLE_SNAPPY_COMPRESSION
FRONT_END_GENERAL_CHECK(data[block_size] == 0, "Compressed files aren't supported");
#else
FRONT_END_GENERAL_CHECK(data[block_size] == 0 || data[block_size] == 1, "Compression method isn't supported");
if (data[block_size] == 1) {
size_t uncompressed_length = 0;
FRONT_END_GENERAL_CHECK(snappy::GetUncompressedLength(data.data(), data.size(), &uncompressed_length),
"Cannot retrieve uncompressed block length");
std::string uncompressed_string;
uncompressed_string.reserve(uncompressed_length);
snappy::Uncompress(data.data(), data.size(), &uncompressed_string);
data.resize(uncompressed_length);
std::copy(uncompressed_string.begin(), uncompressed_string.end(), data.begin());
block_size = uncompressed_length;
}
#endif
uint32_t numRestarts = smReadFixed<uint32_t>(data.data() + block_size - sizeof(uint32_t));
size_t maxRestarts = (block_size - sizeof(uint32_t)) / sizeof(uint32_t);
FRONT_END_GENERAL_CHECK(maxRestarts >= numRestarts, "Wrong restarts value");
offset_end = static_cast<uint32_t>(block_size) - ((numRestarts + 1) * sizeof(uint32_t));
offset = smReadFixed<uint32_t>(data.data() + offset_end);
}
void SavedModelVariablesIndex::read_variables_index_pair(char*& ptr,
const char* ptr_end,
std::string& key,
char*& value,
uint32_t& val_length) {
uint32_t shared, nonShared;
shared = smUnpack<uint32_t>(ptr, ptr_end);
nonShared = smUnpack<uint32_t>(ptr, ptr_end);
val_length = smUnpack<uint32_t>(ptr, ptr_end);
// Key inherits last part of string (shared-size bytes) and appends new string
// shared_part_key1 //resize(0) + append(shared_part_key1)
// ............key2 //resize(12) + append(key2)
// ............key3 //resize(12) + append(key3)
// new_shared_key4 //resize(0) + append(new_shared_key4)
// ...........key5 //resize(11) + append(key5)
key.resize(shared);
key.append(ptr, nonShared);
value = ptr + nonShared;
ptr = value + val_length;
}
void SavedModelVariablesIndex::read_variables_index(std::ifstream& fs,
std::map<std::string, std::vector<char>>& varIndex) {
VIFooter footer;
footer.read(fs);
std::vector<VIBlock> secondLevel;
std::vector<char> blockData;
uint32_t offset = 0, offset_end = 0;
read_variables_index_block(fs, footer.m_index, blockData, offset, offset_end);
char *ptr = blockData.data() + offset, *ptr_end = blockData.data() + offset_end, *value = nullptr;
std::string key = "";
uint32_t valLength;
while (ptr < ptr_end) {
read_variables_index_pair(ptr, ptr_end, key, value, valLength);
VIBlock valBlock;
valBlock.read(value, value + valLength);
secondLevel.push_back(valBlock);
ptr = value + valLength;
}
for (auto& block : secondLevel) {
read_variables_index_block(fs, block, blockData, offset, offset_end);
key = "";
ptr = blockData.data() + offset;
ptr_end = blockData.data() + offset_end;
while (ptr < ptr_end) {
read_variables_index_pair(ptr, ptr_end, key, value, valLength);
varIndex[key] = std::vector<char>(value, value + valLength);
}
}
}
void SavedModelVariablesIndex::read_bundle_header() {
auto item = m_variables_index.find("");
FRONT_END_GENERAL_CHECK(item != m_variables_index.end(), "Bundle Header isn't found in index");
::tensorflow::BundleHeaderProto bundleHeader;
FRONT_END_GENERAL_CHECK(bundleHeader.ParseFromString(item->second.data()),
"Bundle Header: Cannot parse Bundle Header");
FRONT_END_GENERAL_CHECK(bundleHeader.version().producer() == 1, "Bundle Header: Unsupported producer version");
FRONT_END_GENERAL_CHECK(bundleHeader.version().min_consumer() == 0, "Bundle Header: Unsupported consumer version");
FRONT_END_GENERAL_CHECK(bundleHeader.endianness() == 0, "Bundle Header: BIG endian isn't supported");
m_total_shards = bundleHeader.num_shards();
}
void SavedModelVariablesIndex::read_checkpointable_object_graph() {
m_variables_map.clear();
auto item = m_variables_index.find("_CHECKPOINTABLE_OBJECT_GRAPH");
FRONT_END_GENERAL_CHECK(item != m_variables_index.end(), "Checkpointable Object Graph isn't found in index");
::tensorflow::BundleEntryProto entry;
FRONT_END_GENERAL_CHECK(entry.ParseFromArray(item->second.data(), static_cast<int>(item->second.size())),
"CMO: Cannot parse Bundle Entry");
FRONT_END_GENERAL_CHECK(entry.slices().empty(), "CMO: Slices are not supported");
auto shard = m_data_files.find(entry.shard_id());
FRONT_END_GENERAL_CHECK(shard != m_data_files.end(), "CMO: data files isn't found");
std::vector<char> data(entry.size());
::tensorflow::TrackableObjectGraph tog;
// TODO: have to understand this offset
// It looks like reinterpret_cast artifact
// https://github.com/tensorflow/tensorflow/blob/d90f1947ebcf510b23c238f43c2191e5b3817cb3/tensorflow/cc/experimental/libexport/load.cc#L70
int chg = 6;
shard->second->seekg(entry.offset() + chg);
shard->second->read(data.data(), entry.size() - chg);
// Might be need to remove this verification:
// https://github.com/tensorflow/tensorflow/blob/d90f1947ebcf510b23c238f43c2191e5b3817cb3/tensorflow/cc/experimental/libexport/load.cc#L73
// FRONT_END_GENERAL_CHECK(tog.ParseFromArray(data.data(), static_cast<int>(data.size()) - chg), "CMO: Trackable
// Object Graph couldn't be read");
tog.ParseFromArray(data.data(), static_cast<int>(data.size()) - chg);
for (const auto& node : tog.nodes()) {
for (const auto& attr : node.attributes()) {
m_variables_map[attr.full_name()] = attr.checkpoint_key();
}
}
}
bool GraphIteratorSavedModel::is_valid_signature(const ::tensorflow::SignatureDef& signature) const {
const std::map<::tensorflow::DataType, ov::element::Type> types{
{::tensorflow::DataType::DT_BOOL, ov::element::boolean},
{::tensorflow::DataType::DT_INT16, ov::element::i16},
{::tensorflow::DataType::DT_INT32, ov::element::i32},
{::tensorflow::DataType::DT_INT64, ov::element::i64},
{::tensorflow::DataType::DT_HALF, ov::element::f16},
{::tensorflow::DataType::DT_FLOAT, ov::element::f32},
{::tensorflow::DataType::DT_DOUBLE, ov::element::f64},
{::tensorflow::DataType::DT_UINT8, ov::element::u8},
{::tensorflow::DataType::DT_INT8, ov::element::i8},
{::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16},
{::tensorflow::DataType::DT_STRING, ov::element::undefined}};
for (const auto& it : signature.inputs()) {
if (it.second.name().empty() || types.find(it.second.dtype()) == types.end())
return false;
}
for (const auto& it : signature.outputs()) {
if (it.second.name().empty() || types.find(it.second.dtype()) == types.end())
return false;
}
return true;
}
bool SavedModelVariablesIndex::read_variables(std::ifstream& vi_stream, const std::string& path) {
m_variables_index.clear();
read_variables_index(vi_stream, m_variables_index);
read_bundle_header();
std::vector<char> suffix(20);
for (int32_t shard = 0; shard < m_total_shards; ++shard) {
std::snprintf(suffix.data(), suffix.size(), "data-%05d-of-%05d", shard, m_total_shards);
std::string fullPath = ov::util::path_join({path, "variables", std::string("variables.") + suffix.data()});
m_data_files[shard] =
std::shared_ptr<std::ifstream>(new std::ifstream(fullPath, std::ifstream::in | std::ifstream::binary));
FRONT_END_GENERAL_CHECK(m_data_files[shard]->is_open(), "Saved Model's variable index file does not exist");
}
read_checkpointable_object_graph();
return true;
}
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
bool SavedModelVariablesIndex::read_variables(std::ifstream& vi_stream, const std::wstring& path) {
m_variables_index.clear();
read_variables_index(vi_stream, m_variables_index);
read_bundle_header();
std::vector<wchar_t> suffix(20);
for (int32_t shard = 0; shard < m_total_shards; ++shard) {
swprintf_s(suffix.data(), suffix.size(), L"data-%05d-of-%05d", shard, m_total_shards);
std::wstring fullPath =
ov::util::path_join_w({path, L"variables", std::wstring(L"variables.") + suffix.data()});
m_data_files[shard] =
std::shared_ptr<std::ifstream>(new std::ifstream(fullPath, std::ifstream::in | std::ifstream::binary));
FRONT_END_GENERAL_CHECK(m_data_files[shard]->is_open(), "Saved Model's variable index file does not exist");
}
read_checkpointable_object_graph();
return true;
}
#endif
struct PtrNode {
const ::tensorflow::NodeDef* node;
std::vector<PtrNode*> inputs;
std::vector<PtrNode*> outputs;
PtrNode() : node(nullptr), inputs(), outputs() {}
PtrNode(const ::tensorflow::NodeDef& src_node, const std::map<std::string, PtrNode*>& node_dictionary) {
node = &src_node;
std::vector<std::string> parsedName;
for (const auto& input_name : node->input()) {
parse_node_name(input_name, parsedName);
auto input_node = node_dictionary.find(parsedName[0]);
if (input_node == node_dictionary.end()) {
continue;
}
input_node->second->outputs.push_back(this);
inputs.push_back(input_node->second);
}
}
void find_parent_by_op(const std::string& op, std::vector<PtrNode*>& result) const {
for (auto input : inputs) {
if (input->op() == op) {
result.push_back(input);
}
input->find_parent_by_op(op, result);
}
}
static void parse_node_name(const std::string& name, std::vector<std::string>& result) {
result.clear();
size_t left_pos = name.find_first_of('^'), right_pos = name.find(':');
if (left_pos != std::string::npos && left_pos < right_pos) {
++left_pos;
} else {
left_pos = 0;
}
while (right_pos != std::string::npos && right_pos > left_pos) {
result.push_back(name.substr(left_pos, right_pos - left_pos));
left_pos = right_pos + 1;
right_pos = name.find(':', left_pos);
}
result.push_back(name.substr(left_pos, name.length() - left_pos));
}
const std::string& op() const {
return node->op();
}
};
static void read_stateful_partitioned_call(const std::shared_ptr<::tensorflow::GraphDef> graph_def,
const ::tensorflow::NodeDef& partCall,
std::map<std::string, PtrNode*>& node_dictionary) {
FRONT_END_GENERAL_CHECK(partCall.op() == "StatefulPartitionedCall", "Passed node isn't StatefulPartitionedCall");
std::string func_name = partCall.attr().at("f").func().name();
const ::tensorflow::FunctionDef* func_def = nullptr;
for (const auto& func : graph_def->library().function()) {
if (func.signature().name() == func_name) {
func_def = &func;
break;
}
}
FRONT_END_GENERAL_CHECK(func_def, "Function isn't found in the library");
FRONT_END_GENERAL_CHECK(graph_def->has_library(), "GraphDef contains functions, but doesn't have the library");
std::map<std::string, PtrNode*> nodes;
// Filling temporary input nodes for exact function
for (int i = 0; i < func_def->signature().input_arg_size(); ++i) {
const auto& input_arg = func_def->signature().input_arg(i).name();
const auto& parent_input = partCall.input(i);
auto input_node = node_dictionary.find(parent_input);
if (input_node != node_dictionary.end()) {
nodes[input_arg] = input_node->second;
}
}
// Parsing nodes and inline partitioned calls
for (const auto& node : func_def->node_def()) {
nodes[node.name()] = new PtrNode(node, nodes);
if (node.op() == "StatefulPartitionedCall") {
read_stateful_partitioned_call(graph_def, node, nodes);
}
}
// Removing temporary input nodes
for (int i = 0; i < func_def->signature().input_arg_size(); ++i) {
const auto& input_arg = func_def->signature().input_arg(i).name();
auto input_node = nodes.find(input_arg);
if (input_node != nodes.end()) {
nodes.erase(input_node);
}
}
// Moving nodes to the global dictionary
for (const auto& node : nodes) {
std::string global_name = partCall.name() + "/" + node.first;
node_dictionary[global_name] = node.second;
}
}
void GraphIteratorSavedModel::map_assignvariable(const std::shared_ptr<::tensorflow::GraphDef> graph_def,
std::map<std::string, std::string>& variables_map) const {
std::map<std::string, PtrNode*> nodes;
for (const auto& node : graph_def->node()) {
nodes[node.name()] = new PtrNode(node, nodes);
if (node.op() == "StatefulPartitionedCall") {
read_stateful_partitioned_call(graph_def, node, nodes);
}
}
for (const auto& node : nodes) {
if (node.second->op() != "AssignVariableOp") {
continue;
}
// TODO: assets reading
std::vector<PtrNode*> restorev2_nodes;
std::vector<PtrNode*> varhandle_nodes;
node.second->find_parent_by_op("RestoreV2", restorev2_nodes);
node.second->find_parent_by_op("VarHandleOp", varhandle_nodes);
FRONT_END_GENERAL_CHECK(restorev2_nodes.size() == 1, "Found unexpected amount of RestoreV2 nodes");
FRONT_END_GENERAL_CHECK(varhandle_nodes.size() == 1, "Found unexpected amount of VarHandleOp nodes");
std::vector<std::string> restore_output;
// Expected path is: RestoreV2 -(output_index)-(0)-> Identity -(0)-(1)-> AssignVariableOp
PtrNode::parse_node_name(node.second->inputs[1]->node->input(0), restore_output);
int output_index = std::atoi(restore_output[restore_output.size() - 1].c_str());
// Expected path is: Const(tensor_names) -(0)-(1)-> RestoreV2
const auto& variable_name =
restorev2_nodes[0]->inputs[1]->node->attr().at("value").tensor().string_val(output_index);
variables_map[varhandle_nodes[0]->node->name()] = variable_name;
}
nodes.clear();
}
bool GraphIteratorSavedModel::is_supported(const std::string& path) {
return ov::util::directory_exists(path) && ov::util::file_exists(ov::util::path_join({path, "saved_model.pb"}));
}
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
bool GraphIteratorSavedModel::is_supported(const std::wstring& path) {
return ov::util::directory_exists(path) && ov::util::file_exists(ov::util::path_join_w({path, L"saved_model.pb"}));
}
#endif
template <>
std::basic_string<char> get_saved_model_name<char>() {
return "/saved_model.pb";
}
template <>
std::basic_string<char> get_variables_index_name<char>() {
return "/variables/variables.index";
}
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
template <>
std::basic_string<wchar_t> get_saved_model_name<wchar_t>() {
return L"/saved_model.pb";
}
template <>
std::basic_string<wchar_t> get_variables_index_name<wchar_t>() {
return L"/variables/variables.index";
}
#endif
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -36,6 +36,25 @@ std::vector<T> reorder_ops_by_names(const std::vector<std::string>& names, const
}
return resulted_ops;
};
/// \brief Adds known input names from Saved Model file format
/// \param[in] node Node which should be updated
/// \param[in] saved_model_names Map of names from saved model
/// \returns True if node was updated, false otherwise
static bool apply_saved_model_names(std::shared_ptr<ov::Node> node,
const std::shared_ptr<std::map<std::string, std::string>>& saved_model_names) {
for (size_t i = 0; i < node->get_output_size(); ++i) {
const auto& node_names = node->get_output_tensor(i).get_names();
for (const auto& name : node_names) {
const auto& saved_model_name = saved_model_names->find(name);
if (saved_model_name != saved_model_names->end()) {
node->set_friendly_name(saved_model_name->second);
return true;
}
}
}
return false;
}
} // namespace
TranslateSession::TranslateSession(const ov::frontend::InputModel::Ptr& input_model,
@ -94,6 +113,8 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu
const auto& model_inputs = model_tf->get_inputs();
const auto& model_outputs = model_tf->get_outputs();
const auto& model_frozen_inputs = model_tf->get_tensor_values();
const auto& saved_model_inputs = model_tf->get_saved_model_input_names();
const auto& saved_model_outputs = model_tf->get_saved_model_output_names();
// fill ng_op_map with Constant outputs for frozen inputs
for (const auto& frozen_input : model_frozen_inputs) {
@ -123,6 +144,11 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu
auto param = std::make_shared<ov::opset8::Parameter>(input_type, input_shape);
set_node_name(input_name, param);
if (saved_model_inputs.get() && saved_model_inputs->size() > 0) {
if (!apply_saved_model_names(param, saved_model_inputs)) {
param->get_output_tensor(0).add_names({"saved_model_unused"});
}
}
params.push_back(param);
ng_op_map[input_name] = {param};
}
@ -273,10 +299,30 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu
if (port_type == "none") {
for (const auto& node_output : ng_op_map[operation_name]) {
auto result_node = std::make_shared<ov::opset8::Result>(node_output);
// to be aligned with Legacy Frontend we set a name along with output port index
// though, the Result name is not used in the OV API 2.0 but it is checked in MO args tests
result_node->set_friendly_name(model_output_name + ":0");
results.push_back(result_node);
// Customize output name in case we have mapping from Saved Model format
if (saved_model_outputs.get() && saved_model_outputs->size() > 0) {
bool isUsed = true;
for (const auto& name : model_output_tensor_place->get_names()) {
auto saved_model_name = saved_model_outputs->find(name);
if (saved_model_name == saved_model_outputs->end()) {
saved_model_name = saved_model_outputs->find(name + ":0");
}
if (saved_model_name != saved_model_outputs->end()) {
result_node->set_friendly_name(saved_model_name->second);
results.push_back(result_node);
isUsed = false;
break;
}
if (!isUsed) {
result_node->get_input_tensor(0).add_names({"saved_model_unused"});
}
}
} else {
// to be aligned with Legacy Frontend we set a name along with output port index
// though, the Result name is not used in the OV API 2.0 but it is checked in MO args tests
result_node->set_friendly_name(model_output_name + ":0");
results.push_back(result_node);
}
}
} else if (port_type == "out") {
const auto& node_outputs = ng_op_map[operation_name];

View File

@ -31,6 +31,10 @@ public:
std::shared_ptr<ov::Model> get_body_ov_model(const std::string& body_graph_name);
ov::frontend::InputModel::Ptr get_input_model(void) const {
return m_input_model;
}
private:
const ov::frontend::InputModel::Ptr m_input_model;
const std::shared_ptr<TranslatorDictionaryType> m_translator_map;

View File

@ -0,0 +1,57 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <vector>
#include "internal_operation.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
/// Pseudo-entity for storing strings
class StringConstant : public InternalOperation {
public:
OPENVINO_OP("StringConstant", "ov::frontend::tensorflow::util", InternalOperation);
StringConstant(ov::Any data, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, {}, 1),
m_data(data) {
validate_and_infer_types();
}
StringConstant(std::string& str, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, {}, 1),
m_data({str}) {
validate_and_infer_types();
}
StringConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, {}, 1) {
validate_and_infer_types();
}
void validate_and_infer_types() override {
set_output_type(0, ov::element::dynamic, ov::PartialShape::dynamic());
}
ov::Any get_data() {
return m_data;
}
std::string& get_string() {
return m_data.as<std::vector<std::string>>()[0];
}
private:
ov::Any m_data;
ov::Shape m_shape;
};
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,27 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
namespace pass {
// This transformation removes isolated subgraph unused Parameters and
// Results marked as unused by Saved Model settings
class SavedModelUnusedRemover : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ov::frontend::tensorflow::pass::SavedModelUnusedRemover");
SavedModelUnusedRemover() {}
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};
} // namespace pass
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -4,6 +4,7 @@
#include "helper_transforms/const_to_result_remover.hpp"
#include "helper_ops/string_constant.hpp"
#include "helper_ops/unsupported_constant.hpp"
#include "openvino/opsets/opset10.hpp"
@ -22,7 +23,8 @@ bool ConstToResultRemover::run_on_model(const std::shared_ptr<ov::Model>& m) {
for (const auto& result : m->get_results()) {
auto unsupported_const = as_type_ptr<UnsupportedConstant>(result->get_input_node_shared_ptr(0));
auto const_node = as_type_ptr<Constant>(result->get_input_node_shared_ptr(0));
if (unsupported_const || const_node) {
auto string_const = as_type_ptr<StringConstant>(result->get_input_node_shared_ptr(0));
if (unsupported_const || const_node || string_const) {
results_to_remove.push_back(result);
}
}

View File

@ -0,0 +1,74 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "helper_transforms/saved_model_unused_remover.hpp"
#include "openvino/opsets/opset8.hpp"
using namespace std;
using namespace ov::opset8;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace pass {
bool SavedModelUnusedRemover::run_on_model(const std::shared_ptr<ov::Model>& m) {
ParameterVector params_to_remove;
ResultVector results_to_remove;
// There is two cases
// 1. When we found unused result with/without unused parameter
// 2. When we found unused parameter
for (const auto& result : m->get_results()) {
bool isUsed = false;
for (size_t i = 0; i < result->get_input_size(); ++i) {
const auto& node_names = result->get_input_tensor(i).get_names();
isUsed = std::find(node_names.begin(), node_names.end(), "saved_model_unused") == node_names.end();
}
if (!isUsed) {
results_to_remove.push_back(result);
continue;
}
auto param = as_type_ptr<Parameter>(result->get_input_node_shared_ptr(0));
if (param) {
for (size_t i = 0; i < param->get_output_size(); ++i) {
const auto& node_names = param->get_output_tensor(i).get_names();
isUsed = std::find(node_names.begin(), node_names.end(), "saved_model_unused") == node_names.end();
}
if (!isUsed) {
results_to_remove.push_back(result);
params_to_remove.push_back(param);
}
}
}
for (const auto& param : m->get_parameters()) {
bool isUsed = false;
for (size_t i = 0; i < param->get_output_size(); ++i) {
const auto& node_names = param->get_output_tensor(i).get_names();
isUsed = std::find(node_names.begin(), node_names.end(), "saved_model_unused") == node_names.end();
}
if (!isUsed && std::find(params_to_remove.begin(), params_to_remove.end(), param) == params_to_remove.end()) {
params_to_remove.push_back(param);
}
}
for (const auto& result : results_to_remove) {
m->remove_result(result);
}
for (const auto& param : params_to_remove) {
m->remove_parameter(param);
}
return true;
}
} // namespace pass
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -3,6 +3,7 @@
//
#include "common_op_table.hpp"
#include "helper_ops/string_constant.hpp"
#include "helper_ops/unsupported_constant.hpp"
#include "openvino/opsets/opset8.hpp"
@ -16,10 +17,15 @@ namespace tensorflow {
namespace op {
OutputVector translate_const_op(const NodeContext& node) {
auto ov_type = node.get_attribute<element::Type>("dtype");
auto ov_type = node.get_attribute_as_any("dtype");
std::shared_ptr<Node> const_node;
if (ov_type == element::dynamic) {
const_node = std::make_shared<UnsupportedConstant>();
if (!ov_type.is<ov::element::Type>() || ov_type.as<ov::element::Type>() == ov::element::dynamic ||
ov_type.as<ov::element::Type>() == ov::element::undefined) {
if (ov_type.is<std::string>() && ov_type.as<std::string>() == "DT_STRING") {
const_node = std::make_shared<StringConstant>(node.get_attribute_as_any("value"));
} else {
const_node = std::make_shared<UnsupportedConstant>();
}
} else {
auto tensor = node.get_attribute<Tensor>("value");
const_node = std::make_shared<Constant>(tensor.get_element_type(), tensor.get_shape(), tensor.data());

View File

@ -14,7 +14,13 @@ namespace tensorflow {
namespace op {
OutputVector translate_identity_op(const NodeContext& node) {
vector<string> supported_ops = {"Identity", "PreventGradient", "Snapshot", "StopGradient"};
vector<string> supported_ops = {"Identity",
"PreventGradient",
"Snapshot",
"StopGradient",
"ReadVariableOp",
"ShardedFilename",
"MergeV2Checkpoints"};
default_op_checks(node, 1, supported_ops);
auto input = node.get_input(0);

View File

@ -416,6 +416,32 @@ if(ENABLE_OV_TF_LITE_FRONTEND)
set(flatbuffers_DEPENDENCY ${flatbuffers_DEPENDENCY} PARENT_SCOPE)
endif()
#
# Snappy Compression
#
if(ENABLE_SNAPPY_COMPRESSION)
function(tf_build_snappy)
set(BUILD_SHARED_LIBS OFF)
set(SNAPPY_BUILD_BENCHMARKS OFF)
set(SNAPPY_BUILD_TESTS OFF)
set(INSTALL_GTEST OFF)
set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
set(CMAKE_CXX_STANDARD 14)
if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
# Removes 3rd party errors which may affect OpenVINO CI
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
if(NOT CMAKE_CXX_FLAGS MATCHES "-Werror=return-type")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=return-type")
endif()
endif()
endif()
add_subdirectory(snappy EXCLUDE_FROM_ALL)
endfunction()
tf_build_snappy()
ov_install_static_lib(snappy ${OV_CPACK_COMP_CORE})
endif()
#
# ONNX
#

1
thirdparty/snappy vendored Submodule

@ -0,0 +1 @@
Subproject commit dc05e026488865bc69313a68bcc03ef2e4ea8e83