From 79683c24cafd7ca02e7cd9f6fac9f2c0f0a18b16 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Thu, 15 Jun 2023 01:01:33 +0400 Subject: [PATCH] [TF FE] Support models with v1 checkpoints (#18000) * [TF FE] Support models with v1 checkpoints Signed-off-by: Kazantsev, Roman * Fix build issue * Fix issue in load_impl * Fix MO unit-tests * Check input_checkpoint in argv * Fix build issue * Update src/frontends/tensorflow/src/checkpoint_v1_reader.cpp Co-authored-by: Maxim Vafin * Apply comments from code-review: split CheckpointV1Reader constructor --------- Signed-off-by: Kazantsev, Roman Co-authored-by: Maxim Vafin --- .../tensorflow/src/checkpoint_utils.cpp | 277 ++++++++++++++++++ .../tensorflow/src/checkpoint_utils.hpp | 100 +++++++ .../tensorflow/src/checkpoint_v1_reader.cpp | 262 +++++++++++++++++ .../tensorflow/src/checkpoint_v1_reader.hpp | 81 +++++ .../tensorflow/src/decoder_argdef.cpp | 1 + .../tensorflow/src/decoder_proto.cpp | 106 +------ .../tensorflow/src/decoder_proto.hpp | 2 - src/frontends/tensorflow/src/frontend.cpp | 107 ++++++- .../tensorflow/src/graph_iterator_proto.hpp | 45 ++- .../src/graph_iterator_proto_txt.hpp | 20 +- src/frontends/tensorflow/src/input_model.cpp | 15 + src/frontends/tensorflow/src/input_model.hpp | 4 + src/frontends/tensorflow/src/op/variable.cpp | 64 ++++ src/frontends/tensorflow/src/op_table.cpp | 20 +- src/frontends/tensorflow/src/tf_utils.cpp | 192 ++++++++++++ src/frontends/tensorflow/src/tf_utils.hpp | 31 ++ .../tensorflow/src/variables_index.cpp | 81 +---- tools/mo/openvino/tools/mo/front/tf/loader.py | 11 +- .../tools/mo/moc_frontend/pipeline.py | 7 +- .../mo/utils/freeze_placeholder_test.py | 1 + .../mo/utils/test_mo_model_analysis_actual.py | 1 + .../conversion_incorrect_models_test.py | 10 +- .../conversion_with_checkpoint_v1_test.py | 42 +++ .../test_models/model_with_variable_v1.pbtxt | 53 ++++ tools/mo/unit_tests/moc_tf_fe/utils.py | 38 +++ 25 files changed, 1352 insertions(+), 219 deletions(-) create mode 100644 src/frontends/tensorflow/src/checkpoint_utils.cpp create mode 100644 src/frontends/tensorflow/src/checkpoint_utils.hpp create mode 100644 src/frontends/tensorflow/src/checkpoint_v1_reader.cpp create mode 100644 src/frontends/tensorflow/src/checkpoint_v1_reader.hpp create mode 100644 src/frontends/tensorflow/src/op/variable.cpp create mode 100644 src/frontends/tensorflow/src/tf_utils.cpp create mode 100644 src/frontends/tensorflow/src/tf_utils.hpp create mode 100644 tools/mo/unit_tests/moc_tf_fe/conversion_with_checkpoint_v1_test.py create mode 100644 tools/mo/unit_tests/moc_tf_fe/test_models/model_with_variable_v1.pbtxt create mode 100644 tools/mo/unit_tests/moc_tf_fe/utils.py diff --git a/src/frontends/tensorflow/src/checkpoint_utils.cpp b/src/frontends/tensorflow/src/checkpoint_utils.cpp new file mode 100644 index 00000000000..a63cb942e0b --- /dev/null +++ b/src/frontends/tensorflow/src/checkpoint_utils.cpp @@ -0,0 +1,277 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "checkpoint_utils.hpp" + +#include +#include + +#include "openvino/frontend/exception.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { + +static const char escape1 = '\000'; +static const char null_character = '\xff'; +static const char separator = '\001'; +static const char escape2 = '\xff'; +static const char ffcharacter = '\000'; +static const char escape1_separator[2] = {escape1, separator}; +static const int max_signed64_length = 10; + +// This array maps encoding length to header bits in the first two bytes. +static const char length_to_header_bits[1 + max_signed64_length][2] = {{0, 0}, + {'\x80', 0}, + {'\xc0', 0}, + {'\xe0', 0}, + {'\xf0', 0}, + {'\xf8', 0}, + {'\xfc', 0}, + {'\xfe', 0}, + {'\xff', 0}, + {'\xff', '\x80'}, + {'\xff', '\xc0'}}; + +inline bool byte_is_0_or_255(char c) { + return (static_cast(c + 1)) < 2; +} + +static inline const unsigned char* char_ptr_to_unsigned_char_ptr(const char* p) { + const void* void_ptr = static_cast(p); + return static_cast(void_ptr); +} + +static inline const char* unsigned_char_ptr_to_char_ptr(const unsigned char* p) { + const void* void_ptr = static_cast(p); + return static_cast(void_ptr); +} + +// return a pointer to the first byte in the range "[start..limit)" +// whose value is 0 or 255 (escape1 or escape2) +inline const char* find_special_byte(const char* start, const char* limit) { + // If these constants were ever changed, this routine needs to change + const char* current = start; + while (current < limit && !byte_is_0_or_255(*current)) { + ++current; + } + return current; +} + +// encode "source" and append to "dest", escaping special characters +inline static void encode_string_piece(std::string& dest, const std::string& source) { + const char* current = source.data(); + const char* limit = current + source.size(); + const char* copy_start = current; + while (true) { + current = find_special_byte(current, limit); + if (current >= limit) + break; // No more special characters that need escaping + char c = *(current++); + if (c == escape1) { + dest.append(copy_start, current - copy_start - 1); + dest.push_back(escape1); + dest.push_back(null_character); + copy_start = current; + } else { + FRONT_END_GENERAL_CHECK(c == escape2, "[TensorFlow Frontend] incorrect model: corrupted checkpoint"); + dest.append(copy_start, current - copy_start - 1); + dest.push_back(escape2); + dest.push_back(ffcharacter); + copy_start = current; + } + } + if (current > copy_start) { + dest.append(copy_start, current - copy_start); + } +} + +// reverse bytes of 64-bit number +static void convert_to_big_endian64(char* dst, uint64_t v) { + for (int i = 0; i < 8; ++i) { + dst[i] = (v >> (56 - 8 * i)) & 0xff; + } +} + +// compute floored log2(n) +static int log2_floor32(uint32_t n) { + if (n == 0) + return -1; + int log = 0; + uint32_t value = n; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32_t x = value >> shift; + if (x != 0) { + value = x; + log += shift; + } + } + return log; +} + +// compute floored log2(n) +static int log2_floor64(uint64_t n) { + const uint32_t topbits = static_cast(n >> 32); + if (topbits == 0) { + // Top bits are zero, so scan in bottom bits + return log2_floor32(static_cast(n)); + } else { + return 32 + log2_floor32(topbits); + } +} + +// calculate the encoding length in bytes of the signed number n +static inline int signed_encoding_length(int64_t n) { + int log2 = log2_floor64(n < 0 ? ~n : n) + 1; + return log2 / 7 + 1; +} + +void write_signed_num_increasing(std::string& dest, int64_t val) { + const uint64_t x = val < 0 ? ~val : val; + if (x < 64) { // fast path for encoding length == 1 + dest += length_to_header_bits[1][0] ^ static_cast(val); + return; + } + // buf = val in network byte order, sign extended to 10 bytes + const char sign_byte = val < 0 ? '\xff' : '\0'; + char buf[max_signed64_length] = { + sign_byte, + sign_byte, + }; + convert_to_big_endian64(buf + 2, val); + const int len = signed_encoding_length(x); + FRONT_END_GENERAL_CHECK(0 < len && len <= max_signed64_length, + "[TensorFlow Frontend] internal error: write_signed_num_increasing failed"); + char* const begin = buf + max_signed64_length - len; + begin[0] ^= length_to_header_bits[len][0]; + begin[1] ^= length_to_header_bits[len][1]; // ok because len >= 2 + dest.append(begin, len); +} + +void write_num_increasing(std::string& dest, uint64_t val) { + // Values are encoded with a single byte length prefix, followed + // by the actual value in big-endian format with leading 0 bytes + // dropped. + unsigned char buf[9]; // 8 bytes for value plus one byte for length + int len = 0; + while (val > 0) { + ++len; + buf[9 - len] = (val & 0xff); + val >>= 8; + } + buf[9 - len - 1] = len; + ++len; + dest.append(unsigned_char_ptr_to_char_ptr(&buf[0]) + 9 - len, len); +} + +std::string encode_tensor_name_slice(const std::string& name, + const std::vector& starts, + const std::vector lengths) { + std::string buffer; + // All the tensor slice keys will start with a 0 + write_num_increasing(buffer, 0); + encode_string_piece(buffer, name); + buffer.append(escape1_separator, 2); + write_num_increasing(buffer, starts.size()); + + FRONT_END_GENERAL_CHECK( + starts.size() == lengths.size(), + "[TensorFlow Frontend] internal error or inconsistent model: check consistency of checkpoint files"); + for (size_t d = 0; d < starts.size(); ++d) { + write_signed_num_increasing(buffer, starts[d]); + write_signed_num_increasing(buffer, lengths[d]); + } + return buffer; +} + +uint32_t decode_fixed32(const char* ptr) { + uint32_t result; + std::memcpy(&result, ptr, sizeof(result)); + return result; +} + +const char* get_varint32_ptr(const char* p, const char* limit, uint32_t& value) { + if (p < limit) { + uint32_t result = *(char_ptr_to_unsigned_char_ptr(p)); + if ((result & 128) == 0) { + value = result; + return p + 1; + } + } + uint32_t result = 0; + for (uint32_t shift = 0; shift <= 28 && p < limit; shift += 7) { + uint32_t byte = *(char_ptr_to_unsigned_char_ptr(p)); + ++p; + if (byte & 128) { + // More bytes are present + result |= ((byte & 127) << shift); + } else { + result |= (byte << shift); + value = result; + return p; + } + } + return nullptr; +} + +const char* get_varint64_ptr(const char* p, const char* limit, uint64_t* value) { + uint64_t result = 0; + for (uint32_t shift = 0; shift <= 63 && p < limit; shift += 7) { + uint64_t byte = *(char_ptr_to_unsigned_char_ptr(p)); + ++p; + if (byte & 128) { + // More bytes are present + result |= ((byte & 127) << shift); + } else { + result |= (byte << shift); + *value = result; + return p; + } + } + return nullptr; +} + +bool get_varint64(std::string& input, uint64_t* value) { + const char* p = input.data(); + const char* limit = p + input.size(); + const char* q = get_varint64_ptr(p, limit, value); + if (q == nullptr) { + return false; + } else { + input = std::string(q, limit - q); + return true; + } +} + +const char* decode_entry(const char* p, + const char* limit, + uint32_t& shared, + uint32_t& non_shared, + uint32_t& value_length) { + if (limit - p < 3) + return nullptr; + shared = char_ptr_to_unsigned_char_ptr(p)[0]; + non_shared = char_ptr_to_unsigned_char_ptr(p)[1]; + value_length = char_ptr_to_unsigned_char_ptr(p)[2]; + if ((shared | non_shared | value_length) < 128) { + // Fast path: all three values are encoded in one byte each + p += 3; + } else { + if ((p = get_varint32_ptr(p, limit, shared)) == nullptr) + return nullptr; + if ((p = get_varint32_ptr(p, limit, non_shared)) == nullptr) + return nullptr; + if ((p = get_varint32_ptr(p, limit, value_length)) == nullptr) + return nullptr; + } + + if (static_cast(limit - p) < (non_shared + value_length)) { + return nullptr; + } + return p; +} +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/checkpoint_utils.hpp b/src/frontends/tensorflow/src/checkpoint_utils.hpp new file mode 100644 index 00000000000..07ffb2d7ecb --- /dev/null +++ b/src/frontends/tensorflow/src/checkpoint_utils.hpp @@ -0,0 +1,100 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include +#include + +#include "openvino/frontend/exception.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { + +#define VARIABLES_INDEX_FOOTER_SIZE 48 +#define BLOCK_TRAILER_SIZE 5 +#define SAVED_TENSOR_SLICES_KEY "" + +template +static T smUnpack(char*& ptr, const char* ptr_end) { + T result = 0; + for (uint8_t i = 0; i <= sizeof(T) * 7 && ptr < ptr_end; i += 7) { + T byte = *(ptr++); + if (byte & 0x80) { + result |= ((byte & 0x7F) << i); + } else { + result |= byte << i; + return result; + } + } + return 0; +} + +/// \brief Structure is for storing information about block in Varaibles Index file. +/// It defines only offset and block size, no information about exact content. +struct VIBlock { + uint64_t m_size; + uint64_t m_offset; + + void read(char*& ptr, const char* ptr_end) { + m_offset = smUnpack(ptr, ptr_end); + m_size = smUnpack(ptr, ptr_end); + } +}; + +/// \brief Structure is for storing information about Variables Index footer information. +/// It contains description of two blocks and a magic number for a file verification. +/// Currently, it is placed in last VARIABLES_INDEX_FOOTER_SIZE bytes at the end of a file. +struct VIFooter { + VIBlock m_metaIndex; + VIBlock m_index; + + void read(char*& ptr, const char* ptr_end) { + m_index.read(ptr, ptr_end); + m_metaIndex.read(ptr, ptr_end); + } + + void read(std::ifstream& fs) { + fs.seekg(0, std::ios::end); + size_t size = fs.tellg(); + FRONT_END_GENERAL_CHECK(size >= VARIABLES_INDEX_FOOTER_SIZE, + "Wrong index file, file size is less than minimal expected"); + + char footerData[VARIABLES_INDEX_FOOTER_SIZE] = {}, *ptr = &footerData[0]; + fs.seekg(size - sizeof(footerData)); + fs.read(ptr, sizeof(footerData)); + + // https://github.com/tensorflow/tensorflow/blob/9659b7bdca80a8ef8240eb021d4da089034eeb00/tensorflow/tsl/lib/io/format.cc#L59 + ptr += sizeof(footerData) - 8; + uint32_t magic_lo = *reinterpret_cast(ptr); + uint32_t magic_hi = *reinterpret_cast(ptr + 4); + uint64_t magic_no = (static_cast(magic_hi) << 32) | static_cast(magic_lo); + + FRONT_END_GENERAL_CHECK(magic_no == 0xdb4775248b80fb57ull, "Wrong index file, magic number mismatch detected"); + + ptr = &footerData[0]; + m_metaIndex.read(ptr, ptr + sizeof(footerData)); + m_index.read(ptr, ptr + sizeof(footerData)); + } +}; + +uint32_t decode_fixed32(const char* ptr); + +const char* decode_entry(const char* p, + const char* limit, + uint32_t& shared, + uint32_t& non_shared, + uint32_t& value_length); + +bool get_varint64(std::string& input, uint64_t* value); + +std::string encode_tensor_name_slice(const std::string& name, + const std::vector& starts, + const std::vector lengths); +} // namespace tensorflow +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/tensorflow/src/checkpoint_v1_reader.cpp b/src/frontends/tensorflow/src/checkpoint_v1_reader.cpp new file mode 100644 index 00000000000..b51b8c91b3e --- /dev/null +++ b/src/frontends/tensorflow/src/checkpoint_v1_reader.cpp @@ -0,0 +1,262 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "checkpoint_v1_reader.hpp" + +#include "checkpoint_utils.hpp" +#include "openvino/frontend/exception.hpp" +#include "openvino/util/file_util.hpp" +#include "saved_tensor_slice.pb.h" +#include "tf_utils.hpp" + +#ifdef ENABLE_SNAPPY_COMPRESSION +# include "snappy.h" +#endif + +using namespace ov::frontend::tensorflow; + +namespace { +std::vector list_files_in_dir(const std::string& directory_path) { + std::vector res; + try { + ov::util::iterate_files( + directory_path, + [&res](const std::string& file_path, bool is_dir) { + auto file = ov::util::get_file_name(file_path); + if (!is_dir) { + res.push_back(file_path); + } + }, + false, + true); + } catch (...) { + // Ignore exceptions + } + return res; +} +} // namespace + +CheckpointV1Reader::CheckpointV1Reader(const std::string& checkpoints) : m_checkpoints(checkpoints) {} + +void CheckpointV1Reader::initialize() { + // figure out if the input is a file or a directory of checkpoints + std::vector checkpoints_paths; + if (ov::util::directory_exists(m_checkpoints)) { + checkpoints_paths = list_files_in_dir(m_checkpoints); + } else if (ov::util::file_exists(m_checkpoints)) { + checkpoints_paths = {m_checkpoints}; + } else { + FRONT_END_GENERAL_CHECK(false, "[TensorFlow Frontend] incorrect checkpoint: the checkpoint does not exist"); + } + + m_variables_info_map.clear(); + + for (auto checkpoint_path : checkpoints_paths) { + // create ifstream for each shard + std::shared_ptr shard_stream = + std::make_shared(checkpoint_path, std::ifstream::in | std::ifstream::binary); + FRONT_END_GENERAL_CHECK( + shard_stream && shard_stream->is_open(), + "[TensorFlow Frontend] incorrect model: checkpoint file " + checkpoint_path + "does not exist"); + const int32_t shard_ind = static_cast(m_shards.size()); + m_shards.push_back(shard_stream); + m_shard_names.push_back(checkpoint_path); + std::string value; + find_entry(shard_stream, checkpoint_path, SAVED_TENSOR_SLICES_KEY, value); + + // parse empty index block + // This is only present at the first item of each checkpoint file and serves + // as a table of contents, listing all the tensor slices saved in this file. + ::tensorflow::SavedTensorSlices sts; + FRONT_END_GENERAL_CHECK(sts.ParseFromArray(value.data(), static_cast(value.size())), + "[TensorFlow Frontend] incorrect input checkpoint file or internal error: cannot parse " + "SavedTensorSlices entry"); + for (const auto& saved_slice_meta : sts.meta().tensor()) { + // parse shapes and types for variables + VariableInfo var_info; + var_info.shard_id = shard_ind; + auto variable_name = saved_slice_meta.name(); // original variable name (not encoded) + var_info.variable_shape = saved_slice_meta.shape(); + var_info.variable_type = saved_slice_meta.type(); + + // save starts and lenghts of slices for variable name encoding + for (const auto& slice : saved_slice_meta.slice()) { + // var_info.starts.push_back(slice.extent()) + for (const auto& extent : slice.extent()) { + var_info.starts.push_back(extent.start()); + if (extent.has_length()) { + var_info.lenghts.push_back(extent.length()); + } else { + var_info.lenghts.push_back(-1); + } + } + } + m_variables_info_map[variable_name] = var_info; + } + } +} + +void CheckpointV1Reader::seek_block(const std::string& shard_name, + const std::string& target_key, + const char* block_ptr, + const uint32_t restarts, + std::string& value) const { + // parsing the next key starts at the end of value, so set value accordingly + const char* curr_value_pos = block_ptr; + const char* limit = block_ptr + restarts; // restarts come right after data + std::string key = ""; + + bool is_found = false; + while (true) { + FRONT_END_GENERAL_CHECK( + curr_value_pos < limit, + "[TensorFlow Frontend] incorrect model: no more entries to return, invalid checkpoint file " + shard_name); + + // decode next entry + // each entry looks as follows: + // | shared (1 byte) | non-shared (1 byte) | value_length (1 byte) | key (non-shared bytes) | + // | value (value_length bytes) | + uint32_t shared, non_shared, value_length; + curr_value_pos = decode_entry(curr_value_pos, limit, shared, non_shared, value_length); + FRONT_END_GENERAL_CHECK( + curr_value_pos && key.size() >= shared, + "[TensorFlow Frontend] incorrect model: corruption error in checkpoint file " + shard_name); + + key.resize(shared); + key.append(curr_value_pos, non_shared); + value = std::string(curr_value_pos + non_shared, value_length); + curr_value_pos += (non_shared + value_length); + + if (key.compare(target_key) >= 0) { + is_found = true; + break; + } + } + FRONT_END_GENERAL_CHECK( + is_found, + "[TensorFlow Frontend] incorrect input model: checkpoint file " + shard_name + " can be incorrect"); +} + +void CheckpointV1Reader::init_block(const std::shared_ptr& shard, + const std::string& shard_name, + uint64_t offset, + uint64_t size, + std::string& block, + uint64_t& restart_offset) const { + // check a size of the shard + FRONT_END_GENERAL_CHECK(shard, + "[TensorFlow Frontend] internal error: nullptr pointer to checkpoint file " + shard_name); + shard->seekg(0, shard->end); + uint64_t shard_size = static_cast(shard->tellg()); + FRONT_END_GENERAL_CHECK(offset < shard_size, + "[TensorFlow Frontend] internal error or inconsistent checkpoint file: block offset is " + "out-of-range for checkpoint file " + + shard_name); + auto n = size + BLOCK_TRAILER_SIZE; + FRONT_END_GENERAL_CHECK(n < (shard_size - offset), + "[TensorFlow Frontend] internal error or inconsistent checkpoint file: block size is " + "out-of-range for checkpoint file " + + shard_name); + + // read a block and decompress if needed + std::vector buf(n); + shard->seekg(offset); + shard->read(buf.data(), n); +#ifndef ENABLE_SNAPPY_COMPRESSION + FRONT_END_GENERAL_CHECK(buf[size] == 0, + "[TensorFlow Frontend] internal error: compression method for given block is not supported " + "for checkpoint file " + + shard_name); + result_data = std::string(buf.get(), size); +#else + FRONT_END_GENERAL_CHECK(buf[size] == 0 || buf[size] == 1, + "[TensorFlow Frontend] internal error: compression method for given block is not supported " + "for checkpoint file " + + shard_name); + if (buf[size] == 1) { + size_t uncompressed_length = 0; + FRONT_END_GENERAL_CHECK( + snappy::GetUncompressedLength(buf.data(), n, &uncompressed_length), + "[TensorFlow Frontend] internal error: cannot retrieve uncompressed block length for checkpoint file " + + shard_name); + std::string uncompressed_string; + block.clear(); + block.reserve(uncompressed_length); + snappy::Uncompress(buf.data(), n, &block); + } else { + block = std::string(buf.data(), size); + } +#endif + const char* data = block.data(); + size = block.size(); + + // find block characteristics: max_restarts_allowed, num_restarts and restart_offset + FRONT_END_GENERAL_CHECK( + size >= sizeof(uint32_t), + "[TensorFlow Frontend] internal error: block size must be not less than 4 bytes in checkpoint file " + + shard_name); + size_t max_restarts_allowed = (size - sizeof(uint32_t)) / sizeof(uint32_t); + uint32_t num_restarts = decode_fixed32(data + size - sizeof(uint32_t)); + FRONT_END_GENERAL_CHECK( + num_restarts <= max_restarts_allowed, + "[TensorFlow Frontend] internal error: num_restarts is greater than max_restarts_allowed in checkpoint file " + + shard_name); + restart_offset = size - (1 + num_restarts) * sizeof(uint32_t); +} + +void CheckpointV1Reader::find_entry(const std::shared_ptr& shard, + const std::string& shard_name, + const std::string& entry_key, + std::string& entry_value) { + // read footer of the shard file to get offset and size of index block + VIFooter footer; + footer.read(*shard); + uint64_t block_offset = footer.m_index.m_offset; + uint64_t block_size = footer.m_index.m_size; + std::string block; + + // initialize index block + uint64_t restart_offset = 0; + init_block(shard, shard_name, block_offset, block_size, block, restart_offset); + + // seek entry in the index block + // this entry contains offset and size of the data block + seek_block(shard_name, entry_key, block.data(), static_cast(restart_offset), entry_value); + + // initialize the data block + FRONT_END_GENERAL_CHECK( + get_varint64(entry_value, &block_offset) && get_varint64(entry_value, &block_size), + "[TensorFlow Frontend] incorrect input model: bad block handle in checkpoint file " + shard_name); + init_block(shard, shard_name, block_offset, block_size, block, restart_offset); + + // seek the final entry in the data block + seek_block(shard_name, entry_key, block.data(), static_cast(restart_offset), entry_value); +} + +void CheckpointV1Reader::read_variable(const std::string& variable_name, ov::Any& data) { + FRONT_END_GENERAL_CHECK(m_variables_info_map.count(variable_name) > 0, + "[TensorFlow Frontend] incorrect input model: checkpoint files does not contain data for " + "the required variable " + + variable_name); + auto var_info = m_variables_info_map[variable_name]; + auto shard_id = m_variables_info_map[variable_name].shard_id; + FRONT_END_GENERAL_CHECK(shard_id < static_cast(m_shards.size()), + "[TensorFlow Frontend] internal error: shard_id is greater than a number of shards"); + FRONT_END_GENERAL_CHECK( + m_shards.size() == m_shard_names.size(), + "[TensorFlow Frontend] internal error: number of shards does not match a number of their names"); + auto shard_ptr = m_shards[shard_id]; + auto shard_name = m_shard_names[shard_id]; + auto encoded_name = encode_tensor_name_slice(variable_name, var_info.starts, var_info.lenghts); + std::string raw_data; + find_entry(shard_ptr, shard_name, encoded_name, raw_data); + + // This is only present at the first item of each checkpoint file and serves + // as a table of contents, listing all the tensor slices saved in this file. + ::tensorflow::SavedTensorSlices sts; + FRONT_END_GENERAL_CHECK(sts.ParseFromArray(raw_data.data(), static_cast(raw_data.size())), + "[TensorFlow Frontend] incorrect input checkpoint file or internal error: cannot parse " + "SavedTensorSlices entry"); + data = unpack_tensor_proto(sts.data().data(), var_info.variable_shape, var_info.variable_type); +} diff --git a/src/frontends/tensorflow/src/checkpoint_v1_reader.hpp b/src/frontends/tensorflow/src/checkpoint_v1_reader.hpp new file mode 100644 index 00000000000..bfae3b139a2 --- /dev/null +++ b/src/frontends/tensorflow/src/checkpoint_v1_reader.hpp @@ -0,0 +1,81 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include +#include + +#include "checkpoint_utils.hpp" +#include "openvino/core/any.hpp" +#include "openvino/frontend/exception.hpp" +#include "saved_tensor_slice.pb.h" +#include "tensor_shape.pb.h" +#include "types.pb.h" + +namespace ov { +namespace frontend { +namespace tensorflow { +// stores information about shape, type, and shard id for Variable +struct VariableInfo { + ::tensorflow::TensorShapeProto variable_shape; + ::tensorflow::DataType variable_type; + int32_t shard_id; + size_t offset; + size_t size; + std::vector starts; + std::vector lenghts; +}; + +// reads checkpoints of v1 version +// it parses value, shape and type for Variable nodes +class CheckpointV1Reader { + const std::string m_checkpoints; + // a map from Variable name to its informations + std::unordered_map m_variables_info_map; + // a vector of streams for shards, where shard is one checkpoint file + std::vector> m_shards; + // a vector of shard names + std::vector m_shard_names; + +public: + /// \brief constructs CheckpointV1Reader for a given directory of checkpoint files + // CheckpointV1Reader(const std::string& checkpoints_dir); + CheckpointV1Reader(const std::string& checkpoints); + + /// \brief initialize Checkpoint V1 reader + void initialize(); + + /// \brief Produces ov::Any object that wraps ov::Tensor for the requested variable + /// it can also wraps string tensor + /// \param variable_name the requested variable name + /// \param a reference to the result + void read_variable(const std::string& variable_name, ov::Any& data); + +private: + /// \brief finds non-master key entry that uses already cached offset and sizes of data blocks + void find_entry(const std::shared_ptr& shard, + const std::string& shard_name, + const std::string& entry_key, + std::string& value); + + void seek_block(const std::string& shard_name, + const std::string& target, + const char* shard_data, + const uint32_t restarts, + std::string& value) const; + + void init_block(const std::shared_ptr& shard, + const std::string& shard_name, + uint64_t offset, + uint64_t size, + std::string& block, + uint64_t& restart_offset) const; +}; + +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/decoder_argdef.cpp b/src/frontends/tensorflow/src/decoder_argdef.cpp index cf3cb3379b1..3430bcbe5e3 100644 --- a/src/frontends/tensorflow/src/decoder_argdef.cpp +++ b/src/frontends/tensorflow/src/decoder_argdef.cpp @@ -8,6 +8,7 @@ #include "op_def.pb.h" #include "openvino/frontend/tensorflow/node_context.hpp" #include "openvino/frontend/tensorflow/special_types.hpp" +#include "tf_utils.hpp" #include "types.pb.h" namespace ov { diff --git a/src/frontends/tensorflow/src/decoder_proto.cpp b/src/frontends/tensorflow/src/decoder_proto.cpp index 936b566b75a..2488973c102 100644 --- a/src/frontends/tensorflow/src/decoder_proto.cpp +++ b/src/frontends/tensorflow/src/decoder_proto.cpp @@ -8,6 +8,7 @@ #include "node_def.pb.h" #include "openvino/frontend/tensorflow/node_context.hpp" #include "openvino/frontend/tensorflow/special_types.hpp" +#include "tf_utils.hpp" #include "types.pb.h" namespace ov { @@ -82,24 +83,6 @@ void extract_compressed_tensor_content(const ::tensorflow::TensorProto& tensor_p #endif } // namespace -ov::element::Type get_ov_type(const ::tensorflow::DataType& type) { - static const std::map<::tensorflow::DataType, ov::element::Type> type_map{ - {::tensorflow::DataType::DT_BOOL, ov::element::boolean}, - {::tensorflow::DataType::DT_INT16, ov::element::i16}, - {::tensorflow::DataType::DT_INT32, ov::element::i32}, - {::tensorflow::DataType::DT_INT64, ov::element::i64}, - {::tensorflow::DataType::DT_HALF, ov::element::f16}, - {::tensorflow::DataType::DT_FLOAT, ov::element::f32}, - {::tensorflow::DataType::DT_DOUBLE, ov::element::f64}, - {::tensorflow::DataType::DT_UINT8, ov::element::u8}, - {::tensorflow::DataType::DT_INT8, ov::element::i8}, - {::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16}}; - - auto it = type_map.find(type); - // for all unsupported types return dynamic type - return it == type_map.end() ? ov::element::dynamic : it->second; -} - ov::Any DecoderProto::get_attribute(const std::string& name) const { auto attrs = decode_attribute_helper(name); if (attrs.empty()) { @@ -194,92 +177,7 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const { } case ::tensorflow::AttrValue::ValueCase::kTensor: { - const auto& tensor_proto = attrs[0].tensor(); - const auto& tf_shape = tensor_proto.tensor_shape(); - ov::PartialShape pshape; - for (int i = 0; i < tf_shape.dim_size(); i++) { - pshape.push_back(tf_shape.dim(i).size()); - } - FRONT_END_GENERAL_CHECK(pshape.is_static(), "Dynamic shapes are not supported for Tensor attribute."); - const auto& tf_type = tensor_proto.dtype(); - auto ov_type = get_ov_type(tf_type); - if (tf_type != ::tensorflow::DataType::DT_STRING) { - FRONT_END_GENERAL_CHECK( - ov_type.is_static(), - "Encountered unknown element type " + DataType_Name(tf_type) + " on an empty tensor_proto"); - } else { - auto data = std::vector(); - for (const auto& item : tensor_proto.string_val()) { - data.push_back(item); - } - return data; - } - ov::Tensor res(ov_type, pshape.get_shape()); - auto tensor_content = tensor_proto.tensor_content(); - if (!tensor_content.empty() && tensor_proto.has_tensor_shape()) { - switch (ov_type) { - case ov::element::u8: - extract_tensor_content(tensor_content, &res); - break; - case ov::element::i8: - extract_tensor_content(tensor_content, &res); - break; - case ov::element::i16: - extract_tensor_content(tensor_content, &res); - break; - case ov::element::i32: - extract_tensor_content(tensor_content, &res); - break; - case ov::element::i64: - extract_tensor_content(tensor_content, &res); - break; - case ov::element::f16: - extract_tensor_content(tensor_content, &res); - break; - case ov::element::f32: - extract_tensor_content(tensor_content, &res); - break; - case ov::element::f64: - extract_tensor_content(tensor_content, &res); - break; - case ov::element::bf16: - extract_tensor_content(tensor_content, &res); - break; - default: - FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name()); - } - } else { - int64_t val_size = 0; - switch (ov_type) { - case ov::element::boolean: - val_size = tensor_proto.bool_val_size(); - extract_compressed_tensor_content(tensor_proto, val_size, &res); - break; - case ov::element::i32: - val_size = tensor_proto.int_val_size(); - extract_compressed_tensor_content(tensor_proto, val_size, &res); - break; - case ov::element::i64: - val_size = tensor_proto.int64_val_size(); - extract_compressed_tensor_content(tensor_proto, val_size, &res); - break; - case ov::element::f16: - val_size = tensor_proto.half_val_size(); - extract_compressed_tensor_content(tensor_proto, val_size, &res); - break; - case ov::element::f32: - val_size = tensor_proto.float_val_size(); - extract_compressed_tensor_content(tensor_proto, val_size, &res); - break; - case ov::element::f64: - val_size = tensor_proto.double_val_size(); - extract_compressed_tensor_content(tensor_proto, val_size, &res); - break; - default: - FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name()); - } - } - return res; + return unpack_tensor_proto(attrs[0].tensor()); } case ::tensorflow::AttrValue::ValueCase::kPlaceholder: FRONT_END_GENERAL_CHECK(false, diff --git a/src/frontends/tensorflow/src/decoder_proto.hpp b/src/frontends/tensorflow/src/decoder_proto.hpp index 6d1bdbd11ee..338bfdeccea 100644 --- a/src/frontends/tensorflow/src/decoder_proto.hpp +++ b/src/frontends/tensorflow/src/decoder_proto.hpp @@ -22,8 +22,6 @@ namespace ov { namespace frontend { namespace tensorflow { -ov::element::Type get_ov_type(const ::tensorflow::DataType& type); - void parse_producer_name(const std::string& producer_port_name, std::string& producer_name, std::string& producer_output_port_name, diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index b47aef03c91..06036576354 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -21,6 +21,7 @@ #include "openvino/op/util/multi_subgraph_base.hpp" #include "openvino/pass/manager.hpp" #include "openvino/util/common_util.hpp" +#include "openvino/util/file_util.hpp" #include "openvino/util/log.hpp" #include "so_extension.hpp" #include "tf_framework_node.hpp" @@ -103,10 +104,13 @@ bool FrontEnd::supported_impl(const std::vector& variants) const { // Last boolean flag in `variants` (if presented) is reserved for FE configuration size_t extra_variants_num = variants.size() > 0 && variants[variants.size() - 1].is() ? 1 : 0; - // TODO: support checkpoint format + // For TF1 models it can be a case of two input variants: input model and v1 checkpoints if (variants.size() != 1 + extra_variants_num) return false; + // to figure out if the model with v1 checkpoints is supported, + // it is sufficient to check only the input model format + // avoid parsing of checkpoints here if (variants[0].is()) { std::string model_path = variants[0].as(); if (ov::util::ends_with(model_path, ".pb") && GraphIteratorProto::is_supported(model_path)) { @@ -122,6 +126,18 @@ bool FrontEnd::supported_impl(const std::vector& variants) const { // handle text protobuf format return true; } + } else if (variants[0].is>() && variants[0].as>().size() == 2) { + // here, we assume to get the input model path and checkpoints directory + auto paths = variants[0].as>(); + auto model_path = paths[0]; + auto checkpoints_dir = paths[1]; + if (GraphIteratorProto::is_supported(model_path)) { + // binary protobuf format with checkpoints + return true; + } else if (GraphIteratorProtoTxt::is_supported(model_path)) { + // text protobuf format with checkpoints + return true; + } } #if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) else if (variants[0].is()) { @@ -140,6 +156,18 @@ bool FrontEnd::supported_impl(const std::vector& variants) const { // handle text protobuf format return true; } + } else if (variants[0].is>() && variants[0].as>().size() == 2) { + // here, we assume to get the input model path and checkpoints directory + auto paths = variants[0].as>(); + auto model_path = ov::util::wstring_to_string(paths[0]); + auto checkpoints_dir = ov::util::wstring_to_string(paths[1]); + if (GraphIteratorProto::is_supported(model_path)) { + // binary protobuf format with checkpoints + return true; + } else if (GraphIteratorProtoTxt::is_supported(model_path)) { + // text protobuf format with checkpoints + return true; + } } #endif else if (variants[0].is()) { @@ -150,13 +178,14 @@ bool FrontEnd::supported_impl(const std::vector& variants) const { } ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector& variants) const { - // TODO: support checkpoint format - // Last boolean flag in `variants` (if presented) is reserved for FE configuration size_t extra_variants_num = variants.size() > 0 && variants[variants.size() - 1].is() ? 1 : 0; - FRONT_END_GENERAL_CHECK(variants.size() == 1 + extra_variants_num, - "[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports " - "frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats."); + + // For TF1 models it can be a case of two input variants: input model and v1 checkpoints + FRONT_END_GENERAL_CHECK( + variants.size() == 1 + extra_variants_num, + "[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports " + "frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats, and v1 checkpoints."); if (variants[0].is()) { auto model_path = variants[0].as(); @@ -175,6 +204,7 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector& va graph_iterator->get_variables_index(), graph_iterator->get_saved_model_input_names(), graph_iterator->get_saved_model_output_names(), + nullptr, true); } else if (GraphIteratorMeta::is_supported(model_path)) { auto graph_iterator = std::make_shared(model_path); @@ -183,11 +213,42 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector& va graph_iterator->get_variables_index(), graph_iterator->get_metagraph_input_names(), graph_iterator->get_metagraph_output_names(), + nullptr, true); } else if (GraphIteratorProtoTxt::is_supported(model_path)) { // handle text protobuf format return std::make_shared(std::make_shared(model_path), m_telemetry); } + } else if (variants[0].is>()) { + // here, we assume to get the input model path and checkpoints directory + auto paths = variants[0].as>(); + FRONT_END_GENERAL_CHECK( + paths.size() == 2, + "[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports " + "frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats, and v1 checkpoints."); + auto model_path = paths[0]; + auto checkpoints_dir = paths[1]; + if (GraphIteratorProto::is_supported(model_path)) { + auto graph_iterator = std::make_shared(model_path, checkpoints_dir); + // handle binary protobuf format with checkpoints + return std::make_shared(graph_iterator, + m_telemetry, + nullptr, + nullptr, + nullptr, + graph_iterator->get_checkpoint_v1_reader(), + false); + } else if (GraphIteratorProtoTxt::is_supported(model_path)) { + auto graph_iterator = std::make_shared(model_path, checkpoints_dir); + // handle text protobuf format with checkpoints + return std::make_shared(graph_iterator, + m_telemetry, + nullptr, + nullptr, + nullptr, + graph_iterator->get_checkpoint_v1_reader(), + false); + } } #if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) else if (variants[0].is()) { @@ -209,6 +270,7 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector& va graph_iterator->get_variables_index(), graph_iterator->get_saved_model_input_names(), graph_iterator->get_saved_model_output_names(), + nullptr, true); } else if (GraphIteratorMeta::is_supported(model_path)) { auto graph_iterator = std::make_shared(model_path); @@ -217,11 +279,42 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector& va graph_iterator->get_variables_index(), graph_iterator->get_metagraph_input_names(), graph_iterator->get_metagraph_output_names(), + nullptr, true); } else if (GraphIteratorProtoTxt::is_supported(model_path)) { // handle text protobuf format with a path in Unicode return std::make_shared(std::make_shared(model_path), m_telemetry); } + } else if (variants[0].is>()) { + // here, we assume to get the input model path and checkpoints directory + auto paths = variants[0].as>(); + FRONT_END_GENERAL_CHECK( + paths.size() == 2, + "[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports " + "frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats, and v1 checkpoints."); + auto model_path = ov::util::wstring_to_string(paths[0]); + auto checkpoints_dir = ov::util::wstring_to_string(paths[1]); + if (GraphIteratorProto::is_supported(model_path)) { + auto graph_iterator = std::make_shared(model_path, checkpoints_dir); + // handle binary protobuf format with checkpoints + return std::make_shared(graph_iterator, + m_telemetry, + nullptr, + nullptr, + nullptr, + graph_iterator->get_checkpoint_v1_reader(), + false); + } else if (GraphIteratorProtoTxt::is_supported(model_path)) { + auto graph_iterator = std::make_shared(model_path, checkpoints_dir); + // handle text protobuf format with checkpoints + return std::make_shared(graph_iterator, + m_telemetry, + nullptr, + nullptr, + nullptr, + graph_iterator->get_checkpoint_v1_reader(), + false); + } } #endif else if (variants[0].is()) { @@ -232,7 +325,7 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector& va FRONT_END_GENERAL_CHECK(false, "[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports " - "frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats."); + "frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta), and v1 checkpoints."); return nullptr; } diff --git a/src/frontends/tensorflow/src/graph_iterator_proto.hpp b/src/frontends/tensorflow/src/graph_iterator_proto.hpp index 43ddb485e03..8b073b08373 100644 --- a/src/frontends/tensorflow/src/graph_iterator_proto.hpp +++ b/src/frontends/tensorflow/src/graph_iterator_proto.hpp @@ -7,6 +7,7 @@ #include #include +#include "checkpoint_v1_reader.hpp" #include "decoder_argdef.hpp" #include "decoder_proto.hpp" #include "graph.pb.h" @@ -22,6 +23,7 @@ class GraphIteratorProto : public GraphIterator { protected: std::shared_ptr<::tensorflow::GraphDef> m_graph_def; std::shared_ptr<::tensorflow::FunctionDef> m_func_def; + std::shared_ptr m_checkpoint_v1_reader; size_t node_index = 0; std::vector> m_decoders; @@ -32,6 +34,7 @@ protected: GraphIteratorProto() : m_graph_def(std::make_shared<::tensorflow::GraphDef>()), m_func_def(nullptr), + m_checkpoint_v1_reader(nullptr), m_library_map() {} void initialize_decoders_and_library() { @@ -52,12 +55,20 @@ protected: } } + template + void initialize_v1_checkpoints(const std::basic_string& checkpoint_directory) { + m_checkpoint_v1_reader = std::make_shared(checkpoint_directory); + m_checkpoint_v1_reader->initialize(); + } + public: GraphIteratorProto(const std::shared_ptr<::tensorflow::GraphDef>& graph_def, const std::shared_ptr<::tensorflow::FunctionDef>& func_def, - const std::unordered_map& library_map) + const std::unordered_map& library_map, + const std::shared_ptr checkpoint_v1_reader) : m_graph_def(graph_def), m_func_def(func_def), + m_checkpoint_v1_reader(checkpoint_v1_reader), m_library_map(library_map) { auto nodes_size = m_func_def->node_def_size(); auto input_size = m_func_def->signature().input_arg_size(); @@ -91,11 +102,13 @@ public: } } + /// \brief Construct GraphIterator for the frozen model without v1 checkpoints template - GraphIteratorProto(const std::basic_string& path) + GraphIteratorProto(const std::basic_string& model_path) : m_graph_def(std::make_shared<::tensorflow::GraphDef>()), - m_func_def(nullptr) { - std::ifstream pb_stream(path.c_str(), std::ios::in | std::ifstream::binary); + m_func_def(nullptr), + m_checkpoint_v1_reader(nullptr) { + std::ifstream pb_stream(model_path, std::ios::in | std::ifstream::binary); FRONT_END_GENERAL_CHECK(pb_stream && pb_stream.is_open(), "Model file does not exist"); FRONT_END_GENERAL_CHECK(m_graph_def->ParseFromIstream(&pb_stream), "Model cannot be parsed"); @@ -103,11 +116,26 @@ public: initialize_decoders_and_library(); } + /// \brief Construct GraphIterator for the frozen model with v1 checkpoints + template + GraphIteratorProto(const std::basic_string& model_path, const std::basic_string& checkpoint_directory) + : m_graph_def(std::make_shared<::tensorflow::GraphDef>()), + m_func_def(nullptr), + m_checkpoint_v1_reader(nullptr) { + std::ifstream pb_stream(model_path, std::ios::in | std::ifstream::binary); + + FRONT_END_GENERAL_CHECK(pb_stream && pb_stream.is_open(), "Model file does not exist"); + FRONT_END_GENERAL_CHECK(m_graph_def->ParseFromIstream(&pb_stream), "Model cannot be parsed"); + + initialize_decoders_and_library(); + initialize_v1_checkpoints(checkpoint_directory); + } + /// \brief Check if the input file is supported template static bool is_supported(const std::basic_string& path) { try { - std::ifstream pb_stream(path.c_str(), std::ios::in | std::ifstream::binary); + std::ifstream pb_stream(path, std::ios::in | std::ifstream::binary); auto graph_def = std::make_shared<::tensorflow::GraphDef>(); return pb_stream && pb_stream.is_open() && graph_def->ParsePartialFromIstream(&pb_stream) && graph_def->node_size() > 0; @@ -116,6 +144,11 @@ public: } } + /// \brief Get checkpoint v1 reader for checkpoint restoring in translator for Variable operation + std::shared_ptr get_checkpoint_v1_reader() const { + return m_checkpoint_v1_reader; + } + /// \brief Set iterator to the start position void reset() override { node_index = 0; @@ -152,7 +185,7 @@ public: auto func = m_graph_def->library().function(func_ind); auto func_ptr = std::make_shared<::tensorflow::FunctionDef>(func); - return std::make_shared(m_graph_def, func_ptr, m_library_map); + return std::make_shared(m_graph_def, func_ptr, m_library_map, m_checkpoint_v1_reader); } return nullptr; diff --git a/src/frontends/tensorflow/src/graph_iterator_proto_txt.hpp b/src/frontends/tensorflow/src/graph_iterator_proto_txt.hpp index 3b483a9688d..6d5b6494f76 100644 --- a/src/frontends/tensorflow/src/graph_iterator_proto_txt.hpp +++ b/src/frontends/tensorflow/src/graph_iterator_proto_txt.hpp @@ -17,9 +17,10 @@ namespace tensorflow { class GraphIteratorProtoTxt : public GraphIteratorProto { public: + /// \brief Construct GraphIterator for the frozen model in text format without v1 checkpoints template GraphIteratorProtoTxt(const std::basic_string& path) : GraphIteratorProto() { - std::ifstream pbtxt_stream(path.c_str(), std::ios::in); + std::ifstream pbtxt_stream(path, std::ios::in); FRONT_END_GENERAL_CHECK(pbtxt_stream && pbtxt_stream.is_open(), "Model file does not exist"); auto input_stream = std::make_shared<::google::protobuf::io::IstreamInputStream>(&pbtxt_stream); FRONT_END_GENERAL_CHECK(input_stream, "Model cannot be read"); @@ -31,6 +32,23 @@ public: initialize_decoders_and_library(); } + /// \brief Construct GraphIterator for the frozen model in text format with v1 checkpoints + template + GraphIteratorProtoTxt(const std::basic_string& path, const std::basic_string& checkpoint_directory) + : GraphIteratorProto() { + std::ifstream pbtxt_stream(path, std::ios::in); + FRONT_END_GENERAL_CHECK(pbtxt_stream && pbtxt_stream.is_open(), "Model file does not exist"); + auto input_stream = std::make_shared<::google::protobuf::io::IstreamInputStream>(&pbtxt_stream); + FRONT_END_GENERAL_CHECK(input_stream, "Model cannot be read"); + auto is_parsed = ::google::protobuf::TextFormat::Parse(input_stream.get(), m_graph_def.get()); + FRONT_END_GENERAL_CHECK( + is_parsed, + "[TensorFlow Frontend] Incorrect model or internal error: Model in text Protobuf format cannot be parsed."); + + initialize_decoders_and_library(); + initialize_v1_checkpoints(checkpoint_directory); + } + /// \brief Check if the input file is supported template static bool is_supported(const std::basic_string& path) { diff --git a/src/frontends/tensorflow/src/input_model.cpp b/src/frontends/tensorflow/src/input_model.cpp index 27a184368bb..d63ff7cf5f5 100644 --- a/src/frontends/tensorflow/src/input_model.cpp +++ b/src/frontends/tensorflow/src/input_model.cpp @@ -59,6 +59,7 @@ public: const std::shared_ptr& variables_index, const std::shared_ptr> saved_model_input_names, const std::shared_ptr> saved_model_output_names, + const std::shared_ptr checkpoint_v1_reader, const bool native_format = false); std::vector get_inputs() const; std::vector get_outputs() const; @@ -86,6 +87,7 @@ public: std::shared_ptr get_variables_index() const; std::shared_ptr> get_saved_model_input_names() const; std::shared_ptr> get_saved_model_output_names() const; + std::shared_ptr get_checkpoint_v1_reader() const; private: void load_places(); @@ -110,6 +112,7 @@ private: std::shared_ptr m_variables_index; std::shared_ptr> m_saved_model_input_names; std::shared_ptr> m_saved_model_output_names; + std::shared_ptr m_checkpoint_v1_reader; bool m_native_format; bool m_custom_inputs; @@ -256,6 +259,10 @@ std::shared_ptr> InputModel::InputModelTFImpl return m_saved_model_output_names; } +std::shared_ptr InputModel::InputModelTFImpl::get_checkpoint_v1_reader() const { + return m_checkpoint_v1_reader; +} + std::vector> InputModel::InputModelTFImpl::get_op_places() { return topologically_sort_op_nodes(); } @@ -417,6 +424,7 @@ InputModel::InputModelTFImpl::InputModelTFImpl( const std::shared_ptr& variables_index, const std::shared_ptr> saved_model_input_names, const std::shared_ptr> saved_model_output_names, + const std::shared_ptr checkpoint_v1_reader, const bool native_format) : m_graph_iterator(graph_iterator), m_input_model(input_model), @@ -424,6 +432,7 @@ InputModel::InputModelTFImpl::InputModelTFImpl( m_variables_index(variables_index), m_saved_model_input_names(saved_model_input_names), m_saved_model_output_names(saved_model_output_names), + m_checkpoint_v1_reader(checkpoint_v1_reader), m_native_format(native_format) { FRONT_END_GENERAL_CHECK(m_graph_iterator, "Null pointer specified for GraphIterator"); m_input_names = graph_iterator->get_input_names(); @@ -602,6 +611,7 @@ InputModel::InputModel(const GraphIterator::Ptr& graph_iterator, const std::shared_ptr& variables_index, const std::shared_ptr> saved_model_input_names, const std::shared_ptr> saved_model_output_names, + const std::shared_ptr checkpoint_v1_reader, const bool native_format) : _impl{std::make_shared(graph_iterator, *this, @@ -609,6 +619,7 @@ InputModel::InputModel(const GraphIterator::Ptr& graph_iterator, variables_index, saved_model_input_names, saved_model_output_names, + checkpoint_v1_reader, native_format)} {} std::shared_ptr InputModel::get_variables_index() { @@ -623,6 +634,10 @@ std::shared_ptr> InputModel::get_saved_model_ return _impl->get_saved_model_output_names(); } +std::shared_ptr InputModel::get_checkpoint_v1_reader() const { + return _impl->get_checkpoint_v1_reader(); +} + std::vector InputModel::get_input_names() const { return _impl->get_input_names(); } diff --git a/src/frontends/tensorflow/src/input_model.hpp b/src/frontends/tensorflow/src/input_model.hpp index 179f244b8bf..a95a4447cc0 100644 --- a/src/frontends/tensorflow/src/input_model.hpp +++ b/src/frontends/tensorflow/src/input_model.hpp @@ -4,6 +4,7 @@ #pragma once +#include "checkpoint_v1_reader.hpp" #include "openvino/frontend/extension/telemetry.hpp" #include "openvino/frontend/graph_iterator.hpp" #include "openvino/frontend/input_model.hpp" @@ -35,6 +36,7 @@ public: const std::shared_ptr& variables_index = {}, const std::shared_ptr> saved_model_input_names = nullptr, const std::shared_ptr> saved_model_output_names = nullptr, + const std::shared_ptr checkpoint_v1_reader = nullptr, const bool native_format = false); std::vector get_inputs() const override; @@ -53,6 +55,8 @@ public: std::shared_ptr> get_saved_model_input_names() const; std::shared_ptr> get_saved_model_output_names() const; + std::shared_ptr get_checkpoint_v1_reader() const; + std::map> get_tensor_places() const; }; diff --git a/src/frontends/tensorflow/src/op/variable.cpp b/src/frontends/tensorflow/src/op/variable.cpp new file mode 100644 index 00000000000..f7028aad358 --- /dev/null +++ b/src/frontends/tensorflow/src/op/variable.cpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/string_constant.hpp" +#include "helper_ops/unsupported_constant.hpp" +#include "input_model.hpp" +#include "openvino/frontend/tensorflow/node_context.hpp" +#include "openvino/opsets/opset11.hpp" +#include "translate_session.hpp" + +using namespace std; +using namespace ov; +using namespace ov::frontend::tensorflow; +using namespace ov::opset11; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_variable_op(const NodeContext& node) { + default_op_checks(node, 0, {"Variable"}); + auto variable_name = node.get_name(); + + auto translate_session = node.get_translate_session(); + TENSORFLOW_OP_VALIDATION(node, + translate_session, + "[TensorFlow Frontend] Internal error: Translate session is nullptr."); + auto model = dynamic_pointer_cast(translate_session->get_input_model()); + TENSORFLOW_OP_VALIDATION( + node, + model, + "[TensorFlow Frontend] Internal error: input model is unable to cast to TensorFlow Frontend InputModel."); + auto checkpoint_v1_reader = model->get_checkpoint_v1_reader(); + TENSORFLOW_OP_VALIDATION(node, + checkpoint_v1_reader, + "[TensorFlow Frontend] incorrect input model: checkpoint to restore variable " + + variable_name + " is not provided."); + + ov::Any variable_data; + checkpoint_v1_reader->read_variable(variable_name, variable_data); + + shared_ptr const_node = nullptr; + if (variable_data.is()) { + auto ov_tensor = variable_data.as(); + const_node = make_shared(ov_tensor); + } else if (variable_data.is>()) { + // a case of string tensor that should be assigned to the variable + const_node = make_shared(variable_data, node.get_decoder()); + } else { + // data of unknown type + auto const_node = make_shared("Variable of unsupported type", node.get_decoder()); + } + + set_node_name(variable_name, const_node); + return {const_node}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 76aa77111a0..42202e9a0d2 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -20,29 +20,30 @@ namespace op { #define TF_OP_CONVERTER(op) OutputVector op(const ov::frontend::tensorflow::NodeContext& node) -TF_OP_CONVERTER(translate_if_op); +TF_OP_CONVERTER(translate_assignvariable_op); TF_OP_CONVERTER(translate_block_lstm_op); TF_OP_CONVERTER(translate_fifo_queue_op); TF_OP_CONVERTER(translate_gru_block_cell_op); TF_OP_CONVERTER(translate_hash_table_op); +TF_OP_CONVERTER(translate_if_op); TF_OP_CONVERTER(translate_iterator_get_next_op); TF_OP_CONVERTER(translate_iterator_op); +TF_OP_CONVERTER(translate_mergev2checkpoint_op); TF_OP_CONVERTER(translate_partitioned_call_op); +TF_OP_CONVERTER(translate_placeholder_linked_op); TF_OP_CONVERTER(translate_queue_dequeue_op); TF_OP_CONVERTER(translate_queue_dequeue_many_op); +TF_OP_CONVERTER(translate_readvariable_op); +TF_OP_CONVERTER(translate_restorev2_op); TF_OP_CONVERTER(translate_sparse_fill_empty_rows_op); TF_OP_CONVERTER(translate_sparse_reshape_op); TF_OP_CONVERTER(translate_sparse_segment_sum_op); -TF_OP_CONVERTER(translate_varisinitialized_op); -TF_OP_CONVERTER(translate_readvariable_op); -TF_OP_CONVERTER(translate_assignvariable_op); -TF_OP_CONVERTER(translate_varhandle_op); -TF_OP_CONVERTER(translate_restorev2_op); TF_OP_CONVERTER(translate_staticregexfullmatch_op); TF_OP_CONVERTER(translate_stringjoin_op); -TF_OP_CONVERTER(translate_mergev2checkpoint_op); +TF_OP_CONVERTER(translate_varhandle_op); +TF_OP_CONVERTER(translate_variable_op); +TF_OP_CONVERTER(translate_varisinitialized_op); TF_OP_CONVERTER(translate_while_op); -TF_OP_CONVERTER(translate_placeholder_linked_op); const std::map get_supported_ops() { return { @@ -276,6 +277,9 @@ const std::map get_supported_ops() { {"VarHandleOp", CreatorFunction(translate_varhandle_op)}, {"VariableV2", CreatorFunction(translate_varhandle_op)}, + // Translator for Checkpoint V1 + {"Variable", CreatorFunction(translate_variable_op)}, + // Translators for internal operations {"BlockLSTM", CreatorFunction(translate_block_lstm_op)}, {"GRUBlockCell", CreatorFunction(translate_gru_block_cell_op)}, diff --git a/src/frontends/tensorflow/src/tf_utils.cpp b/src/frontends/tensorflow/src/tf_utils.cpp new file mode 100644 index 00000000000..7cccb507088 --- /dev/null +++ b/src/frontends/tensorflow/src/tf_utils.cpp @@ -0,0 +1,192 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "tf_utils.hpp" + +#include + +#include "openvino/core/type/element_type.hpp" +#include "openvino/frontend/exception.hpp" +#include "openvino/runtime/tensor.hpp" + +using namespace ov; + +namespace { + +template +void extract_tensor_content(const std::string& tensor_content, ov::Tensor* values) { + const auto tensor_content_size = tensor_content.size(); + FRONT_END_GENERAL_CHECK(tensor_content_size % sizeof(T) == 0, + "Size of tensor_content (", + tensor_content_size, + ") is not a multiple of ", + sizeof(T)); + + const T* tensor_values = reinterpret_cast(tensor_content.data()); + FRONT_END_GENERAL_CHECK(values->get_size() == tensor_content_size / sizeof(T), + "Size of tensor is not equal to tensor_content size."); + std::copy(tensor_values, tensor_values + tensor_content_size / sizeof(T), values->data()); +} + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4244) // possible loss of data +# pragma warning(disable : 4267) // possible loss of data +#endif +template +void extract_compressed_tensor_content(const ::tensorflow::TensorProto& tensor_proto, + int64_t val_size, + ov::Tensor* values) { + auto val_lastsaved = static_cast(0); + auto values_data = values->data(); + for (size_t i = 0; i < values->get_size(); i++) { + if (val_size == 0) { + values_data[i] = static_cast(0); + } else if (static_cast(i) < val_size) { + auto val_i = static_cast(0); + switch (values->get_element_type()) { + // TODO: there are more element types to support here + case ov::element::boolean: + val_i = tensor_proto.bool_val()[i]; + break; + case ov::element::i32: + val_i = tensor_proto.int_val()[i]; + break; + case ov::element::i64: + val_i = tensor_proto.int64_val()[i]; + break; + case ov::element::f16: + val_i = float16::from_bits(tensor_proto.half_val()[i]); + break; + case ov::element::f32: + val_i = tensor_proto.float_val()[i]; + break; + case ov::element::f64: + val_i = tensor_proto.double_val()[i]; + break; + default: + FRONT_END_THROW("Encountered unknown element type " + values->get_element_type().get_type_name()); + } + values_data[i] = val_i; + val_lastsaved = val_i; + } else { + values_data[i] = val_lastsaved; + } + } +} +#if defined(_MSC_VER) +# pragma warning(pop) +#endif +} // namespace + +ov::element::Type ov::frontend::tensorflow::get_ov_type(const ::tensorflow::DataType& type) { + static const std::map<::tensorflow::DataType, ov::element::Type> type_map{ + {::tensorflow::DataType::DT_BOOL, ov::element::boolean}, + {::tensorflow::DataType::DT_INT16, ov::element::i16}, + {::tensorflow::DataType::DT_INT32, ov::element::i32}, + {::tensorflow::DataType::DT_INT64, ov::element::i64}, + {::tensorflow::DataType::DT_HALF, ov::element::f16}, + {::tensorflow::DataType::DT_FLOAT, ov::element::f32}, + {::tensorflow::DataType::DT_DOUBLE, ov::element::f64}, + {::tensorflow::DataType::DT_UINT8, ov::element::u8}, + {::tensorflow::DataType::DT_INT8, ov::element::i8}, + {::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16}}; + + auto it = type_map.find(type); + // for all unsupported types return dynamic type + return it == type_map.end() ? ov::element::dynamic : it->second; +} + +ov::Any ov::frontend::tensorflow::unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto) { + return unpack_tensor_proto(tensor_proto, tensor_proto.tensor_shape(), tensor_proto.dtype()); +} + +ov::Any ov::frontend::tensorflow::unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto, + const ::tensorflow::TensorShapeProto& tensor_shape, + const ::tensorflow::DataType& tensor_type) { + ov::PartialShape pshape; + for (int i = 0; i < tensor_shape.dim_size(); i++) { + pshape.push_back(tensor_shape.dim(i).size()); + } + FRONT_END_GENERAL_CHECK(pshape.is_static(), "Dynamic shapes are not supported for Tensor attribute."); + ov::element::Type ov_type = get_ov_type(tensor_type); + + if (tensor_type != ::tensorflow::DataType::DT_STRING) { + FRONT_END_GENERAL_CHECK( + ov_type.is_static(), + "Encountered unknown element type " + DataType_Name(tensor_type) + " on an empty tensor_proto"); + } else { + auto data = std::vector(); + for (const auto& item : tensor_proto.string_val()) { + data.push_back(item); + } + return data; + } + ov::Tensor res(ov_type, pshape.get_shape()); + auto tensor_content = tensor_proto.tensor_content(); + if (!tensor_content.empty() && tensor_proto.has_tensor_shape()) { + switch (ov_type) { + case ov::element::u8: + extract_tensor_content(tensor_content, &res); + break; + case ov::element::i8: + extract_tensor_content(tensor_content, &res); + break; + case ov::element::i16: + extract_tensor_content(tensor_content, &res); + break; + case ov::element::i32: + extract_tensor_content(tensor_content, &res); + break; + case ov::element::i64: + extract_tensor_content(tensor_content, &res); + break; + case ov::element::f16: + extract_tensor_content(tensor_content, &res); + break; + case ov::element::f32: + extract_tensor_content(tensor_content, &res); + break; + case ov::element::f64: + extract_tensor_content(tensor_content, &res); + break; + case ov::element::bf16: + extract_tensor_content(tensor_content, &res); + break; + default: + FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name()); + } + } else { + int64_t val_size = 0; + switch (ov_type) { + case ov::element::boolean: + val_size = tensor_proto.bool_val_size(); + extract_compressed_tensor_content(tensor_proto, val_size, &res); + break; + case ov::element::i32: + val_size = tensor_proto.int_val_size(); + extract_compressed_tensor_content(tensor_proto, val_size, &res); + break; + case ov::element::i64: + val_size = tensor_proto.int64_val_size(); + extract_compressed_tensor_content(tensor_proto, val_size, &res); + break; + case ov::element::f16: + val_size = tensor_proto.half_val_size(); + extract_compressed_tensor_content(tensor_proto, val_size, &res); + break; + case ov::element::f32: + val_size = tensor_proto.float_val_size(); + extract_compressed_tensor_content(tensor_proto, val_size, &res); + break; + case ov::element::f64: + val_size = tensor_proto.double_val_size(); + extract_compressed_tensor_content(tensor_proto, val_size, &res); + break; + default: + FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name()); + } + } + return res; +} diff --git a/src/frontends/tensorflow/src/tf_utils.hpp b/src/frontends/tensorflow/src/tf_utils.hpp new file mode 100644 index 00000000000..8c8af31e231 --- /dev/null +++ b/src/frontends/tensorflow/src/tf_utils.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "attr_value.pb.h" +#include "node_def.pb.h" +#include "openvino/core/partial_shape.hpp" +#include "openvino/core/type.hpp" +#include "openvino/core/type/element_type.hpp" +#include "openvino/runtime/tensor.hpp" +#include "tensor.pb.h" +#include "tensor_shape.pb.h" +#include "types.pb.h" + +namespace ov { +namespace frontend { +namespace tensorflow { + +ov::element::Type get_ov_type(const ::tensorflow::DataType& type); + +ov::Any unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto); + +ov::Any unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto, + const ::tensorflow::TensorShapeProto& tensor_shape, + const ::tensorflow::DataType& tensor_type); + +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/variables_index.cpp b/src/frontends/tensorflow/src/variables_index.cpp index 426c138c42b..7534a6dcbd0 100644 --- a/src/frontends/tensorflow/src/variables_index.cpp +++ b/src/frontends/tensorflow/src/variables_index.cpp @@ -7,6 +7,7 @@ #include #include +#include "checkpoint_utils.hpp" #include "graph_iterator_saved_model.hpp" #include "openvino/core/type/element_type.hpp" #include "tensor_bundle.pb.h" @@ -20,80 +21,6 @@ namespace ov { namespace frontend { namespace tensorflow { -template -static T smReadFixed(const char* ptr) { - T result = 0; - for (uint8_t i = 0; i < sizeof(T); ++i) { - result |= static_cast(ptr[i]) << (i * 8); - } - return result; -} - -template -static T smUnpack(char*& ptr, const char* ptr_end) { - T result = 0; - for (uint8_t i = 0; i < sizeof(T) * 7 && ptr < ptr_end; i += 7) { - T byte = *(ptr++); - if (byte & 0x80) { - result |= ((byte & 0x7F) << i); - } else { - result |= byte << i; - return result; - } - } - return 0; -} - -/// \brief Structure is for storing information about block in Varaibles Index file. -/// It defines only offset and block size, no information about exact content. -struct VIBlock { - uint64_t m_size; - uint64_t m_offset; - - void read(char*& ptr, const char* ptr_end) { - m_offset = smUnpack(ptr, ptr_end); - m_size = smUnpack(ptr, ptr_end); - } -}; - -#define VARIABLES_INDEX_FOOTER_SIZE 48 - -/// \brief Structure is for storing information about Variables Index footer information. -/// It contains description of two blocks and a magic number for a file verification. -/// Currently, it is placed in last VARIABLES_INDEX_FOOTER_SIZE bytes at the end of a file. -struct VIFooter { - VIBlock m_metaIndex; - VIBlock m_index; - - void read(char*& ptr, const char* ptr_end) { - m_index.read(ptr, ptr_end); - m_metaIndex.read(ptr, ptr_end); - } - - void read(std::ifstream& fs) { - fs.seekg(0, std::ios::end); - size_t size = fs.tellg(); - FRONT_END_GENERAL_CHECK(size >= VARIABLES_INDEX_FOOTER_SIZE, - "Wrong index file, file size is less than minimal expected"); - - char footerData[VARIABLES_INDEX_FOOTER_SIZE] = {}, *ptr = &footerData[0]; - fs.seekg(size - sizeof(footerData)); - fs.read(ptr, sizeof(footerData)); - - // https://github.com/tensorflow/tensorflow/blob/9659b7bdca80a8ef8240eb021d4da089034eeb00/tensorflow/tsl/lib/io/format.cc#L59 - ptr += sizeof(footerData) - 8; - uint32_t magic_lo = *reinterpret_cast(ptr); - uint32_t magic_hi = *reinterpret_cast(ptr + 4); - uint64_t magic_no = (static_cast(magic_hi) << 32) | static_cast(magic_lo); - - FRONT_END_GENERAL_CHECK(magic_no == 0xdb4775248b80fb57ull, "Wrong index file, magic number mismatch detected"); - - ptr = &footerData[0]; - m_metaIndex.read(ptr, ptr + sizeof(footerData)); - m_index.read(ptr, ptr + sizeof(footerData)); - } -}; - void VariablesIndex::read_variables_index_block(std::ifstream& fs, const VIBlock& index, std::vector& data, @@ -101,7 +28,7 @@ void VariablesIndex::read_variables_index_block(std::ifstream& fs, uint32_t& offset_end) { size_t block_size = index.m_size; data.clear(); - data.resize(block_size + 5 /*kBlockTrailerSize*/); + data.resize(block_size + BLOCK_TRAILER_SIZE); FRONT_END_GENERAL_CHECK(index.m_offset <= m_variables_index_size, "Block offset is bigger than variables index size"); FRONT_END_GENERAL_CHECK(index.m_offset + data.size() <= m_variables_index_size, @@ -124,11 +51,11 @@ void VariablesIndex::read_variables_index_block(std::ifstream& fs, block_size = uncompressed_length; } #endif - uint32_t numRestarts = smReadFixed(data.data() + block_size - sizeof(uint32_t)); + uint32_t numRestarts = decode_fixed32(data.data() + block_size - sizeof(uint32_t)); size_t maxRestarts = (block_size - sizeof(uint32_t)) / sizeof(uint32_t); FRONT_END_GENERAL_CHECK(maxRestarts >= numRestarts, "Wrong restarts value"); offset_end = static_cast(block_size) - ((numRestarts + 1) * sizeof(uint32_t)); - offset = smReadFixed(data.data() + offset_end); + offset = decode_fixed32(data.data() + offset_end); } void VariablesIndex::read_variables_index_pair(char*& ptr, diff --git a/tools/mo/openvino/tools/mo/front/tf/loader.py b/tools/mo/openvino/tools/mo/front/tf/loader.py index e1a60f4bd81..9b78b2da2fa 100644 --- a/tools/mo/openvino/tools/mo/front/tf/loader.py +++ b/tools/mo/openvino/tools/mo/front/tf/loader.py @@ -334,14 +334,9 @@ def convert_to_pb(argv: argparse.Namespace): if "tensorflow" in env_setup and env_setup["tensorflow"] >= LooseVersion("2.0.0"): tf.keras.backend.clear_session() - # if this is already binary or text frozen format .pb or .pbtxt, - # there is no need to create auxiliary binary frozen protobuf - if argv.input_model and not argv.input_checkpoint and \ - isinstance(argv.input_model, str): - return None - - # Saved Model format and MetaGraph format is supported without freezing - if argv.saved_model_dir or argv.input_meta_graph: + # any model format on disk is accepted by TensorFlow Frontend + # only model from memory requires temporal saving on a disk + if (argv.input_model and isinstance(argv.input_model, str)) or argv.saved_model_dir or argv.input_meta_graph: return None user_output_node_names_list = argv.output if argv.output else None diff --git a/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py b/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py index 8da2b41427f..32d43a698f6 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py @@ -34,7 +34,12 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd): raise Exception("ONNX frontend does not support input model as BytesIO object. " "Please use use_legacy_frontend=True to convert the model.") else: - if argv.input_model: + input_checkpoint = getattr(argv, 'input_checkpoint', None) + if argv.input_model and input_checkpoint: + # frozen format with v1 checkpoints + input_model = moc_front_end.load([argv.input_model, argv.input_checkpoint]) + elif argv.input_model: + # frozen model without v1 checkpoints input_model = moc_front_end.load(argv.input_model) elif argv.saved_model_dir: input_model = moc_front_end.load(argv.saved_model_dir) diff --git a/tools/mo/unit_tests/mo/utils/freeze_placeholder_test.py b/tools/mo/unit_tests/mo/utils/freeze_placeholder_test.py index 2c4fa3c2d1d..fdcba729f03 100644 --- a/tools/mo/unit_tests/mo/utils/freeze_placeholder_test.py +++ b/tools/mo/unit_tests/mo/utils/freeze_placeholder_test.py @@ -28,6 +28,7 @@ def base_args_config(use_legacy_fe: bool = None, use_new_fe: bool = None): args.framework = "onnx" args.model_name = None args.input_model = None + args.input_checkpoint = None args.silent = True args.transform = [] args.scale = None diff --git a/tools/mo/unit_tests/mo/utils/test_mo_model_analysis_actual.py b/tools/mo/unit_tests/mo/utils/test_mo_model_analysis_actual.py index 37873e98510..0837d0c8637 100644 --- a/tools/mo/unit_tests/mo/utils/test_mo_model_analysis_actual.py +++ b/tools/mo/unit_tests/mo/utils/test_mo_model_analysis_actual.py @@ -27,6 +27,7 @@ def base_args_config(): args.framework = 'onnx' args.model_name = None args.input_model = None + args.input_checkpoint = None args.silent = True args.transform=[] args.scale = None diff --git a/tools/mo/unit_tests/moc_tf_fe/conversion_incorrect_models_test.py b/tools/mo/unit_tests/moc_tf_fe/conversion_incorrect_models_test.py index d3484af3814..a9f2a4a4c4f 100644 --- a/tools/mo/unit_tests/moc_tf_fe/conversion_incorrect_models_test.py +++ b/tools/mo/unit_tests/moc_tf_fe/conversion_incorrect_models_test.py @@ -33,7 +33,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase): def test_conversion_fake_pb_model(self, use_new_frontend, use_legacy_frontend, framework): with self.assertRaisesRegex(Exception, "Internal error or inconsistent input model: the frontend supports frozen formats" - " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats."): + " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints."): path = os.path.dirname(__file__) input_model = os.path.join(path, "test_models", "fake.pb") @@ -51,7 +51,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase): ( False, False, "tf", r"Internal error or inconsistent input model: the frontend supports frozen formats" - " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats." + " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints." ), # new frontend ( @@ -61,7 +61,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase): ( True, False, "tf", r"Internal error or inconsistent input model: the frontend supports frozen formats" - " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats." + " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints." ), ], ) @@ -83,7 +83,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase): ( False, False, "tf", r"Internal error or inconsistent input model: the frontend supports frozen formats" - " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats." + " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints." ), # new frontend ( @@ -93,7 +93,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase): ( True, False, "tf", r"Internal error or inconsistent input model: the frontend supports frozen formats" - " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats." + " \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints." ), ], ) diff --git a/tools/mo/unit_tests/moc_tf_fe/conversion_with_checkpoint_v1_test.py b/tools/mo/unit_tests/moc_tf_fe/conversion_with_checkpoint_v1_test.py new file mode 100644 index 00000000000..4b7a37495cb --- /dev/null +++ b/tools/mo/unit_tests/moc_tf_fe/conversion_with_checkpoint_v1_test.py @@ -0,0 +1,42 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +import unittest + +import numpy as np + +from unit_tests.moc_tf_fe.utils import basic_check + + +class TestBasicConversion(unittest.TestCase): + def prepare_checkpoint_v1(self): + # quite old TensorFlow version can produce checkpoint v1 file + # so have hard coded bytestream corresponding to checkpoint v1 content + # this is a checkpoint v1 for Variable global_step with value = 14108582 of int64 type + buffer_checkpoint = [ + 0x00, 0x00, 0x1B, 0x0A, 0x19, 0x0A, 0x13, 0x0A, 0x0B, 0x67, 0x6C, 0x6F, + 0x62, 0x61, 0x6C, 0x5F, 0x73, 0x74, 0x65, 0x70, 0x12, 0x00, 0x18, 0x09, + 0x22, 0x00, 0x12, 0x02, 0x08, 0x01, 0x00, 0x0F, 0x19, 0x00, 0x67, 0x6C, + 0x6F, 0x62, 0x61, 0x6C, 0x5F, 0x73, 0x74, 0x65, 0x70, 0x00, 0x01, 0x00, + 0x12, 0x17, 0x0A, 0x0B, 0x67, 0x6C, 0x6F, 0x62, 0x61, 0x6C, 0x5F, 0x73, + 0x74, 0x65, 0x70, 0x12, 0x00, 0x1A, 0x06, 0x52, 0x04, 0xA6, 0x8F, 0xDD, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x62, 0x29, + 0x33, 0xD3, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xC0, + 0xF2, 0xA1, 0xB0, 0x00, 0x01, 0x02, 0x01, 0x00, 0x51, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x1A, 0x13, 0xD9, 0x46, 0x56, 0x08, + 0x63, 0x0E, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x57, 0xFB, 0x80, 0x8B, 0x24, 0x75, 0x47, 0xDB] + return buffer_checkpoint + + def test_basic_checkpoint_v1(self): + ckpt_file = tempfile.NamedTemporaryFile(delete=False) + checkpoint_byte_stream = self.prepare_checkpoint_v1() + ckpt_file.write(bytes(checkpoint_byte_stream)) + ckpt_file.close() + basic_check(input_model="model_with_variable_v1.pbtxt", argv_input=None, + input_data={'input1': np.array([[1]], dtype=np.int64)}, + expected_dtype=np.int64, expected_value=np.array([[14108583]], dtype=np.int64), + use_new_frontend=True, use_legacy_frontend=False, input_checkpoint=ckpt_file.name) diff --git a/tools/mo/unit_tests/moc_tf_fe/test_models/model_with_variable_v1.pbtxt b/tools/mo/unit_tests/moc_tf_fe/test_models/model_with_variable_v1.pbtxt new file mode 100644 index 00000000000..4bd69c1afbe --- /dev/null +++ b/tools/mo/unit_tests/moc_tf_fe/test_models/model_with_variable_v1.pbtxt @@ -0,0 +1,53 @@ +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + } + } +} +node { + name: "global_step" + op: "Variable" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "add" + op: "Add" + input: "input1" + input: "global_step" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} diff --git a/tools/mo/unit_tests/moc_tf_fe/utils.py b/tools/mo/unit_tests/moc_tf_fe/utils.py new file mode 100644 index 00000000000..4e73b1d3aae --- /dev/null +++ b/tools/mo/unit_tests/moc_tf_fe/utils.py @@ -0,0 +1,38 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os + +import numpy as np + +from openvino.runtime import Core +from openvino.tools.mo.convert import convert_model + + +def basic_check(input_model, argv_input, input_data, expected_dtype, expected_value, freeze_placeholder_with_value=None, + input_shape=None, only_conversion=False, input_model_is_text=True, use_new_frontend=True, + use_legacy_frontend=False, extensions=None, input_checkpoint=None): + path = os.path.dirname(__file__) + input_model = os.path.join(path, "test_models", input_model) + + ov_model = convert_model(input_model, input=argv_input, + freeze_placeholder_with_value=freeze_placeholder_with_value, + input_shape=input_shape, input_model_is_text=input_model_is_text, + use_new_frontend=use_new_frontend, use_legacy_frontend=use_legacy_frontend, + framework="tf", extensions=extensions, input_checkpoint=input_checkpoint) + + if only_conversion: + return ov_model + + ie = Core() + exec_net = ie.compile_model(ov_model, "CPU") + req = exec_net.create_infer_request() + results = req.infer(input_data) + values = list(results.values())[0] + if expected_dtype is not None: + assert values.dtype == expected_dtype + assert np.allclose(values, + expected_value), "Expected and actual values are different." \ + " Expected value: {}, actual value: {}".format(expected_value, values) + + return ov_model