[TF FE] Support conversion of models with non-standard extensions in the path (#15875)
* [TF FE] Support conversion of models with non-standard extensions in the path Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Update tools/mo/unit_tests/moc_tf_fe/conversion_basic_models.py --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
87bcbc1747
commit
900332c46e
@ -58,11 +58,12 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
||||
if (variants.size() != 1)
|
||||
return false;
|
||||
|
||||
// Validating first path, it must contain a model
|
||||
if (variants[0].is<std::string>()) {
|
||||
std::string suffix = ".pb";
|
||||
std::string model_path = variants[0].as<std::string>();
|
||||
if (ov::util::ends_with(model_path, suffix.c_str())) {
|
||||
if (ov::util::ends_with(model_path, ".pb") && GraphIteratorProto::is_supported(model_path)) {
|
||||
// handle binary protobuf format
|
||||
// for automatic deduction of the frontend to convert the model
|
||||
// we have more strict rule that is to have `.pb` extension in the path
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -70,12 +71,16 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
||||
else if (variants[0].is<std::wstring>()) {
|
||||
std::wstring suffix = L".pb";
|
||||
std::wstring model_path = variants[0].as<std::wstring>();
|
||||
if (ov::util::ends_with(model_path, suffix)) {
|
||||
if (ov::util::ends_with(model_path, suffix) && GraphIteratorProto::is_supported(model_path)) {
|
||||
// handle binary protobuf format with a path in Unicode
|
||||
// for automatic deduction of the frontend to convert the model
|
||||
// we have more strict rule that is to have `.pb` extension in the path
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else if (variants[0].is<GraphIterator::Ptr>()) {
|
||||
// this is used for OpenVINO with TensorFlow Integration
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -83,33 +88,36 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
||||
|
||||
ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& variants) const {
|
||||
// TODO: Support other TensorFlow formats: SavedModel, .meta, checkpoint, pbtxt
|
||||
if (variants.size() == 1) {
|
||||
// a case when binary protobuf format is provided
|
||||
if (variants[0].is<std::string>()) {
|
||||
std::string suffix = ".pb";
|
||||
std::string model_path = variants[0].as<std::string>();
|
||||
if (ov::util::ends_with(model_path, suffix.c_str())) {
|
||||
return std::make_shared<InputModel>(
|
||||
std::make_shared<::ov::frontend::tensorflow::GraphIteratorProto>(model_path),
|
||||
m_telemetry);
|
||||
}
|
||||
}
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
else if (variants[0].is<std::wstring>()) {
|
||||
std::wstring suffix = L".pb";
|
||||
std::wstring model_path = variants[0].as<std::wstring>();
|
||||
if (ov::util::ends_with(model_path, suffix)) {
|
||||
return std::make_shared<InputModel>(
|
||||
std::make_shared<::ov::frontend::tensorflow::GraphIteratorProto>(model_path),
|
||||
m_telemetry);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else if (variants[0].is<GraphIterator::Ptr>()) {
|
||||
auto graph_iterator = variants[0].as<GraphIterator::Ptr>();
|
||||
return std::make_shared<InputModel>(graph_iterator, m_telemetry);
|
||||
FRONT_END_GENERAL_CHECK(variants.size() == 1,
|
||||
"[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports "
|
||||
"only frozen binary protobuf format.");
|
||||
|
||||
if (variants[0].is<std::string>()) {
|
||||
auto model_path = variants[0].as<std::string>();
|
||||
if (GraphIteratorProto::is_supported(model_path)) {
|
||||
// handle binary protobuf format
|
||||
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProto>(model_path), m_telemetry);
|
||||
}
|
||||
}
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
else if (variants[0].is<std::wstring>()) {
|
||||
std::wstring model_path = variants[0].as<std::wstring>();
|
||||
if (GraphIteratorProto::is_supported(model_path)) {
|
||||
// handle binary protobuf format with a path in Unicode
|
||||
return std::make_shared<InputModel>(std::make_shared<GraphIteratorProto>(model_path), m_telemetry);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else if (variants[0].is<GraphIterator::Ptr>()) {
|
||||
// this is used for OpenVINO with TensorFlow Integration
|
||||
auto graph_iterator = variants[0].as<GraphIterator::Ptr>();
|
||||
return std::make_shared<InputModel>(graph_iterator, m_telemetry);
|
||||
}
|
||||
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"[TensorFlow Frontend] Internal error or inconsistent input model: the frontend supports "
|
||||
"only frozen binary protobuf format.");
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -88,29 +88,40 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Set iterator to the start position
|
||||
/// \brief Check if the input file is supported
|
||||
template <typename T>
|
||||
static bool is_supported(const std::basic_string<T>& path) {
|
||||
std::ifstream pb_stream(path, std::ios::in | std::ifstream::binary);
|
||||
auto graph_def = std::make_shared<::tensorflow::GraphDef>();
|
||||
return pb_stream && pb_stream.is_open() && graph_def->ParsePartialFromIstream(&pb_stream);
|
||||
}
|
||||
|
||||
/// \brief Set iterator to the start position
|
||||
void reset() override {
|
||||
node_index = 0;
|
||||
}
|
||||
|
||||
/// \brief Return a number of nodes in the graph
|
||||
size_t size() const override {
|
||||
return m_decoders.size();
|
||||
}
|
||||
|
||||
/// Moves to the next node in the graph
|
||||
/// \brief Move to the next node in the graph
|
||||
void next() override {
|
||||
node_index++;
|
||||
}
|
||||
|
||||
/// \brief Check if the graph is fully traversed
|
||||
bool is_end() const override {
|
||||
return node_index >= m_decoders.size();
|
||||
}
|
||||
|
||||
/// Return NodeContext for the current node that iterator points to
|
||||
/// \brief Return NodeContext for the current node that iterator points to
|
||||
std::shared_ptr<DecoderBase> get_decoder() const override {
|
||||
return m_decoders[node_index];
|
||||
}
|
||||
|
||||
/// \brief Get GraphIterator for library funnction by name
|
||||
std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const override {
|
||||
if (m_library_map.count(func_name)) {
|
||||
auto func_ind = m_library_map.at(func_name);
|
||||
@ -127,10 +138,12 @@ public:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// \brief Get input names in the original order. Used for the library functions
|
||||
std::vector<std::string> get_input_names() const override {
|
||||
return m_input_names;
|
||||
}
|
||||
|
||||
/// \brief Get output names in the original order. Used for the library functions
|
||||
std::vector<std::string> get_output_names() const override {
|
||||
return m_output_names;
|
||||
}
|
||||
|
@ -309,3 +309,41 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
|
||||
def test_conversion_model_oneshot_iterator_default(self):
|
||||
self.basic("model_oneshot_iterator.pbtxt", None, None, None, None,
|
||||
None, None, True, True, False, False)
|
||||
|
||||
@generate(
|
||||
*[
|
||||
(
|
||||
"in2{f32}->[0.0 0.0 0.0 0.0]",
|
||||
{"in1": np.array([[1.0, 2.0], [3.0, 4.0]])},
|
||||
np.array([[1.0, 2.0], [3.0, 4.0]]),
|
||||
np.float32,
|
||||
),
|
||||
(
|
||||
"in2->[1.0 15.0 15.5 1.0]",
|
||||
{"in1": np.array([[2.0, 4.0], [12.0, 8.0]])},
|
||||
np.array([[3.0, 19.0], [27.5, 9.0]]),
|
||||
np.float32,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_conversion_model_with_non_standard_extension(self, input_freezing_value, inputs, expected,
|
||||
dtype):
|
||||
self.basic("model_fp32.frozen", input_freezing_value, inputs, dtype, expected, only_conversion=False,
|
||||
input_model_is_text=False, use_new_frontend=True,
|
||||
use_legacy_frontend=False)
|
||||
|
||||
def test_conversion_fake_model(self):
|
||||
with self.assertRaisesRegex(Exception,
|
||||
"Internal error or inconsistent input model: the frontend supports "
|
||||
"only frozen binary protobuf format."):
|
||||
self.basic("fake.pb", None, None, None, None,
|
||||
only_conversion=True, input_model_is_text=False, use_new_frontend=True,
|
||||
use_legacy_frontend=False)
|
||||
|
||||
def test_conversion_dir_model(self):
|
||||
with self.assertRaisesRegex(Exception,
|
||||
"Internal error or inconsistent input model: the frontend supports "
|
||||
"only frozen binary protobuf format."):
|
||||
self.basic(".", None, None, None, None,
|
||||
only_conversion=True, input_model_is_text=False, use_new_frontend=True,
|
||||
use_legacy_frontend=False)
|
||||
|
2
tools/mo/unit_tests/moc_tf_fe/test_models/fake.pb
Normal file
2
tools/mo/unit_tests/moc_tf_fe/test_models/fake.pb
Normal file
@ -0,0 +1,2 @@
|
||||
dcfsdcdsdcs
|
||||
cscscsc
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8a33c91148b5e72ca03608c7d2ee18229ee4b610344dadd6896efeb6ac7b93e0
|
||||
size 141
|
Loading…
Reference in New Issue
Block a user