[TF FE] Support models with v1 checkpoints (#18000)
* [TF FE] Support models with v1 checkpoints Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix build issue * Fix issue in load_impl * Fix MO unit-tests * Check input_checkpoint in argv * Fix build issue * Update src/frontends/tensorflow/src/checkpoint_v1_reader.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * Apply comments from code-review: split CheckpointV1Reader constructor --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
967faf878a
commit
79683c24ca
277
src/frontends/tensorflow/src/checkpoint_utils.cpp
Normal file
277
src/frontends/tensorflow/src/checkpoint_utils.cpp
Normal file
@ -0,0 +1,277 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "checkpoint_utils.hpp"
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/frontend/exception.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
static const char escape1 = '\000';
|
||||
static const char null_character = '\xff';
|
||||
static const char separator = '\001';
|
||||
static const char escape2 = '\xff';
|
||||
static const char ffcharacter = '\000';
|
||||
static const char escape1_separator[2] = {escape1, separator};
|
||||
static const int max_signed64_length = 10;
|
||||
|
||||
// This array maps encoding length to header bits in the first two bytes.
|
||||
static const char length_to_header_bits[1 + max_signed64_length][2] = {{0, 0},
|
||||
{'\x80', 0},
|
||||
{'\xc0', 0},
|
||||
{'\xe0', 0},
|
||||
{'\xf0', 0},
|
||||
{'\xf8', 0},
|
||||
{'\xfc', 0},
|
||||
{'\xfe', 0},
|
||||
{'\xff', 0},
|
||||
{'\xff', '\x80'},
|
||||
{'\xff', '\xc0'}};
|
||||
|
||||
inline bool byte_is_0_or_255(char c) {
|
||||
return (static_cast<unsigned char>(c + 1)) < 2;
|
||||
}
|
||||
|
||||
static inline const unsigned char* char_ptr_to_unsigned_char_ptr(const char* p) {
|
||||
const void* void_ptr = static_cast<const void*>(p);
|
||||
return static_cast<const unsigned char*>(void_ptr);
|
||||
}
|
||||
|
||||
static inline const char* unsigned_char_ptr_to_char_ptr(const unsigned char* p) {
|
||||
const void* void_ptr = static_cast<const void*>(p);
|
||||
return static_cast<const char*>(void_ptr);
|
||||
}
|
||||
|
||||
// return a pointer to the first byte in the range "[start..limit)"
|
||||
// whose value is 0 or 255 (escape1 or escape2)
|
||||
inline const char* find_special_byte(const char* start, const char* limit) {
|
||||
// If these constants were ever changed, this routine needs to change
|
||||
const char* current = start;
|
||||
while (current < limit && !byte_is_0_or_255(*current)) {
|
||||
++current;
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
// encode "source" and append to "dest", escaping special characters
|
||||
inline static void encode_string_piece(std::string& dest, const std::string& source) {
|
||||
const char* current = source.data();
|
||||
const char* limit = current + source.size();
|
||||
const char* copy_start = current;
|
||||
while (true) {
|
||||
current = find_special_byte(current, limit);
|
||||
if (current >= limit)
|
||||
break; // No more special characters that need escaping
|
||||
char c = *(current++);
|
||||
if (c == escape1) {
|
||||
dest.append(copy_start, current - copy_start - 1);
|
||||
dest.push_back(escape1);
|
||||
dest.push_back(null_character);
|
||||
copy_start = current;
|
||||
} else {
|
||||
FRONT_END_GENERAL_CHECK(c == escape2, "[TensorFlow Frontend] incorrect model: corrupted checkpoint");
|
||||
dest.append(copy_start, current - copy_start - 1);
|
||||
dest.push_back(escape2);
|
||||
dest.push_back(ffcharacter);
|
||||
copy_start = current;
|
||||
}
|
||||
}
|
||||
if (current > copy_start) {
|
||||
dest.append(copy_start, current - copy_start);
|
||||
}
|
||||
}
|
||||
|
||||
// reverse bytes of 64-bit number
|
||||
static void convert_to_big_endian64(char* dst, uint64_t v) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
dst[i] = (v >> (56 - 8 * i)) & 0xff;
|
||||
}
|
||||
}
|
||||
|
||||
// compute floored log2(n)
|
||||
static int log2_floor32(uint32_t n) {
|
||||
if (n == 0)
|
||||
return -1;
|
||||
int log = 0;
|
||||
uint32_t value = n;
|
||||
for (int i = 4; i >= 0; --i) {
|
||||
int shift = (1 << i);
|
||||
uint32_t x = value >> shift;
|
||||
if (x != 0) {
|
||||
value = x;
|
||||
log += shift;
|
||||
}
|
||||
}
|
||||
return log;
|
||||
}
|
||||
|
||||
// compute floored log2(n)
|
||||
static int log2_floor64(uint64_t n) {
|
||||
const uint32_t topbits = static_cast<uint32_t>(n >> 32);
|
||||
if (topbits == 0) {
|
||||
// Top bits are zero, so scan in bottom bits
|
||||
return log2_floor32(static_cast<uint32_t>(n));
|
||||
} else {
|
||||
return 32 + log2_floor32(topbits);
|
||||
}
|
||||
}
|
||||
|
||||
// calculate the encoding length in bytes of the signed number n
|
||||
static inline int signed_encoding_length(int64_t n) {
|
||||
int log2 = log2_floor64(n < 0 ? ~n : n) + 1;
|
||||
return log2 / 7 + 1;
|
||||
}
|
||||
|
||||
void write_signed_num_increasing(std::string& dest, int64_t val) {
|
||||
const uint64_t x = val < 0 ? ~val : val;
|
||||
if (x < 64) { // fast path for encoding length == 1
|
||||
dest += length_to_header_bits[1][0] ^ static_cast<char>(val);
|
||||
return;
|
||||
}
|
||||
// buf = val in network byte order, sign extended to 10 bytes
|
||||
const char sign_byte = val < 0 ? '\xff' : '\0';
|
||||
char buf[max_signed64_length] = {
|
||||
sign_byte,
|
||||
sign_byte,
|
||||
};
|
||||
convert_to_big_endian64(buf + 2, val);
|
||||
const int len = signed_encoding_length(x);
|
||||
FRONT_END_GENERAL_CHECK(0 < len && len <= max_signed64_length,
|
||||
"[TensorFlow Frontend] internal error: write_signed_num_increasing failed");
|
||||
char* const begin = buf + max_signed64_length - len;
|
||||
begin[0] ^= length_to_header_bits[len][0];
|
||||
begin[1] ^= length_to_header_bits[len][1]; // ok because len >= 2
|
||||
dest.append(begin, len);
|
||||
}
|
||||
|
||||
void write_num_increasing(std::string& dest, uint64_t val) {
|
||||
// Values are encoded with a single byte length prefix, followed
|
||||
// by the actual value in big-endian format with leading 0 bytes
|
||||
// dropped.
|
||||
unsigned char buf[9]; // 8 bytes for value plus one byte for length
|
||||
int len = 0;
|
||||
while (val > 0) {
|
||||
++len;
|
||||
buf[9 - len] = (val & 0xff);
|
||||
val >>= 8;
|
||||
}
|
||||
buf[9 - len - 1] = len;
|
||||
++len;
|
||||
dest.append(unsigned_char_ptr_to_char_ptr(&buf[0]) + 9 - len, len);
|
||||
}
|
||||
|
||||
std::string encode_tensor_name_slice(const std::string& name,
|
||||
const std::vector<int64_t>& starts,
|
||||
const std::vector<int64_t> lengths) {
|
||||
std::string buffer;
|
||||
// All the tensor slice keys will start with a 0
|
||||
write_num_increasing(buffer, 0);
|
||||
encode_string_piece(buffer, name);
|
||||
buffer.append(escape1_separator, 2);
|
||||
write_num_increasing(buffer, starts.size());
|
||||
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
starts.size() == lengths.size(),
|
||||
"[TensorFlow Frontend] internal error or inconsistent model: check consistency of checkpoint files");
|
||||
for (size_t d = 0; d < starts.size(); ++d) {
|
||||
write_signed_num_increasing(buffer, starts[d]);
|
||||
write_signed_num_increasing(buffer, lengths[d]);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
uint32_t decode_fixed32(const char* ptr) {
|
||||
uint32_t result;
|
||||
std::memcpy(&result, ptr, sizeof(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
const char* get_varint32_ptr(const char* p, const char* limit, uint32_t& value) {
|
||||
if (p < limit) {
|
||||
uint32_t result = *(char_ptr_to_unsigned_char_ptr(p));
|
||||
if ((result & 128) == 0) {
|
||||
value = result;
|
||||
return p + 1;
|
||||
}
|
||||
}
|
||||
uint32_t result = 0;
|
||||
for (uint32_t shift = 0; shift <= 28 && p < limit; shift += 7) {
|
||||
uint32_t byte = *(char_ptr_to_unsigned_char_ptr(p));
|
||||
++p;
|
||||
if (byte & 128) {
|
||||
// More bytes are present
|
||||
result |= ((byte & 127) << shift);
|
||||
} else {
|
||||
result |= (byte << shift);
|
||||
value = result;
|
||||
return p;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const char* get_varint64_ptr(const char* p, const char* limit, uint64_t* value) {
|
||||
uint64_t result = 0;
|
||||
for (uint32_t shift = 0; shift <= 63 && p < limit; shift += 7) {
|
||||
uint64_t byte = *(char_ptr_to_unsigned_char_ptr(p));
|
||||
++p;
|
||||
if (byte & 128) {
|
||||
// More bytes are present
|
||||
result |= ((byte & 127) << shift);
|
||||
} else {
|
||||
result |= (byte << shift);
|
||||
*value = result;
|
||||
return p;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool get_varint64(std::string& input, uint64_t* value) {
|
||||
const char* p = input.data();
|
||||
const char* limit = p + input.size();
|
||||
const char* q = get_varint64_ptr(p, limit, value);
|
||||
if (q == nullptr) {
|
||||
return false;
|
||||
} else {
|
||||
input = std::string(q, limit - q);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
const char* decode_entry(const char* p,
|
||||
const char* limit,
|
||||
uint32_t& shared,
|
||||
uint32_t& non_shared,
|
||||
uint32_t& value_length) {
|
||||
if (limit - p < 3)
|
||||
return nullptr;
|
||||
shared = char_ptr_to_unsigned_char_ptr(p)[0];
|
||||
non_shared = char_ptr_to_unsigned_char_ptr(p)[1];
|
||||
value_length = char_ptr_to_unsigned_char_ptr(p)[2];
|
||||
if ((shared | non_shared | value_length) < 128) {
|
||||
// Fast path: all three values are encoded in one byte each
|
||||
p += 3;
|
||||
} else {
|
||||
if ((p = get_varint32_ptr(p, limit, shared)) == nullptr)
|
||||
return nullptr;
|
||||
if ((p = get_varint32_ptr(p, limit, non_shared)) == nullptr)
|
||||
return nullptr;
|
||||
if ((p = get_varint32_ptr(p, limit, value_length)) == nullptr)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (static_cast<uint32_t>(limit - p) < (non_shared + value_length)) {
|
||||
return nullptr;
|
||||
}
|
||||
return p;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
100
src/frontends/tensorflow/src/checkpoint_utils.hpp
Normal file
100
src/frontends/tensorflow/src/checkpoint_utils.hpp
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/frontend/exception.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
#define VARIABLES_INDEX_FOOTER_SIZE 48
|
||||
#define BLOCK_TRAILER_SIZE 5
|
||||
#define SAVED_TENSOR_SLICES_KEY ""
|
||||
|
||||
template <typename T>
|
||||
static T smUnpack(char*& ptr, const char* ptr_end) {
|
||||
T result = 0;
|
||||
for (uint8_t i = 0; i <= sizeof(T) * 7 && ptr < ptr_end; i += 7) {
|
||||
T byte = *(ptr++);
|
||||
if (byte & 0x80) {
|
||||
result |= ((byte & 0x7F) << i);
|
||||
} else {
|
||||
result |= byte << i;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// \brief Structure is for storing information about block in Varaibles Index file.
|
||||
/// It defines only offset and block size, no information about exact content.
|
||||
struct VIBlock {
|
||||
uint64_t m_size;
|
||||
uint64_t m_offset;
|
||||
|
||||
void read(char*& ptr, const char* ptr_end) {
|
||||
m_offset = smUnpack<uint64_t>(ptr, ptr_end);
|
||||
m_size = smUnpack<uint64_t>(ptr, ptr_end);
|
||||
}
|
||||
};
|
||||
|
||||
/// \brief Structure is for storing information about Variables Index footer information.
|
||||
/// It contains description of two blocks and a magic number for a file verification.
|
||||
/// Currently, it is placed in last VARIABLES_INDEX_FOOTER_SIZE bytes at the end of a file.
|
||||
struct VIFooter {
|
||||
VIBlock m_metaIndex;
|
||||
VIBlock m_index;
|
||||
|
||||
void read(char*& ptr, const char* ptr_end) {
|
||||
m_index.read(ptr, ptr_end);
|
||||
m_metaIndex.read(ptr, ptr_end);
|
||||
}
|
||||
|
||||
void read(std::ifstream& fs) {
|
||||
fs.seekg(0, std::ios::end);
|
||||
size_t size = fs.tellg();
|
||||
FRONT_END_GENERAL_CHECK(size >= VARIABLES_INDEX_FOOTER_SIZE,
|
||||
"Wrong index file, file size is less than minimal expected");
|
||||
|
||||
char footerData[VARIABLES_INDEX_FOOTER_SIZE] = {}, *ptr = &footerData[0];
|
||||
fs.seekg(size - sizeof(footerData));
|
||||
fs.read(ptr, sizeof(footerData));
|
||||
|
||||
// https://github.com/tensorflow/tensorflow/blob/9659b7bdca80a8ef8240eb021d4da089034eeb00/tensorflow/tsl/lib/io/format.cc#L59
|
||||
ptr += sizeof(footerData) - 8;
|
||||
uint32_t magic_lo = *reinterpret_cast<const uint32_t*>(ptr);
|
||||
uint32_t magic_hi = *reinterpret_cast<const uint32_t*>(ptr + 4);
|
||||
uint64_t magic_no = (static_cast<uint64_t>(magic_hi) << 32) | static_cast<uint64_t>(magic_lo);
|
||||
|
||||
FRONT_END_GENERAL_CHECK(magic_no == 0xdb4775248b80fb57ull, "Wrong index file, magic number mismatch detected");
|
||||
|
||||
ptr = &footerData[0];
|
||||
m_metaIndex.read(ptr, ptr + sizeof(footerData));
|
||||
m_index.read(ptr, ptr + sizeof(footerData));
|
||||
}
|
||||
};
|
||||
|
||||
uint32_t decode_fixed32(const char* ptr);
|
||||
|
||||
const char* decode_entry(const char* p,
|
||||
const char* limit,
|
||||
uint32_t& shared,
|
||||
uint32_t& non_shared,
|
||||
uint32_t& value_length);
|
||||
|
||||
bool get_varint64(std::string& input, uint64_t* value);
|
||||
|
||||
std::string encode_tensor_name_slice(const std::string& name,
|
||||
const std::vector<int64_t>& starts,
|
||||
const std::vector<int64_t> lengths);
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
262
src/frontends/tensorflow/src/checkpoint_v1_reader.cpp
Normal file
262
src/frontends/tensorflow/src/checkpoint_v1_reader.cpp
Normal file
@ -0,0 +1,262 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "checkpoint_v1_reader.hpp"
|
||||
|
||||
#include "checkpoint_utils.hpp"
|
||||
#include "openvino/frontend/exception.hpp"
|
||||
#include "openvino/util/file_util.hpp"
|
||||
#include "saved_tensor_slice.pb.h"
|
||||
#include "tf_utils.hpp"
|
||||
|
||||
#ifdef ENABLE_SNAPPY_COMPRESSION
|
||||
# include "snappy.h"
|
||||
#endif
|
||||
|
||||
using namespace ov::frontend::tensorflow;
|
||||
|
||||
namespace {
|
||||
std::vector<std::string> list_files_in_dir(const std::string& directory_path) {
|
||||
std::vector<std::string> res;
|
||||
try {
|
||||
ov::util::iterate_files(
|
||||
directory_path,
|
||||
[&res](const std::string& file_path, bool is_dir) {
|
||||
auto file = ov::util::get_file_name(file_path);
|
||||
if (!is_dir) {
|
||||
res.push_back(file_path);
|
||||
}
|
||||
},
|
||||
false,
|
||||
true);
|
||||
} catch (...) {
|
||||
// Ignore exceptions
|
||||
}
|
||||
return res;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CheckpointV1Reader::CheckpointV1Reader(const std::string& checkpoints) : m_checkpoints(checkpoints) {}
|
||||
|
||||
void CheckpointV1Reader::initialize() {
|
||||
// figure out if the input is a file or a directory of checkpoints
|
||||
std::vector<std::string> checkpoints_paths;
|
||||
if (ov::util::directory_exists(m_checkpoints)) {
|
||||
checkpoints_paths = list_files_in_dir(m_checkpoints);
|
||||
} else if (ov::util::file_exists(m_checkpoints)) {
|
||||
checkpoints_paths = {m_checkpoints};
|
||||
} else {
|
||||
FRONT_END_GENERAL_CHECK(false, "[TensorFlow Frontend] incorrect checkpoint: the checkpoint does not exist");
|
||||
}
|
||||
|
||||
m_variables_info_map.clear();
|
||||
|
||||
for (auto checkpoint_path : checkpoints_paths) {
|
||||
// create ifstream for each shard
|
||||
std::shared_ptr<std::ifstream> shard_stream =
|
||||
std::make_shared<std::ifstream>(checkpoint_path, std::ifstream::in | std::ifstream::binary);
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
shard_stream && shard_stream->is_open(),
|
||||
"[TensorFlow Frontend] incorrect model: checkpoint file " + checkpoint_path + "does not exist");
|
||||
const int32_t shard_ind = static_cast<int32_t>(m_shards.size());
|
||||
m_shards.push_back(shard_stream);
|
||||
m_shard_names.push_back(checkpoint_path);
|
||||
std::string value;
|
||||
find_entry(shard_stream, checkpoint_path, SAVED_TENSOR_SLICES_KEY, value);
|
||||
|
||||
// parse empty index block
|
||||
// This is only present at the first item of each checkpoint file and serves
|
||||
// as a table of contents, listing all the tensor slices saved in this file.
|
||||
::tensorflow::SavedTensorSlices sts;
|
||||
FRONT_END_GENERAL_CHECK(sts.ParseFromArray(value.data(), static_cast<int>(value.size())),
|
||||
"[TensorFlow Frontend] incorrect input checkpoint file or internal error: cannot parse "
|
||||
"SavedTensorSlices entry");
|
||||
for (const auto& saved_slice_meta : sts.meta().tensor()) {
|
||||
// parse shapes and types for variables
|
||||
VariableInfo var_info;
|
||||
var_info.shard_id = shard_ind;
|
||||
auto variable_name = saved_slice_meta.name(); // original variable name (not encoded)
|
||||
var_info.variable_shape = saved_slice_meta.shape();
|
||||
var_info.variable_type = saved_slice_meta.type();
|
||||
|
||||
// save starts and lenghts of slices for variable name encoding
|
||||
for (const auto& slice : saved_slice_meta.slice()) {
|
||||
// var_info.starts.push_back(slice.extent())
|
||||
for (const auto& extent : slice.extent()) {
|
||||
var_info.starts.push_back(extent.start());
|
||||
if (extent.has_length()) {
|
||||
var_info.lenghts.push_back(extent.length());
|
||||
} else {
|
||||
var_info.lenghts.push_back(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
m_variables_info_map[variable_name] = var_info;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckpointV1Reader::seek_block(const std::string& shard_name,
|
||||
const std::string& target_key,
|
||||
const char* block_ptr,
|
||||
const uint32_t restarts,
|
||||
std::string& value) const {
|
||||
// parsing the next key starts at the end of value, so set value accordingly
|
||||
const char* curr_value_pos = block_ptr;
|
||||
const char* limit = block_ptr + restarts; // restarts come right after data
|
||||
std::string key = "";
|
||||
|
||||
bool is_found = false;
|
||||
while (true) {
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
curr_value_pos < limit,
|
||||
"[TensorFlow Frontend] incorrect model: no more entries to return, invalid checkpoint file " + shard_name);
|
||||
|
||||
// decode next entry
|
||||
// each entry looks as follows:
|
||||
// | shared (1 byte) | non-shared (1 byte) | value_length (1 byte) | key (non-shared bytes) |
|
||||
// | value (value_length bytes) |
|
||||
uint32_t shared, non_shared, value_length;
|
||||
curr_value_pos = decode_entry(curr_value_pos, limit, shared, non_shared, value_length);
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
curr_value_pos && key.size() >= shared,
|
||||
"[TensorFlow Frontend] incorrect model: corruption error in checkpoint file " + shard_name);
|
||||
|
||||
key.resize(shared);
|
||||
key.append(curr_value_pos, non_shared);
|
||||
value = std::string(curr_value_pos + non_shared, value_length);
|
||||
curr_value_pos += (non_shared + value_length);
|
||||
|
||||
if (key.compare(target_key) >= 0) {
|
||||
is_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
is_found,
|
||||
"[TensorFlow Frontend] incorrect input model: checkpoint file " + shard_name + " can be incorrect");
|
||||
}
|
||||
|
||||
void CheckpointV1Reader::init_block(const std::shared_ptr<std::ifstream>& shard,
|
||||
const std::string& shard_name,
|
||||
uint64_t offset,
|
||||
uint64_t size,
|
||||
std::string& block,
|
||||
uint64_t& restart_offset) const {
|
||||
// check a size of the shard
|
||||
FRONT_END_GENERAL_CHECK(shard,
|
||||
"[TensorFlow Frontend] internal error: nullptr pointer to checkpoint file " + shard_name);
|
||||
shard->seekg(0, shard->end);
|
||||
uint64_t shard_size = static_cast<uint64_t>(shard->tellg());
|
||||
FRONT_END_GENERAL_CHECK(offset < shard_size,
|
||||
"[TensorFlow Frontend] internal error or inconsistent checkpoint file: block offset is "
|
||||
"out-of-range for checkpoint file " +
|
||||
shard_name);
|
||||
auto n = size + BLOCK_TRAILER_SIZE;
|
||||
FRONT_END_GENERAL_CHECK(n < (shard_size - offset),
|
||||
"[TensorFlow Frontend] internal error or inconsistent checkpoint file: block size is "
|
||||
"out-of-range for checkpoint file " +
|
||||
shard_name);
|
||||
|
||||
// read a block and decompress if needed
|
||||
std::vector<char> buf(n);
|
||||
shard->seekg(offset);
|
||||
shard->read(buf.data(), n);
|
||||
#ifndef ENABLE_SNAPPY_COMPRESSION
|
||||
FRONT_END_GENERAL_CHECK(buf[size] == 0,
|
||||
"[TensorFlow Frontend] internal error: compression method for given block is not supported "
|
||||
"for checkpoint file " +
|
||||
shard_name);
|
||||
result_data = std::string(buf.get(), size);
|
||||
#else
|
||||
FRONT_END_GENERAL_CHECK(buf[size] == 0 || buf[size] == 1,
|
||||
"[TensorFlow Frontend] internal error: compression method for given block is not supported "
|
||||
"for checkpoint file " +
|
||||
shard_name);
|
||||
if (buf[size] == 1) {
|
||||
size_t uncompressed_length = 0;
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
snappy::GetUncompressedLength(buf.data(), n, &uncompressed_length),
|
||||
"[TensorFlow Frontend] internal error: cannot retrieve uncompressed block length for checkpoint file " +
|
||||
shard_name);
|
||||
std::string uncompressed_string;
|
||||
block.clear();
|
||||
block.reserve(uncompressed_length);
|
||||
snappy::Uncompress(buf.data(), n, &block);
|
||||
} else {
|
||||
block = std::string(buf.data(), size);
|
||||
}
|
||||
#endif
|
||||
const char* data = block.data();
|
||||
size = block.size();
|
||||
|
||||
// find block characteristics: max_restarts_allowed, num_restarts and restart_offset
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
size >= sizeof(uint32_t),
|
||||
"[TensorFlow Frontend] internal error: block size must be not less than 4 bytes in checkpoint file " +
|
||||
shard_name);
|
||||
size_t max_restarts_allowed = (size - sizeof(uint32_t)) / sizeof(uint32_t);
|
||||
uint32_t num_restarts = decode_fixed32(data + size - sizeof(uint32_t));
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
num_restarts <= max_restarts_allowed,
|
||||
"[TensorFlow Frontend] internal error: num_restarts is greater than max_restarts_allowed in checkpoint file " +
|
||||
shard_name);
|
||||
restart_offset = size - (1 + num_restarts) * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
void CheckpointV1Reader::find_entry(const std::shared_ptr<std::ifstream>& shard,
|
||||
const std::string& shard_name,
|
||||
const std::string& entry_key,
|
||||
std::string& entry_value) {
|
||||
// read footer of the shard file to get offset and size of index block
|
||||
VIFooter footer;
|
||||
footer.read(*shard);
|
||||
uint64_t block_offset = footer.m_index.m_offset;
|
||||
uint64_t block_size = footer.m_index.m_size;
|
||||
std::string block;
|
||||
|
||||
// initialize index block
|
||||
uint64_t restart_offset = 0;
|
||||
init_block(shard, shard_name, block_offset, block_size, block, restart_offset);
|
||||
|
||||
// seek entry in the index block
|
||||
// this entry contains offset and size of the data block
|
||||
seek_block(shard_name, entry_key, block.data(), static_cast<uint32_t>(restart_offset), entry_value);
|
||||
|
||||
// initialize the data block
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
get_varint64(entry_value, &block_offset) && get_varint64(entry_value, &block_size),
|
||||
"[TensorFlow Frontend] incorrect input model: bad block handle in checkpoint file " + shard_name);
|
||||
init_block(shard, shard_name, block_offset, block_size, block, restart_offset);
|
||||
|
||||
// seek the final entry in the data block
|
||||
seek_block(shard_name, entry_key, block.data(), static_cast<uint32_t>(restart_offset), entry_value);
|
||||
}
|
||||
|
||||
void CheckpointV1Reader::read_variable(const std::string& variable_name, ov::Any& data) {
|
||||
FRONT_END_GENERAL_CHECK(m_variables_info_map.count(variable_name) > 0,
|
||||
"[TensorFlow Frontend] incorrect input model: checkpoint files does not contain data for "
|
||||
"the required variable " +
|
||||
variable_name);
|
||||
auto var_info = m_variables_info_map[variable_name];
|
||||
auto shard_id = m_variables_info_map[variable_name].shard_id;
|
||||
FRONT_END_GENERAL_CHECK(shard_id < static_cast<int32_t>(m_shards.size()),
|
||||
"[TensorFlow Frontend] internal error: shard_id is greater than a number of shards");
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
m_shards.size() == m_shard_names.size(),
|
||||
"[TensorFlow Frontend] internal error: number of shards does not match a number of their names");
|
||||
auto shard_ptr = m_shards[shard_id];
|
||||
auto shard_name = m_shard_names[shard_id];
|
||||
auto encoded_name = encode_tensor_name_slice(variable_name, var_info.starts, var_info.lenghts);
|
||||
std::string raw_data;
|
||||
find_entry(shard_ptr, shard_name, encoded_name, raw_data);
|
||||
|
||||
// This is only present at the first item of each checkpoint file and serves
|
||||
// as a table of contents, listing all the tensor slices saved in this file.
|
||||
::tensorflow::SavedTensorSlices sts;
|
||||
FRONT_END_GENERAL_CHECK(sts.ParseFromArray(raw_data.data(), static_cast<int>(raw_data.size())),
|
||||
"[TensorFlow Frontend] incorrect input checkpoint file or internal error: cannot parse "
|
||||
"SavedTensorSlices entry");
|
||||
data = unpack_tensor_proto(sts.data().data(), var_info.variable_shape, var_info.variable_type);
|
||||
}
|
81
src/frontends/tensorflow/src/checkpoint_v1_reader.hpp
Normal file
81
src/frontends/tensorflow/src/checkpoint_v1_reader.hpp
Normal file
@ -0,0 +1,81 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sys/stat.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "checkpoint_utils.hpp"
|
||||
#include "openvino/core/any.hpp"
|
||||
#include "openvino/frontend/exception.hpp"
|
||||
#include "saved_tensor_slice.pb.h"
|
||||
#include "tensor_shape.pb.h"
|
||||
#include "types.pb.h"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
// stores information about shape, type, and shard id for Variable
|
||||
struct VariableInfo {
|
||||
::tensorflow::TensorShapeProto variable_shape;
|
||||
::tensorflow::DataType variable_type;
|
||||
int32_t shard_id;
|
||||
size_t offset;
|
||||
size_t size;
|
||||
std::vector<int64_t> starts;
|
||||
std::vector<int64_t> lenghts;
|
||||
};
|
||||
|
||||
// reads checkpoints of v1 version
|
||||
// it parses value, shape and type for Variable nodes
|
||||
class CheckpointV1Reader {
|
||||
const std::string m_checkpoints;
|
||||
// a map from Variable name to its informations
|
||||
std::unordered_map<std::string, VariableInfo> m_variables_info_map;
|
||||
// a vector of streams for shards, where shard is one checkpoint file
|
||||
std::vector<std::shared_ptr<std::ifstream>> m_shards;
|
||||
// a vector of shard names
|
||||
std::vector<std::string> m_shard_names;
|
||||
|
||||
public:
|
||||
/// \brief constructs CheckpointV1Reader for a given directory of checkpoint files
|
||||
// CheckpointV1Reader(const std::string& checkpoints_dir);
|
||||
CheckpointV1Reader(const std::string& checkpoints);
|
||||
|
||||
/// \brief initialize Checkpoint V1 reader
|
||||
void initialize();
|
||||
|
||||
/// \brief Produces ov::Any object that wraps ov::Tensor for the requested variable
|
||||
/// it can also wraps string tensor
|
||||
/// \param variable_name the requested variable name
|
||||
/// \param a reference to the result
|
||||
void read_variable(const std::string& variable_name, ov::Any& data);
|
||||
|
||||
private:
|
||||
/// \brief finds non-master key entry that uses already cached offset and sizes of data blocks
|
||||
void find_entry(const std::shared_ptr<std::ifstream>& shard,
|
||||
const std::string& shard_name,
|
||||
const std::string& entry_key,
|
||||
std::string& value);
|
||||
|
||||
void seek_block(const std::string& shard_name,
|
||||
const std::string& target,
|
||||
const char* shard_data,
|
||||
const uint32_t restarts,
|
||||
std::string& value) const;
|
||||
|
||||
void init_block(const std::shared_ptr<std::ifstream>& shard,
|
||||
const std::string& shard_name,
|
||||
uint64_t offset,
|
||||
uint64_t size,
|
||||
std::string& block,
|
||||
uint64_t& restart_offset) const;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -8,6 +8,7 @@
|
||||
#include "op_def.pb.h"
|
||||
#include "openvino/frontend/tensorflow/node_context.hpp"
|
||||
#include "openvino/frontend/tensorflow/special_types.hpp"
|
||||
#include "tf_utils.hpp"
|
||||
#include "types.pb.h"
|
||||
|
||||
namespace ov {
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "node_def.pb.h"
|
||||
#include "openvino/frontend/tensorflow/node_context.hpp"
|
||||
#include "openvino/frontend/tensorflow/special_types.hpp"
|
||||
#include "tf_utils.hpp"
|
||||
#include "types.pb.h"
|
||||
|
||||
namespace ov {
|
||||
@ -82,24 +83,6 @@ void extract_compressed_tensor_content(const ::tensorflow::TensorProto& tensor_p
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
ov::element::Type get_ov_type(const ::tensorflow::DataType& type) {
|
||||
static const std::map<::tensorflow::DataType, ov::element::Type> type_map{
|
||||
{::tensorflow::DataType::DT_BOOL, ov::element::boolean},
|
||||
{::tensorflow::DataType::DT_INT16, ov::element::i16},
|
||||
{::tensorflow::DataType::DT_INT32, ov::element::i32},
|
||||
{::tensorflow::DataType::DT_INT64, ov::element::i64},
|
||||
{::tensorflow::DataType::DT_HALF, ov::element::f16},
|
||||
{::tensorflow::DataType::DT_FLOAT, ov::element::f32},
|
||||
{::tensorflow::DataType::DT_DOUBLE, ov::element::f64},
|
||||
{::tensorflow::DataType::DT_UINT8, ov::element::u8},
|
||||
{::tensorflow::DataType::DT_INT8, ov::element::i8},
|
||||
{::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16}};
|
||||
|
||||
auto it = type_map.find(type);
|
||||
// for all unsupported types return dynamic type
|
||||
return it == type_map.end() ? ov::element::dynamic : it->second;
|
||||
}
|
||||
|
||||
ov::Any DecoderProto::get_attribute(const std::string& name) const {
|
||||
auto attrs = decode_attribute_helper(name);
|
||||
if (attrs.empty()) {
|
||||
@ -194,92 +177,7 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
|
||||
}
|
||||
|
||||
case ::tensorflow::AttrValue::ValueCase::kTensor: {
|
||||
const auto& tensor_proto = attrs[0].tensor();
|
||||
const auto& tf_shape = tensor_proto.tensor_shape();
|
||||
ov::PartialShape pshape;
|
||||
for (int i = 0; i < tf_shape.dim_size(); i++) {
|
||||
pshape.push_back(tf_shape.dim(i).size());
|
||||
}
|
||||
FRONT_END_GENERAL_CHECK(pshape.is_static(), "Dynamic shapes are not supported for Tensor attribute.");
|
||||
const auto& tf_type = tensor_proto.dtype();
|
||||
auto ov_type = get_ov_type(tf_type);
|
||||
if (tf_type != ::tensorflow::DataType::DT_STRING) {
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
ov_type.is_static(),
|
||||
"Encountered unknown element type " + DataType_Name(tf_type) + " on an empty tensor_proto");
|
||||
} else {
|
||||
auto data = std::vector<std::string>();
|
||||
for (const auto& item : tensor_proto.string_val()) {
|
||||
data.push_back(item);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
ov::Tensor res(ov_type, pshape.get_shape());
|
||||
auto tensor_content = tensor_proto.tensor_content();
|
||||
if (!tensor_content.empty() && tensor_proto.has_tensor_shape()) {
|
||||
switch (ov_type) {
|
||||
case ov::element::u8:
|
||||
extract_tensor_content<uint8_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::i8:
|
||||
extract_tensor_content<int8_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::i16:
|
||||
extract_tensor_content<int16_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::i32:
|
||||
extract_tensor_content<int32_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::i64:
|
||||
extract_tensor_content<int64_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::f16:
|
||||
extract_tensor_content<float16>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::f32:
|
||||
extract_tensor_content<float>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::f64:
|
||||
extract_tensor_content<double>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::bf16:
|
||||
extract_tensor_content<bfloat16>(tensor_content, &res);
|
||||
break;
|
||||
default:
|
||||
FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name());
|
||||
}
|
||||
} else {
|
||||
int64_t val_size = 0;
|
||||
switch (ov_type) {
|
||||
case ov::element::boolean:
|
||||
val_size = tensor_proto.bool_val_size();
|
||||
extract_compressed_tensor_content<bool>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::i32:
|
||||
val_size = tensor_proto.int_val_size();
|
||||
extract_compressed_tensor_content<int32_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::i64:
|
||||
val_size = tensor_proto.int64_val_size();
|
||||
extract_compressed_tensor_content<int64_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::f16:
|
||||
val_size = tensor_proto.half_val_size();
|
||||
extract_compressed_tensor_content<float16>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::f32:
|
||||
val_size = tensor_proto.float_val_size();
|
||||
extract_compressed_tensor_content<float>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::f64:
|
||||
val_size = tensor_proto.double_val_size();
|
||||
extract_compressed_tensor_content<double>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
default:
|
||||
FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name());
|
||||
}
|
||||
}
|
||||
return res;
|
||||
return unpack_tensor_proto(attrs[0].tensor());
|
||||
}
|
||||
case ::tensorflow::AttrValue::ValueCase::kPlaceholder:
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
|
@ -22,8 +22,6 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
ov::element::Type get_ov_type(const ::tensorflow::DataType& type);
|
||||
|
||||
void parse_producer_name(const std::string& producer_port_name,
|
||||
std::string& producer_name,
|
||||
std::string& producer_output_port_name,
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "openvino/op/util/multi_subgraph_base.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/file_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "so_extension.hpp"
|
||||
#include "tf_framework_node.hpp"
|
||||
@ -103,10 +104,13 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
||||
// Last boolean flag in `variants` (if presented) is reserved for FE configuration
|
||||
size_t extra_variants_num = variants.size() > 0 && variants[variants.size() - 1].is<bool>() ? 1 : 0;
|
||||
|
||||
// TODO: support checkpoint format
|
||||
// For TF1 models it can be a case of two input variants: input model and v1 checkpoints
|
||||
if (variants.size() != 1 + extra_variants_num)
|
||||
return false;
|
||||
|
||||
// to figure out if the model with v1 checkpoints is supported,
|
||||
// it is sufficient to check only the input model format
|
||||
// avoid parsing of checkpoints here
|
||||
if (variants[0].is<std::string>()) {
|
||||
std::string model_path = variants[0].as<std::string>();
|
||||
if (ov::util::ends_with(model_path, ".pb") && GraphIteratorProto::is_supported(model_path)) {
|
||||
@ -122,6 +126,18 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
||||
// handle text protobuf format
|
||||
return true;
|
||||
}
|
||||
} else if (variants[0].is<std::vector<std::string>>() && variants[0].as<std::vector<std::string>>().size() == 2) {
|
||||
// here, we assume to get the input model path and checkpoints directory
|
||||
auto paths = variants[0].as<std::vector<std::string>>();
|
||||
auto model_path = paths[0];
|
||||
auto checkpoints_dir = paths[1];
|
||||
if (GraphIteratorProto::is_supported(model_path)) {
|
||||
// binary protobuf format with checkpoints
|
||||
return true;
|
||||
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
|
||||
// text protobuf format with checkpoints
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
else if (variants[0].is<std::wstring>()) {
|
||||
@ -140,6 +156,18 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
||||
// handle text protobuf format
|
||||
return true;
|
||||
}
|
||||
} else if (variants[0].is<std::vector<std::wstring>>() && variants[0].as<std::vector<std::wstring>>().size() == 2) {
|
||||
// here, we assume to get the input model path and checkpoints directory
|
||||
auto paths = variants[0].as<std::vector<std::wstring>>();
|
||||
auto model_path = ov::util::wstring_to_string(paths[0]);
|
||||
auto checkpoints_dir = ov::util::wstring_to_string(paths[1]);
|
||||
if (GraphIteratorProto::is_supported(model_path)) {
|
||||
// binary protobuf format with checkpoints
|
||||
return true;
|
||||
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
|
||||
// text protobuf format with checkpoints
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else if (variants[0].is<GraphIterator::Ptr>()) {
|
||||
@ -150,13 +178,14 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
||||
}
|
||||
|
||||
ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& variants) const {
|
||||
// TODO: support checkpoint format
|
||||
|
||||
// Last boolean flag in `variants` (if presented) is reserved for FE configuration
|
||||
size_t extra_variants_num = variants.size() > 0 && variants[variants.size() - 1].is<bool>() ? 1 : 0;
|
||||
FRONT_END_GENERAL_CHECK(variants.size() == 1 + extra_variants_num,
|
||||
"[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports "
|
||||
"frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats.");
|
||||
|
||||
// For TF1 models it can be a case of two input variants: input model and v1 checkpoints
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
variants.size() == 1 + extra_variants_num,
|
||||
"[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports "
|
||||
"frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats, and v1 checkpoints.");
|
||||
|
||||
if (variants[0].is<std::string>()) {
|
||||
auto model_path = variants[0].as<std::string>();
|
||||
@ -175,6 +204,7 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
|
||||
graph_iterator->get_variables_index(),
|
||||
graph_iterator->get_saved_model_input_names(),
|
||||
graph_iterator->get_saved_model_output_names(),
|
||||
nullptr,
|
||||
true);
|
||||
} else if (GraphIteratorMeta::is_supported(model_path)) {
|
||||
auto graph_iterator = std::make_shared<GraphIteratorMeta>(model_path);
|
||||
@ -183,11 +213,42 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
|
||||
graph_iterator->get_variables_index(),
|
||||
graph_iterator->get_metagraph_input_names(),
|
||||
graph_iterator->get_metagraph_output_names(),
|
||||
nullptr,
|
||||
true);
|
||||
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
|
||||
// handle text protobuf format
|
||||
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProtoTxt>(model_path), m_telemetry);
|
||||
}
|
||||
} else if (variants[0].is<std::vector<std::string>>()) {
|
||||
// here, we assume to get the input model path and checkpoints directory
|
||||
auto paths = variants[0].as<std::vector<std::string>>();
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
paths.size() == 2,
|
||||
"[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports "
|
||||
"frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats, and v1 checkpoints.");
|
||||
auto model_path = paths[0];
|
||||
auto checkpoints_dir = paths[1];
|
||||
if (GraphIteratorProto::is_supported(model_path)) {
|
||||
auto graph_iterator = std::make_shared<GraphIteratorProto>(model_path, checkpoints_dir);
|
||||
// handle binary protobuf format with checkpoints
|
||||
return std::make_shared<InputModel>(graph_iterator,
|
||||
m_telemetry,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
graph_iterator->get_checkpoint_v1_reader(),
|
||||
false);
|
||||
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
|
||||
auto graph_iterator = std::make_shared<GraphIteratorProtoTxt>(model_path, checkpoints_dir);
|
||||
// handle text protobuf format with checkpoints
|
||||
return std::make_shared<InputModel>(graph_iterator,
|
||||
m_telemetry,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
graph_iterator->get_checkpoint_v1_reader(),
|
||||
false);
|
||||
}
|
||||
}
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
else if (variants[0].is<std::wstring>()) {
|
||||
@ -209,6 +270,7 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
|
||||
graph_iterator->get_variables_index(),
|
||||
graph_iterator->get_saved_model_input_names(),
|
||||
graph_iterator->get_saved_model_output_names(),
|
||||
nullptr,
|
||||
true);
|
||||
} else if (GraphIteratorMeta::is_supported(model_path)) {
|
||||
auto graph_iterator = std::make_shared<GraphIteratorMeta>(model_path);
|
||||
@ -217,11 +279,42 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
|
||||
graph_iterator->get_variables_index(),
|
||||
graph_iterator->get_metagraph_input_names(),
|
||||
graph_iterator->get_metagraph_output_names(),
|
||||
nullptr,
|
||||
true);
|
||||
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
|
||||
// handle text protobuf format with a path in Unicode
|
||||
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProtoTxt>(model_path), m_telemetry);
|
||||
}
|
||||
} else if (variants[0].is<std::vector<std::wstring>>()) {
|
||||
// here, we assume to get the input model path and checkpoints directory
|
||||
auto paths = variants[0].as<std::vector<std::wstring>>();
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
paths.size() == 2,
|
||||
"[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports "
|
||||
"frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats, and v1 checkpoints.");
|
||||
auto model_path = ov::util::wstring_to_string(paths[0]);
|
||||
auto checkpoints_dir = ov::util::wstring_to_string(paths[1]);
|
||||
if (GraphIteratorProto::is_supported(model_path)) {
|
||||
auto graph_iterator = std::make_shared<GraphIteratorProto>(model_path, checkpoints_dir);
|
||||
// handle binary protobuf format with checkpoints
|
||||
return std::make_shared<InputModel>(graph_iterator,
|
||||
m_telemetry,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
graph_iterator->get_checkpoint_v1_reader(),
|
||||
false);
|
||||
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
|
||||
auto graph_iterator = std::make_shared<GraphIteratorProtoTxt>(model_path, checkpoints_dir);
|
||||
// handle text protobuf format with checkpoints
|
||||
return std::make_shared<InputModel>(graph_iterator,
|
||||
m_telemetry,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
graph_iterator->get_checkpoint_v1_reader(),
|
||||
false);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else if (variants[0].is<GraphIterator::Ptr>()) {
|
||||
@ -232,7 +325,7 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
|
||||
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports "
|
||||
"frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta) formats.");
|
||||
"frozen formats (.pb and .pbtxt), SavedModel and MetaGraph (.meta), and v1 checkpoints.");
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "checkpoint_v1_reader.hpp"
|
||||
#include "decoder_argdef.hpp"
|
||||
#include "decoder_proto.hpp"
|
||||
#include "graph.pb.h"
|
||||
@ -22,6 +23,7 @@ class GraphIteratorProto : public GraphIterator {
|
||||
protected:
|
||||
std::shared_ptr<::tensorflow::GraphDef> m_graph_def;
|
||||
std::shared_ptr<::tensorflow::FunctionDef> m_func_def;
|
||||
std::shared_ptr<CheckpointV1Reader> m_checkpoint_v1_reader;
|
||||
|
||||
size_t node_index = 0;
|
||||
std::vector<std::shared_ptr<DecoderBase>> m_decoders;
|
||||
@ -32,6 +34,7 @@ protected:
|
||||
GraphIteratorProto()
|
||||
: m_graph_def(std::make_shared<::tensorflow::GraphDef>()),
|
||||
m_func_def(nullptr),
|
||||
m_checkpoint_v1_reader(nullptr),
|
||||
m_library_map() {}
|
||||
|
||||
void initialize_decoders_and_library() {
|
||||
@ -52,12 +55,20 @@ protected:
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void initialize_v1_checkpoints(const std::basic_string<T>& checkpoint_directory) {
|
||||
m_checkpoint_v1_reader = std::make_shared<CheckpointV1Reader>(checkpoint_directory);
|
||||
m_checkpoint_v1_reader->initialize();
|
||||
}
|
||||
|
||||
public:
|
||||
GraphIteratorProto(const std::shared_ptr<::tensorflow::GraphDef>& graph_def,
|
||||
const std::shared_ptr<::tensorflow::FunctionDef>& func_def,
|
||||
const std::unordered_map<std::string, int>& library_map)
|
||||
const std::unordered_map<std::string, int>& library_map,
|
||||
const std::shared_ptr<CheckpointV1Reader> checkpoint_v1_reader)
|
||||
: m_graph_def(graph_def),
|
||||
m_func_def(func_def),
|
||||
m_checkpoint_v1_reader(checkpoint_v1_reader),
|
||||
m_library_map(library_map) {
|
||||
auto nodes_size = m_func_def->node_def_size();
|
||||
auto input_size = m_func_def->signature().input_arg_size();
|
||||
@ -91,11 +102,13 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Construct GraphIterator for the frozen model without v1 checkpoints
|
||||
template <typename T>
|
||||
GraphIteratorProto(const std::basic_string<T>& path)
|
||||
GraphIteratorProto(const std::basic_string<T>& model_path)
|
||||
: m_graph_def(std::make_shared<::tensorflow::GraphDef>()),
|
||||
m_func_def(nullptr) {
|
||||
std::ifstream pb_stream(path.c_str(), std::ios::in | std::ifstream::binary);
|
||||
m_func_def(nullptr),
|
||||
m_checkpoint_v1_reader(nullptr) {
|
||||
std::ifstream pb_stream(model_path, std::ios::in | std::ifstream::binary);
|
||||
|
||||
FRONT_END_GENERAL_CHECK(pb_stream && pb_stream.is_open(), "Model file does not exist");
|
||||
FRONT_END_GENERAL_CHECK(m_graph_def->ParseFromIstream(&pb_stream), "Model cannot be parsed");
|
||||
@ -103,11 +116,26 @@ public:
|
||||
initialize_decoders_and_library();
|
||||
}
|
||||
|
||||
/// \brief Construct GraphIterator for the frozen model with v1 checkpoints
|
||||
template <typename T>
|
||||
GraphIteratorProto(const std::basic_string<T>& model_path, const std::basic_string<T>& checkpoint_directory)
|
||||
: m_graph_def(std::make_shared<::tensorflow::GraphDef>()),
|
||||
m_func_def(nullptr),
|
||||
m_checkpoint_v1_reader(nullptr) {
|
||||
std::ifstream pb_stream(model_path, std::ios::in | std::ifstream::binary);
|
||||
|
||||
FRONT_END_GENERAL_CHECK(pb_stream && pb_stream.is_open(), "Model file does not exist");
|
||||
FRONT_END_GENERAL_CHECK(m_graph_def->ParseFromIstream(&pb_stream), "Model cannot be parsed");
|
||||
|
||||
initialize_decoders_and_library();
|
||||
initialize_v1_checkpoints(checkpoint_directory);
|
||||
}
|
||||
|
||||
/// \brief Check if the input file is supported
|
||||
template <typename T>
|
||||
static bool is_supported(const std::basic_string<T>& path) {
|
||||
try {
|
||||
std::ifstream pb_stream(path.c_str(), std::ios::in | std::ifstream::binary);
|
||||
std::ifstream pb_stream(path, std::ios::in | std::ifstream::binary);
|
||||
auto graph_def = std::make_shared<::tensorflow::GraphDef>();
|
||||
return pb_stream && pb_stream.is_open() && graph_def->ParsePartialFromIstream(&pb_stream) &&
|
||||
graph_def->node_size() > 0;
|
||||
@ -116,6 +144,11 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Get checkpoint v1 reader for checkpoint restoring in translator for Variable operation
|
||||
std::shared_ptr<CheckpointV1Reader> get_checkpoint_v1_reader() const {
|
||||
return m_checkpoint_v1_reader;
|
||||
}
|
||||
|
||||
/// \brief Set iterator to the start position
|
||||
void reset() override {
|
||||
node_index = 0;
|
||||
@ -152,7 +185,7 @@ public:
|
||||
|
||||
auto func = m_graph_def->library().function(func_ind);
|
||||
auto func_ptr = std::make_shared<::tensorflow::FunctionDef>(func);
|
||||
return std::make_shared<GraphIteratorProto>(m_graph_def, func_ptr, m_library_map);
|
||||
return std::make_shared<GraphIteratorProto>(m_graph_def, func_ptr, m_library_map, m_checkpoint_v1_reader);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
@ -17,9 +17,10 @@ namespace tensorflow {
|
||||
|
||||
class GraphIteratorProtoTxt : public GraphIteratorProto {
|
||||
public:
|
||||
/// \brief Construct GraphIterator for the frozen model in text format without v1 checkpoints
|
||||
template <typename T>
|
||||
GraphIteratorProtoTxt(const std::basic_string<T>& path) : GraphIteratorProto() {
|
||||
std::ifstream pbtxt_stream(path.c_str(), std::ios::in);
|
||||
std::ifstream pbtxt_stream(path, std::ios::in);
|
||||
FRONT_END_GENERAL_CHECK(pbtxt_stream && pbtxt_stream.is_open(), "Model file does not exist");
|
||||
auto input_stream = std::make_shared<::google::protobuf::io::IstreamInputStream>(&pbtxt_stream);
|
||||
FRONT_END_GENERAL_CHECK(input_stream, "Model cannot be read");
|
||||
@ -31,6 +32,23 @@ public:
|
||||
initialize_decoders_and_library();
|
||||
}
|
||||
|
||||
/// \brief Construct GraphIterator for the frozen model in text format with v1 checkpoints
|
||||
template <typename T>
|
||||
GraphIteratorProtoTxt(const std::basic_string<T>& path, const std::basic_string<T>& checkpoint_directory)
|
||||
: GraphIteratorProto() {
|
||||
std::ifstream pbtxt_stream(path, std::ios::in);
|
||||
FRONT_END_GENERAL_CHECK(pbtxt_stream && pbtxt_stream.is_open(), "Model file does not exist");
|
||||
auto input_stream = std::make_shared<::google::protobuf::io::IstreamInputStream>(&pbtxt_stream);
|
||||
FRONT_END_GENERAL_CHECK(input_stream, "Model cannot be read");
|
||||
auto is_parsed = ::google::protobuf::TextFormat::Parse(input_stream.get(), m_graph_def.get());
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
is_parsed,
|
||||
"[TensorFlow Frontend] Incorrect model or internal error: Model in text Protobuf format cannot be parsed.");
|
||||
|
||||
initialize_decoders_and_library();
|
||||
initialize_v1_checkpoints(checkpoint_directory);
|
||||
}
|
||||
|
||||
/// \brief Check if the input file is supported
|
||||
template <typename T>
|
||||
static bool is_supported(const std::basic_string<T>& path) {
|
||||
|
@ -59,6 +59,7 @@ public:
|
||||
const std::shared_ptr<VariablesIndex>& variables_index,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names,
|
||||
const std::shared_ptr<CheckpointV1Reader> checkpoint_v1_reader,
|
||||
const bool native_format = false);
|
||||
std::vector<ov::frontend::Place::Ptr> get_inputs() const;
|
||||
std::vector<ov::frontend::Place::Ptr> get_outputs() const;
|
||||
@ -86,6 +87,7 @@ public:
|
||||
std::shared_ptr<VariablesIndex> get_variables_index() const;
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_input_names() const;
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_output_names() const;
|
||||
std::shared_ptr<CheckpointV1Reader> get_checkpoint_v1_reader() const;
|
||||
|
||||
private:
|
||||
void load_places();
|
||||
@ -110,6 +112,7 @@ private:
|
||||
std::shared_ptr<VariablesIndex> m_variables_index;
|
||||
std::shared_ptr<std::map<std::string, std::string>> m_saved_model_input_names;
|
||||
std::shared_ptr<std::map<std::string, std::string>> m_saved_model_output_names;
|
||||
std::shared_ptr<CheckpointV1Reader> m_checkpoint_v1_reader;
|
||||
|
||||
bool m_native_format;
|
||||
bool m_custom_inputs;
|
||||
@ -256,6 +259,10 @@ std::shared_ptr<std::map<std::string, std::string>> InputModel::InputModelTFImpl
|
||||
return m_saved_model_output_names;
|
||||
}
|
||||
|
||||
std::shared_ptr<CheckpointV1Reader> InputModel::InputModelTFImpl::get_checkpoint_v1_reader() const {
|
||||
return m_checkpoint_v1_reader;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::get_op_places() {
|
||||
return topologically_sort_op_nodes();
|
||||
}
|
||||
@ -417,6 +424,7 @@ InputModel::InputModelTFImpl::InputModelTFImpl(
|
||||
const std::shared_ptr<VariablesIndex>& variables_index,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names,
|
||||
const std::shared_ptr<CheckpointV1Reader> checkpoint_v1_reader,
|
||||
const bool native_format)
|
||||
: m_graph_iterator(graph_iterator),
|
||||
m_input_model(input_model),
|
||||
@ -424,6 +432,7 @@ InputModel::InputModelTFImpl::InputModelTFImpl(
|
||||
m_variables_index(variables_index),
|
||||
m_saved_model_input_names(saved_model_input_names),
|
||||
m_saved_model_output_names(saved_model_output_names),
|
||||
m_checkpoint_v1_reader(checkpoint_v1_reader),
|
||||
m_native_format(native_format) {
|
||||
FRONT_END_GENERAL_CHECK(m_graph_iterator, "Null pointer specified for GraphIterator");
|
||||
m_input_names = graph_iterator->get_input_names();
|
||||
@ -602,6 +611,7 @@ InputModel::InputModel(const GraphIterator::Ptr& graph_iterator,
|
||||
const std::shared_ptr<VariablesIndex>& variables_index,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names,
|
||||
const std::shared_ptr<CheckpointV1Reader> checkpoint_v1_reader,
|
||||
const bool native_format)
|
||||
: _impl{std::make_shared<InputModelTFImpl>(graph_iterator,
|
||||
*this,
|
||||
@ -609,6 +619,7 @@ InputModel::InputModel(const GraphIterator::Ptr& graph_iterator,
|
||||
variables_index,
|
||||
saved_model_input_names,
|
||||
saved_model_output_names,
|
||||
checkpoint_v1_reader,
|
||||
native_format)} {}
|
||||
|
||||
std::shared_ptr<VariablesIndex> InputModel::get_variables_index() {
|
||||
@ -623,6 +634,10 @@ std::shared_ptr<std::map<std::string, std::string>> InputModel::get_saved_model_
|
||||
return _impl->get_saved_model_output_names();
|
||||
}
|
||||
|
||||
std::shared_ptr<CheckpointV1Reader> InputModel::get_checkpoint_v1_reader() const {
|
||||
return _impl->get_checkpoint_v1_reader();
|
||||
}
|
||||
|
||||
std::vector<std::string> InputModel::get_input_names() const {
|
||||
return _impl->get_input_names();
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "checkpoint_v1_reader.hpp"
|
||||
#include "openvino/frontend/extension/telemetry.hpp"
|
||||
#include "openvino/frontend/graph_iterator.hpp"
|
||||
#include "openvino/frontend/input_model.hpp"
|
||||
@ -35,6 +36,7 @@ public:
|
||||
const std::shared_ptr<VariablesIndex>& variables_index = {},
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_input_names = nullptr,
|
||||
const std::shared_ptr<std::map<std::string, std::string>> saved_model_output_names = nullptr,
|
||||
const std::shared_ptr<CheckpointV1Reader> checkpoint_v1_reader = nullptr,
|
||||
const bool native_format = false);
|
||||
|
||||
std::vector<ov::frontend::Place::Ptr> get_inputs() const override;
|
||||
@ -53,6 +55,8 @@ public:
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_input_names() const;
|
||||
std::shared_ptr<std::map<std::string, std::string>> get_saved_model_output_names() const;
|
||||
|
||||
std::shared_ptr<CheckpointV1Reader> get_checkpoint_v1_reader() const;
|
||||
|
||||
std::map<std::string, std::shared_ptr<TensorPlace>> get_tensor_places() const;
|
||||
};
|
||||
|
||||
|
64
src/frontends/tensorflow/src/op/variable.cpp
Normal file
64
src/frontends/tensorflow/src/op/variable.cpp
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "helper_ops/string_constant.hpp"
|
||||
#include "helper_ops/unsupported_constant.hpp"
|
||||
#include "input_model.hpp"
|
||||
#include "openvino/frontend/tensorflow/node_context.hpp"
|
||||
#include "openvino/opsets/opset11.hpp"
|
||||
#include "translate_session.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::frontend::tensorflow;
|
||||
using namespace ov::opset11;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_variable_op(const NodeContext& node) {
|
||||
default_op_checks(node, 0, {"Variable"});
|
||||
auto variable_name = node.get_name();
|
||||
|
||||
auto translate_session = node.get_translate_session();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
translate_session,
|
||||
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
|
||||
auto model = dynamic_pointer_cast<ov::frontend::tensorflow::InputModel>(translate_session->get_input_model());
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
model,
|
||||
"[TensorFlow Frontend] Internal error: input model is unable to cast to TensorFlow Frontend InputModel.");
|
||||
auto checkpoint_v1_reader = model->get_checkpoint_v1_reader();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
checkpoint_v1_reader,
|
||||
"[TensorFlow Frontend] incorrect input model: checkpoint to restore variable " +
|
||||
variable_name + " is not provided.");
|
||||
|
||||
ov::Any variable_data;
|
||||
checkpoint_v1_reader->read_variable(variable_name, variable_data);
|
||||
|
||||
shared_ptr<Node> const_node = nullptr;
|
||||
if (variable_data.is<ov::Tensor>()) {
|
||||
auto ov_tensor = variable_data.as<ov::Tensor>();
|
||||
const_node = make_shared<Constant>(ov_tensor);
|
||||
} else if (variable_data.is<std::vector<std::string>>()) {
|
||||
// a case of string tensor that should be assigned to the variable
|
||||
const_node = make_shared<StringConstant>(variable_data, node.get_decoder());
|
||||
} else {
|
||||
// data of unknown type
|
||||
auto const_node = make_shared<UnsupportedConstant>("Variable of unsupported type", node.get_decoder());
|
||||
}
|
||||
|
||||
set_node_name(variable_name, const_node);
|
||||
return {const_node};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -20,29 +20,30 @@ namespace op {
|
||||
|
||||
#define TF_OP_CONVERTER(op) OutputVector op(const ov::frontend::tensorflow::NodeContext& node)
|
||||
|
||||
TF_OP_CONVERTER(translate_if_op);
|
||||
TF_OP_CONVERTER(translate_assignvariable_op);
|
||||
TF_OP_CONVERTER(translate_block_lstm_op);
|
||||
TF_OP_CONVERTER(translate_fifo_queue_op);
|
||||
TF_OP_CONVERTER(translate_gru_block_cell_op);
|
||||
TF_OP_CONVERTER(translate_hash_table_op);
|
||||
TF_OP_CONVERTER(translate_if_op);
|
||||
TF_OP_CONVERTER(translate_iterator_get_next_op);
|
||||
TF_OP_CONVERTER(translate_iterator_op);
|
||||
TF_OP_CONVERTER(translate_mergev2checkpoint_op);
|
||||
TF_OP_CONVERTER(translate_partitioned_call_op);
|
||||
TF_OP_CONVERTER(translate_placeholder_linked_op);
|
||||
TF_OP_CONVERTER(translate_queue_dequeue_op);
|
||||
TF_OP_CONVERTER(translate_queue_dequeue_many_op);
|
||||
TF_OP_CONVERTER(translate_readvariable_op);
|
||||
TF_OP_CONVERTER(translate_restorev2_op);
|
||||
TF_OP_CONVERTER(translate_sparse_fill_empty_rows_op);
|
||||
TF_OP_CONVERTER(translate_sparse_reshape_op);
|
||||
TF_OP_CONVERTER(translate_sparse_segment_sum_op);
|
||||
TF_OP_CONVERTER(translate_varisinitialized_op);
|
||||
TF_OP_CONVERTER(translate_readvariable_op);
|
||||
TF_OP_CONVERTER(translate_assignvariable_op);
|
||||
TF_OP_CONVERTER(translate_varhandle_op);
|
||||
TF_OP_CONVERTER(translate_restorev2_op);
|
||||
TF_OP_CONVERTER(translate_staticregexfullmatch_op);
|
||||
TF_OP_CONVERTER(translate_stringjoin_op);
|
||||
TF_OP_CONVERTER(translate_mergev2checkpoint_op);
|
||||
TF_OP_CONVERTER(translate_varhandle_op);
|
||||
TF_OP_CONVERTER(translate_variable_op);
|
||||
TF_OP_CONVERTER(translate_varisinitialized_op);
|
||||
TF_OP_CONVERTER(translate_while_op);
|
||||
TF_OP_CONVERTER(translate_placeholder_linked_op);
|
||||
|
||||
const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
return {
|
||||
@ -276,6 +277,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"VarHandleOp", CreatorFunction(translate_varhandle_op)},
|
||||
{"VariableV2", CreatorFunction(translate_varhandle_op)},
|
||||
|
||||
// Translator for Checkpoint V1
|
||||
{"Variable", CreatorFunction(translate_variable_op)},
|
||||
|
||||
// Translators for internal operations
|
||||
{"BlockLSTM", CreatorFunction(translate_block_lstm_op)},
|
||||
{"GRUBlockCell", CreatorFunction(translate_gru_block_cell_op)},
|
||||
|
192
src/frontends/tensorflow/src/tf_utils.cpp
Normal file
192
src/frontends/tensorflow/src/tf_utils.cpp
Normal file
@ -0,0 +1,192 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "tf_utils.hpp"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "openvino/core/type/element_type.hpp"
|
||||
#include "openvino/frontend/exception.hpp"
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
|
||||
using namespace ov;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void extract_tensor_content(const std::string& tensor_content, ov::Tensor* values) {
|
||||
const auto tensor_content_size = tensor_content.size();
|
||||
FRONT_END_GENERAL_CHECK(tensor_content_size % sizeof(T) == 0,
|
||||
"Size of tensor_content (",
|
||||
tensor_content_size,
|
||||
") is not a multiple of ",
|
||||
sizeof(T));
|
||||
|
||||
const T* tensor_values = reinterpret_cast<const T*>(tensor_content.data());
|
||||
FRONT_END_GENERAL_CHECK(values->get_size() == tensor_content_size / sizeof(T),
|
||||
"Size of tensor is not equal to tensor_content size.");
|
||||
std::copy(tensor_values, tensor_values + tensor_content_size / sizeof(T), values->data<T>());
|
||||
}
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
# pragma warning(push)
|
||||
# pragma warning(disable : 4244) // possible loss of data
|
||||
# pragma warning(disable : 4267) // possible loss of data
|
||||
#endif
|
||||
template <typename T>
|
||||
void extract_compressed_tensor_content(const ::tensorflow::TensorProto& tensor_proto,
|
||||
int64_t val_size,
|
||||
ov::Tensor* values) {
|
||||
auto val_lastsaved = static_cast<T>(0);
|
||||
auto values_data = values->data<T>();
|
||||
for (size_t i = 0; i < values->get_size(); i++) {
|
||||
if (val_size == 0) {
|
||||
values_data[i] = static_cast<T>(0);
|
||||
} else if (static_cast<int64_t>(i) < val_size) {
|
||||
auto val_i = static_cast<T>(0);
|
||||
switch (values->get_element_type()) {
|
||||
// TODO: there are more element types to support here
|
||||
case ov::element::boolean:
|
||||
val_i = tensor_proto.bool_val()[i];
|
||||
break;
|
||||
case ov::element::i32:
|
||||
val_i = tensor_proto.int_val()[i];
|
||||
break;
|
||||
case ov::element::i64:
|
||||
val_i = tensor_proto.int64_val()[i];
|
||||
break;
|
||||
case ov::element::f16:
|
||||
val_i = float16::from_bits(tensor_proto.half_val()[i]);
|
||||
break;
|
||||
case ov::element::f32:
|
||||
val_i = tensor_proto.float_val()[i];
|
||||
break;
|
||||
case ov::element::f64:
|
||||
val_i = tensor_proto.double_val()[i];
|
||||
break;
|
||||
default:
|
||||
FRONT_END_THROW("Encountered unknown element type " + values->get_element_type().get_type_name());
|
||||
}
|
||||
values_data[i] = val_i;
|
||||
val_lastsaved = val_i;
|
||||
} else {
|
||||
values_data[i] = val_lastsaved;
|
||||
}
|
||||
}
|
||||
}
|
||||
#if defined(_MSC_VER)
|
||||
# pragma warning(pop)
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
ov::element::Type ov::frontend::tensorflow::get_ov_type(const ::tensorflow::DataType& type) {
|
||||
static const std::map<::tensorflow::DataType, ov::element::Type> type_map{
|
||||
{::tensorflow::DataType::DT_BOOL, ov::element::boolean},
|
||||
{::tensorflow::DataType::DT_INT16, ov::element::i16},
|
||||
{::tensorflow::DataType::DT_INT32, ov::element::i32},
|
||||
{::tensorflow::DataType::DT_INT64, ov::element::i64},
|
||||
{::tensorflow::DataType::DT_HALF, ov::element::f16},
|
||||
{::tensorflow::DataType::DT_FLOAT, ov::element::f32},
|
||||
{::tensorflow::DataType::DT_DOUBLE, ov::element::f64},
|
||||
{::tensorflow::DataType::DT_UINT8, ov::element::u8},
|
||||
{::tensorflow::DataType::DT_INT8, ov::element::i8},
|
||||
{::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16}};
|
||||
|
||||
auto it = type_map.find(type);
|
||||
// for all unsupported types return dynamic type
|
||||
return it == type_map.end() ? ov::element::dynamic : it->second;
|
||||
}
|
||||
|
||||
ov::Any ov::frontend::tensorflow::unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto) {
|
||||
return unpack_tensor_proto(tensor_proto, tensor_proto.tensor_shape(), tensor_proto.dtype());
|
||||
}
|
||||
|
||||
ov::Any ov::frontend::tensorflow::unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto,
|
||||
const ::tensorflow::TensorShapeProto& tensor_shape,
|
||||
const ::tensorflow::DataType& tensor_type) {
|
||||
ov::PartialShape pshape;
|
||||
for (int i = 0; i < tensor_shape.dim_size(); i++) {
|
||||
pshape.push_back(tensor_shape.dim(i).size());
|
||||
}
|
||||
FRONT_END_GENERAL_CHECK(pshape.is_static(), "Dynamic shapes are not supported for Tensor attribute.");
|
||||
ov::element::Type ov_type = get_ov_type(tensor_type);
|
||||
|
||||
if (tensor_type != ::tensorflow::DataType::DT_STRING) {
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
ov_type.is_static(),
|
||||
"Encountered unknown element type " + DataType_Name(tensor_type) + " on an empty tensor_proto");
|
||||
} else {
|
||||
auto data = std::vector<std::string>();
|
||||
for (const auto& item : tensor_proto.string_val()) {
|
||||
data.push_back(item);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
ov::Tensor res(ov_type, pshape.get_shape());
|
||||
auto tensor_content = tensor_proto.tensor_content();
|
||||
if (!tensor_content.empty() && tensor_proto.has_tensor_shape()) {
|
||||
switch (ov_type) {
|
||||
case ov::element::u8:
|
||||
extract_tensor_content<uint8_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::i8:
|
||||
extract_tensor_content<int8_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::i16:
|
||||
extract_tensor_content<int16_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::i32:
|
||||
extract_tensor_content<int32_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::i64:
|
||||
extract_tensor_content<int64_t>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::f16:
|
||||
extract_tensor_content<float16>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::f32:
|
||||
extract_tensor_content<float>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::f64:
|
||||
extract_tensor_content<double>(tensor_content, &res);
|
||||
break;
|
||||
case ov::element::bf16:
|
||||
extract_tensor_content<bfloat16>(tensor_content, &res);
|
||||
break;
|
||||
default:
|
||||
FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name());
|
||||
}
|
||||
} else {
|
||||
int64_t val_size = 0;
|
||||
switch (ov_type) {
|
||||
case ov::element::boolean:
|
||||
val_size = tensor_proto.bool_val_size();
|
||||
extract_compressed_tensor_content<bool>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::i32:
|
||||
val_size = tensor_proto.int_val_size();
|
||||
extract_compressed_tensor_content<int32_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::i64:
|
||||
val_size = tensor_proto.int64_val_size();
|
||||
extract_compressed_tensor_content<int64_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::f16:
|
||||
val_size = tensor_proto.half_val_size();
|
||||
extract_compressed_tensor_content<float16>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::f32:
|
||||
val_size = tensor_proto.float_val_size();
|
||||
extract_compressed_tensor_content<float>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case ov::element::f64:
|
||||
val_size = tensor_proto.double_val_size();
|
||||
extract_compressed_tensor_content<double>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
default:
|
||||
FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name());
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
31
src/frontends/tensorflow/src/tf_utils.hpp
Normal file
31
src/frontends/tensorflow/src/tf_utils.hpp
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "attr_value.pb.h"
|
||||
#include "node_def.pb.h"
|
||||
#include "openvino/core/partial_shape.hpp"
|
||||
#include "openvino/core/type.hpp"
|
||||
#include "openvino/core/type/element_type.hpp"
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
#include "tensor.pb.h"
|
||||
#include "tensor_shape.pb.h"
|
||||
#include "types.pb.h"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
ov::element::Type get_ov_type(const ::tensorflow::DataType& type);
|
||||
|
||||
ov::Any unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto);
|
||||
|
||||
ov::Any unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto,
|
||||
const ::tensorflow::TensorShapeProto& tensor_shape,
|
||||
const ::tensorflow::DataType& tensor_type);
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -7,6 +7,7 @@
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#include "checkpoint_utils.hpp"
|
||||
#include "graph_iterator_saved_model.hpp"
|
||||
#include "openvino/core/type/element_type.hpp"
|
||||
#include "tensor_bundle.pb.h"
|
||||
@ -20,80 +21,6 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
static T smReadFixed(const char* ptr) {
|
||||
T result = 0;
|
||||
for (uint8_t i = 0; i < sizeof(T); ++i) {
|
||||
result |= static_cast<const uint8_t>(ptr[i]) << (i * 8);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T smUnpack(char*& ptr, const char* ptr_end) {
|
||||
T result = 0;
|
||||
for (uint8_t i = 0; i < sizeof(T) * 7 && ptr < ptr_end; i += 7) {
|
||||
T byte = *(ptr++);
|
||||
if (byte & 0x80) {
|
||||
result |= ((byte & 0x7F) << i);
|
||||
} else {
|
||||
result |= byte << i;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// \brief Structure is for storing information about block in Varaibles Index file.
|
||||
/// It defines only offset and block size, no information about exact content.
|
||||
struct VIBlock {
|
||||
uint64_t m_size;
|
||||
uint64_t m_offset;
|
||||
|
||||
void read(char*& ptr, const char* ptr_end) {
|
||||
m_offset = smUnpack<uint64_t>(ptr, ptr_end);
|
||||
m_size = smUnpack<uint64_t>(ptr, ptr_end);
|
||||
}
|
||||
};
|
||||
|
||||
#define VARIABLES_INDEX_FOOTER_SIZE 48
|
||||
|
||||
/// \brief Structure is for storing information about Variables Index footer information.
|
||||
/// It contains description of two blocks and a magic number for a file verification.
|
||||
/// Currently, it is placed in last VARIABLES_INDEX_FOOTER_SIZE bytes at the end of a file.
|
||||
struct VIFooter {
|
||||
VIBlock m_metaIndex;
|
||||
VIBlock m_index;
|
||||
|
||||
void read(char*& ptr, const char* ptr_end) {
|
||||
m_index.read(ptr, ptr_end);
|
||||
m_metaIndex.read(ptr, ptr_end);
|
||||
}
|
||||
|
||||
void read(std::ifstream& fs) {
|
||||
fs.seekg(0, std::ios::end);
|
||||
size_t size = fs.tellg();
|
||||
FRONT_END_GENERAL_CHECK(size >= VARIABLES_INDEX_FOOTER_SIZE,
|
||||
"Wrong index file, file size is less than minimal expected");
|
||||
|
||||
char footerData[VARIABLES_INDEX_FOOTER_SIZE] = {}, *ptr = &footerData[0];
|
||||
fs.seekg(size - sizeof(footerData));
|
||||
fs.read(ptr, sizeof(footerData));
|
||||
|
||||
// https://github.com/tensorflow/tensorflow/blob/9659b7bdca80a8ef8240eb021d4da089034eeb00/tensorflow/tsl/lib/io/format.cc#L59
|
||||
ptr += sizeof(footerData) - 8;
|
||||
uint32_t magic_lo = *reinterpret_cast<const uint32_t*>(ptr);
|
||||
uint32_t magic_hi = *reinterpret_cast<const uint32_t*>(ptr + 4);
|
||||
uint64_t magic_no = (static_cast<uint64_t>(magic_hi) << 32) | static_cast<uint64_t>(magic_lo);
|
||||
|
||||
FRONT_END_GENERAL_CHECK(magic_no == 0xdb4775248b80fb57ull, "Wrong index file, magic number mismatch detected");
|
||||
|
||||
ptr = &footerData[0];
|
||||
m_metaIndex.read(ptr, ptr + sizeof(footerData));
|
||||
m_index.read(ptr, ptr + sizeof(footerData));
|
||||
}
|
||||
};
|
||||
|
||||
void VariablesIndex::read_variables_index_block(std::ifstream& fs,
|
||||
const VIBlock& index,
|
||||
std::vector<char>& data,
|
||||
@ -101,7 +28,7 @@ void VariablesIndex::read_variables_index_block(std::ifstream& fs,
|
||||
uint32_t& offset_end) {
|
||||
size_t block_size = index.m_size;
|
||||
data.clear();
|
||||
data.resize(block_size + 5 /*kBlockTrailerSize*/);
|
||||
data.resize(block_size + BLOCK_TRAILER_SIZE);
|
||||
FRONT_END_GENERAL_CHECK(index.m_offset <= m_variables_index_size,
|
||||
"Block offset is bigger than variables index size");
|
||||
FRONT_END_GENERAL_CHECK(index.m_offset + data.size() <= m_variables_index_size,
|
||||
@ -124,11 +51,11 @@ void VariablesIndex::read_variables_index_block(std::ifstream& fs,
|
||||
block_size = uncompressed_length;
|
||||
}
|
||||
#endif
|
||||
uint32_t numRestarts = smReadFixed<uint32_t>(data.data() + block_size - sizeof(uint32_t));
|
||||
uint32_t numRestarts = decode_fixed32(data.data() + block_size - sizeof(uint32_t));
|
||||
size_t maxRestarts = (block_size - sizeof(uint32_t)) / sizeof(uint32_t);
|
||||
FRONT_END_GENERAL_CHECK(maxRestarts >= numRestarts, "Wrong restarts value");
|
||||
offset_end = static_cast<uint32_t>(block_size) - ((numRestarts + 1) * sizeof(uint32_t));
|
||||
offset = smReadFixed<uint32_t>(data.data() + offset_end);
|
||||
offset = decode_fixed32(data.data() + offset_end);
|
||||
}
|
||||
|
||||
void VariablesIndex::read_variables_index_pair(char*& ptr,
|
||||
|
@ -334,14 +334,9 @@ def convert_to_pb(argv: argparse.Namespace):
|
||||
if "tensorflow" in env_setup and env_setup["tensorflow"] >= LooseVersion("2.0.0"):
|
||||
tf.keras.backend.clear_session()
|
||||
|
||||
# if this is already binary or text frozen format .pb or .pbtxt,
|
||||
# there is no need to create auxiliary binary frozen protobuf
|
||||
if argv.input_model and not argv.input_checkpoint and \
|
||||
isinstance(argv.input_model, str):
|
||||
return None
|
||||
|
||||
# Saved Model format and MetaGraph format is supported without freezing
|
||||
if argv.saved_model_dir or argv.input_meta_graph:
|
||||
# any model format on disk is accepted by TensorFlow Frontend
|
||||
# only model from memory requires temporal saving on a disk
|
||||
if (argv.input_model and isinstance(argv.input_model, str)) or argv.saved_model_dir or argv.input_meta_graph:
|
||||
return None
|
||||
|
||||
user_output_node_names_list = argv.output if argv.output else None
|
||||
|
@ -34,7 +34,12 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
raise Exception("ONNX frontend does not support input model as BytesIO object. "
|
||||
"Please use use_legacy_frontend=True to convert the model.")
|
||||
else:
|
||||
if argv.input_model:
|
||||
input_checkpoint = getattr(argv, 'input_checkpoint', None)
|
||||
if argv.input_model and input_checkpoint:
|
||||
# frozen format with v1 checkpoints
|
||||
input_model = moc_front_end.load([argv.input_model, argv.input_checkpoint])
|
||||
elif argv.input_model:
|
||||
# frozen model without v1 checkpoints
|
||||
input_model = moc_front_end.load(argv.input_model)
|
||||
elif argv.saved_model_dir:
|
||||
input_model = moc_front_end.load(argv.saved_model_dir)
|
||||
|
@ -28,6 +28,7 @@ def base_args_config(use_legacy_fe: bool = None, use_new_fe: bool = None):
|
||||
args.framework = "onnx"
|
||||
args.model_name = None
|
||||
args.input_model = None
|
||||
args.input_checkpoint = None
|
||||
args.silent = True
|
||||
args.transform = []
|
||||
args.scale = None
|
||||
|
@ -27,6 +27,7 @@ def base_args_config():
|
||||
args.framework = 'onnx'
|
||||
args.model_name = None
|
||||
args.input_model = None
|
||||
args.input_checkpoint = None
|
||||
args.silent = True
|
||||
args.transform=[]
|
||||
args.scale = None
|
||||
|
@ -33,7 +33,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
|
||||
def test_conversion_fake_pb_model(self, use_new_frontend, use_legacy_frontend, framework):
|
||||
with self.assertRaisesRegex(Exception,
|
||||
"Internal error or inconsistent input model: the frontend supports frozen formats"
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats."):
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints."):
|
||||
path = os.path.dirname(__file__)
|
||||
input_model = os.path.join(path, "test_models", "fake.pb")
|
||||
|
||||
@ -51,7 +51,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
|
||||
(
|
||||
False, False, "tf",
|
||||
r"Internal error or inconsistent input model: the frontend supports frozen formats"
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats."
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints."
|
||||
),
|
||||
# new frontend
|
||||
(
|
||||
@ -61,7 +61,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
|
||||
(
|
||||
True, False, "tf",
|
||||
r"Internal error or inconsistent input model: the frontend supports frozen formats"
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats."
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints."
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -83,7 +83,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
|
||||
(
|
||||
False, False, "tf",
|
||||
r"Internal error or inconsistent input model: the frontend supports frozen formats"
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats."
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints."
|
||||
),
|
||||
# new frontend
|
||||
(
|
||||
@ -93,7 +93,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
|
||||
(
|
||||
True, False, "tf",
|
||||
r"Internal error or inconsistent input model: the frontend supports frozen formats"
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\) formats."
|
||||
" \(.pb and .pbtxt\), SavedModel and MetaGraph \(.meta\), and v1 checkpoints."
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,42 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from unit_tests.moc_tf_fe.utils import basic_check
|
||||
|
||||
|
||||
class TestBasicConversion(unittest.TestCase):
|
||||
def prepare_checkpoint_v1(self):
|
||||
# quite old TensorFlow version can produce checkpoint v1 file
|
||||
# so have hard coded bytestream corresponding to checkpoint v1 content
|
||||
# this is a checkpoint v1 for Variable global_step with value = 14108582 of int64 type
|
||||
buffer_checkpoint = [
|
||||
0x00, 0x00, 0x1B, 0x0A, 0x19, 0x0A, 0x13, 0x0A, 0x0B, 0x67, 0x6C, 0x6F,
|
||||
0x62, 0x61, 0x6C, 0x5F, 0x73, 0x74, 0x65, 0x70, 0x12, 0x00, 0x18, 0x09,
|
||||
0x22, 0x00, 0x12, 0x02, 0x08, 0x01, 0x00, 0x0F, 0x19, 0x00, 0x67, 0x6C,
|
||||
0x6F, 0x62, 0x61, 0x6C, 0x5F, 0x73, 0x74, 0x65, 0x70, 0x00, 0x01, 0x00,
|
||||
0x12, 0x17, 0x0A, 0x0B, 0x67, 0x6C, 0x6F, 0x62, 0x61, 0x6C, 0x5F, 0x73,
|
||||
0x74, 0x65, 0x70, 0x12, 0x00, 0x1A, 0x06, 0x52, 0x04, 0xA6, 0x8F, 0xDD,
|
||||
0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x62, 0x29,
|
||||
0x33, 0xD3, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xC0,
|
||||
0xF2, 0xA1, 0xB0, 0x00, 0x01, 0x02, 0x01, 0x00, 0x51, 0x00, 0x00, 0x00,
|
||||
0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x1A, 0x13, 0xD9, 0x46, 0x56, 0x08,
|
||||
0x63, 0x0E, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x57, 0xFB, 0x80, 0x8B, 0x24, 0x75, 0x47, 0xDB]
|
||||
return buffer_checkpoint
|
||||
|
||||
def test_basic_checkpoint_v1(self):
|
||||
ckpt_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
checkpoint_byte_stream = self.prepare_checkpoint_v1()
|
||||
ckpt_file.write(bytes(checkpoint_byte_stream))
|
||||
ckpt_file.close()
|
||||
basic_check(input_model="model_with_variable_v1.pbtxt", argv_input=None,
|
||||
input_data={'input1': np.array([[1]], dtype=np.int64)},
|
||||
expected_dtype=np.int64, expected_value=np.array([[14108583]], dtype=np.int64),
|
||||
use_new_frontend=True, use_legacy_frontend=False, input_checkpoint=ckpt_file.name)
|
@ -0,0 +1,53 @@
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "global_step"
|
||||
op: "Variable"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "add"
|
||||
op: "Add"
|
||||
input: "input1"
|
||||
input: "global_step"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
38
tools/mo/unit_tests/moc_tf_fe/utils.py
Normal file
38
tools/mo/unit_tests/moc_tf_fe/utils.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.runtime import Core
|
||||
from openvino.tools.mo.convert import convert_model
|
||||
|
||||
|
||||
def basic_check(input_model, argv_input, input_data, expected_dtype, expected_value, freeze_placeholder_with_value=None,
|
||||
input_shape=None, only_conversion=False, input_model_is_text=True, use_new_frontend=True,
|
||||
use_legacy_frontend=False, extensions=None, input_checkpoint=None):
|
||||
path = os.path.dirname(__file__)
|
||||
input_model = os.path.join(path, "test_models", input_model)
|
||||
|
||||
ov_model = convert_model(input_model, input=argv_input,
|
||||
freeze_placeholder_with_value=freeze_placeholder_with_value,
|
||||
input_shape=input_shape, input_model_is_text=input_model_is_text,
|
||||
use_new_frontend=use_new_frontend, use_legacy_frontend=use_legacy_frontend,
|
||||
framework="tf", extensions=extensions, input_checkpoint=input_checkpoint)
|
||||
|
||||
if only_conversion:
|
||||
return ov_model
|
||||
|
||||
ie = Core()
|
||||
exec_net = ie.compile_model(ov_model, "CPU")
|
||||
req = exec_net.create_infer_request()
|
||||
results = req.infer(input_data)
|
||||
values = list(results.values())[0]
|
||||
if expected_dtype is not None:
|
||||
assert values.dtype == expected_dtype
|
||||
assert np.allclose(values,
|
||||
expected_value), "Expected and actual values are different." \
|
||||
" Expected value: {}, actual value: {}".format(expected_value, values)
|
||||
|
||||
return ov_model
|
Loading…
Reference in New Issue
Block a user