[TF FE] Support conversion of models with non-standard extensions in the path (#15875)

* [TF FE] Support conversion of models with non-standard extensions in the path

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Update tools/mo/unit_tests/moc_tf_fe/conversion_basic_models.py

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-02-23 11:29:14 +04:00 committed by GitHub
parent 87bcbc1747
commit 900332c46e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 96 additions and 32 deletions

View File

@ -58,11 +58,12 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
if (variants.size() != 1)
return false;
// Validating first path, it must contain a model
if (variants[0].is<std::string>()) {
std::string suffix = ".pb";
std::string model_path = variants[0].as<std::string>();
if (ov::util::ends_with(model_path, suffix.c_str())) {
if (ov::util::ends_with(model_path, ".pb") && GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format
// for automatic deduction of the frontend to convert the model
// we have more strict rule that is to have `.pb` extension in the path
return true;
}
}
@ -70,12 +71,16 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
else if (variants[0].is<std::wstring>()) {
std::wstring suffix = L".pb";
std::wstring model_path = variants[0].as<std::wstring>();
if (ov::util::ends_with(model_path, suffix)) {
if (ov::util::ends_with(model_path, suffix) && GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format with a path in Unicode
// for automatic deduction of the frontend to convert the model
// we have more strict rule that is to have `.pb` extension in the path
return true;
}
}
#endif
else if (variants[0].is<GraphIterator::Ptr>()) {
// this is used for OpenVINO with TensorFlow Integration
return true;
}
return false;
@ -83,33 +88,36 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& variants) const {
// TODO: Support other TensorFlow formats: SavedModel, .meta, checkpoint, pbtxt
if (variants.size() == 1) {
// a case when binary protobuf format is provided
if (variants[0].is<std::string>()) {
std::string suffix = ".pb";
std::string model_path = variants[0].as<std::string>();
if (ov::util::ends_with(model_path, suffix.c_str())) {
return std::make_shared<InputModel>(
std::make_shared<::ov::frontend::tensorflow::GraphIteratorProto>(model_path),
m_telemetry);
}
}
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
else if (variants[0].is<std::wstring>()) {
std::wstring suffix = L".pb";
std::wstring model_path = variants[0].as<std::wstring>();
if (ov::util::ends_with(model_path, suffix)) {
return std::make_shared<InputModel>(
std::make_shared<::ov::frontend::tensorflow::GraphIteratorProto>(model_path),
m_telemetry);
}
}
#endif
else if (variants[0].is<GraphIterator::Ptr>()) {
auto graph_iterator = variants[0].as<GraphIterator::Ptr>();
return std::make_shared<InputModel>(graph_iterator, m_telemetry);
FRONT_END_GENERAL_CHECK(variants.size() == 1,
"[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports "
"only frozen binary protobuf format.");
if (variants[0].is<std::string>()) {
auto model_path = variants[0].as<std::string>();
if (GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProto>(model_path), m_telemetry);
}
}
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
else if (variants[0].is<std::wstring>()) {
std::wstring model_path = variants[0].as<std::wstring>();
if (GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format with a path in Unicode
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProto>(model_path), m_telemetry);
}
}
#endif
else if (variants[0].is<GraphIterator::Ptr>()) {
// this is used for OpenVINO with TensorFlow Integration
auto graph_iterator = variants[0].as<GraphIterator::Ptr>();
return std::make_shared<InputModel>(graph_iterator, m_telemetry);
}
FRONT_END_GENERAL_CHECK(false,
"[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports "
"only frozen binary protobuf format.");
return nullptr;
}

View File

@ -88,29 +88,40 @@ public:
}
}
/// Set iterator to the start position
/// \brief Check if the input file is supported
template <typename T>
static bool is_supported(const std::basic_string<T>& path) {
std::ifstream pb_stream(path, std::ios::in | std::ifstream::binary);
auto graph_def = std::make_shared<::tensorflow::GraphDef>();
return pb_stream && pb_stream.is_open() && graph_def->ParsePartialFromIstream(&pb_stream);
}
/// \brief Set iterator to the start position
void reset() override {
node_index = 0;
}
/// \brief Return a number of nodes in the graph
size_t size() const override {
return m_decoders.size();
}
/// Moves to the next node in the graph
/// \brief Move to the next node in the graph
void next() override {
node_index++;
}
/// \brief Check if the graph is fully traversed
bool is_end() const override {
return node_index >= m_decoders.size();
}
/// Return NodeContext for the current node that iterator points to
/// \brief Return NodeContext for the current node that iterator points to
std::shared_ptr<DecoderBase> get_decoder() const override {
return m_decoders[node_index];
}
/// \brief Get GraphIterator for library funnction by name
std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const override {
if (m_library_map.count(func_name)) {
auto func_ind = m_library_map.at(func_name);
@ -127,10 +138,12 @@ public:
return nullptr;
}
/// \brief Get input names in the original order. Used for the library functions
std::vector<std::string> get_input_names() const override {
return m_input_names;
}
/// \brief Get output names in the original order. Used for the library functions
std::vector<std::string> get_output_names() const override {
return m_output_names;
}

View File

@ -309,3 +309,41 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
def test_conversion_model_oneshot_iterator_default(self):
self.basic("model_oneshot_iterator.pbtxt", None, None, None, None,
None, None, True, True, False, False)
@generate(
*[
(
"in2{f32}->[0.0 0.0 0.0 0.0]",
{"in1": np.array([[1.0, 2.0], [3.0, 4.0]])},
np.array([[1.0, 2.0], [3.0, 4.0]]),
np.float32,
),
(
"in2->[1.0 15.0 15.5 1.0]",
{"in1": np.array([[2.0, 4.0], [12.0, 8.0]])},
np.array([[3.0, 19.0], [27.5, 9.0]]),
np.float32,
),
],
)
def test_conversion_model_with_non_standard_extension(self, input_freezing_value, inputs, expected,
dtype):
self.basic("model_fp32.frozen", input_freezing_value, inputs, dtype, expected, only_conversion=False,
input_model_is_text=False, use_new_frontend=True,
use_legacy_frontend=False)
def test_conversion_fake_model(self):
with self.assertRaisesRegex(Exception,
"Internal error or inconsistent input model: the frontend supports "
"only frozen binary protobuf format."):
self.basic("fake.pb", None, None, None, None,
only_conversion=True, input_model_is_text=False, use_new_frontend=True,
use_legacy_frontend=False)
def test_conversion_dir_model(self):
with self.assertRaisesRegex(Exception,
"Internal error or inconsistent input model: the frontend supports "
"only frozen binary protobuf format."):
self.basic(".", None, None, None, None,
only_conversion=True, input_model_is_text=False, use_new_frontend=True,
use_legacy_frontend=False)

View File

@ -0,0 +1,2 @@
dcfsdcdsdcs
cscscsc

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8a33c91148b5e72ca03608c7d2ee18229ee4b610344dadd6896efeb6ac7b93e0
size 141