Rework model loading in FE manager, implement PDPD probing (#6358)

* Rework model loading in FE manager, implement PDPD probing

* Fix build

* Fix build

* Fix build

* Fix unicode

* Fix merge issues

* Fix codestyle

* Read frontends path from frontend_manager library location

* Fix codestyle

* Fix FE dependency

* Fix dependencies

* Fix codestyle

* Check if model file exists

* Revert adding model to lfs

* Add test model

* Fix cmake dependencies

* Apply review feedback

* Revert pugixml

* make getFrontendLibraryPath not public API

* Fix codestyle

* Apply fix from Ilya Lavrenov

* Add FE dependency in legacy tests

* Remove not needed dependency

* Better support Unicode

* Fix build

* Fix build

* Fix build

* Add dependency foe deprecated tests

* Fix dependency

* Fix typo

* Revert adding FE dependency to IESharedTests

* Remove relative paths from frontend unit tests

* Apply review feedback

* Fix typo

* Return allow-undefined, since kmb dependecies fail to link

* Fix merge conflict

* Compare functions in reader tests

* Simplify code to load from variants

* Remove supported_by_arguments from public api

* Fix codestyle

* Fix build

* Compare names in reader tests

* Fix wchar in variant

Co-authored-by: Ilya Churaev <ilya.churaev@intel.com>
This commit is contained in:
Maxim Vafin 2021-07-19 20:10:00 +03:00 committed by GitHub
parent 48c9eaba56
commit 960ba48e6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 727 additions and 390 deletions

View File

@ -73,6 +73,10 @@ function(_ie_target_no_deprecation_error)
else()
set(flags "-Wno-error=deprecated-declarations")
endif()
if(CMAKE_CROSSCOMPILING)
set_target_properties(${ARGV} PROPERTIES
INTERFACE_LINK_OPTIONS "-Wl,--allow-shlib-undefined")
endif()
set_target_properties(${ARGV} PROPERTIES INTERFACE_COMPILE_OPTIONS ${flags})
endif()

View File

@ -124,6 +124,7 @@ target_compile_definitions(${TARGET_NAME}_obj PRIVATE IMPLEMENT_INFERENCE_ENGINE
target_include_directories(${TARGET_NAME}_obj SYSTEM PRIVATE $<TARGET_PROPERTY:ngraph::ngraph,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:pugixml::static,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:ngraph::frontend_manager,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:xbyak,INTERFACE_INCLUDE_DIRECTORIES>)
target_include_directories(${TARGET_NAME}_obj PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}"
@ -160,7 +161,7 @@ if (TBBBIND_2_4_FOUND)
endif()
target_link_libraries(${TARGET_NAME} PRIVATE pugixml::static openvino::itt ${CMAKE_DL_LIBS} Threads::Threads
ngraph inference_engine_transformations)
ngraph ngraph::frontend_manager inference_engine_transformations)
target_include_directories(${TARGET_NAME} INTERFACE
$<BUILD_INTERFACE:${PUBLIC_HEADERS_DIR}>
@ -200,7 +201,7 @@ if(WIN32)
set_target_properties(${TARGET_NAME}_s PROPERTIES COMPILE_PDB_NAME ${TARGET_NAME}_s)
endif()
target_link_libraries(${TARGET_NAME}_s PRIVATE openvino::itt ${CMAKE_DL_LIBS} ngraph
target_link_libraries(${TARGET_NAME}_s PRIVATE openvino::itt ${CMAKE_DL_LIBS} ngraph ngraph::frontend_manager
inference_engine_transformations pugixml::static)
target_compile_definitions(${TARGET_NAME}_s PUBLIC USE_STATIC_IE)

View File

@ -9,6 +9,7 @@
#include <file_utils.h>
#include <ie_reader.hpp>
#include <ie_ir_version.hpp>
#include <frontend_manager/frontend_manager.hpp>
#include <fstream>
#include <istream>
@ -226,6 +227,26 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string&
return reader->read(modelStream, exts);
}
}
// Try to load with FrontEndManager
static ngraph::frontend::FrontEndManager manager;
ngraph::frontend::FrontEnd::Ptr FE;
ngraph::frontend::InputModel::Ptr inputModel;
if (!binPath.empty()) {
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
std::wstring weights_path = FileUtils::multiByteCharToWString(binPath.c_str());
#else
std::string weights_path = binPath;
#endif
FE = manager.load_by_model(model_path, weights_path);
if (FE) inputModel = FE->load(model_path, weights_path);
} else {
FE = manager.load_by_model(model_path);
if (FE) inputModel = FE->load(model_path);
}
if (inputModel) {
auto ngFunc = FE->convert(inputModel);
return CNNNetwork(ngFunc);
}
IE_THROW() << "Unknown model format! Cannot find reader for model format: " << fileExt << " and read the model: " << modelPath <<
". Please check that reader library exists in your PATH.";
}
@ -248,4 +269,4 @@ CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weig
IE_THROW() << "Unknown model format! Cannot find reader for the model and read it. Please check that reader library exists in your PATH.";
}
} // namespace InferenceEngine
} // namespace InferenceEngine

View File

@ -55,6 +55,11 @@ if(NGRAPH_ONNX_IMPORT_ENABLE)
add_dependencies(${TARGET_NAME} inference_engine_onnx_reader)
endif()
if(NGRAPH_PDPD_FRONTEND_ENABLE)
target_compile_definitions(${TARGET_NAME} PRIVATE
PDPD_TEST_MODELS="${CMAKE_CURRENT_SOURCE_DIR}/pdpd_reader/models/")
endif()
ie_faster_build(${TARGET_NAME}
PCH PRIVATE "precomp.hpp"
)

View File

@ -0,0 +1,84 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <set>
#include <string>
#include <fstream>
#include <ie_blob.h>
#include <ie_core.hpp>
#include <file_utils.h>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset8.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
TEST(PDPD_Reader_Tests, ImportBasicModelToCore) {
auto model = std::string(PDPD_TEST_MODELS) + "relu.pdmodel";
InferenceEngine::Core ie;
auto cnnNetwork = ie.ReadNetwork(model);
auto function = cnnNetwork.getFunction();
const auto inputType = ngraph::element::f32;
const auto inputShape = ngraph::Shape{ 3 };
const auto data = std::make_shared<ngraph::opset8::Parameter>(inputType, inputShape);
data->set_friendly_name("x");
data->output(0).get_tensor().add_names({ "x" });
const auto relu = std::make_shared<ngraph::opset8::Relu>(data->output(0));
relu->set_friendly_name("relu_0.tmp_0");
relu->output(0).get_tensor().add_names({ "relu_0.tmp_0" });
const auto scale = std::make_shared<ngraph::opset8::Constant>(ngraph::element::f32, ngraph::Shape{ 1 }, std::vector<float>{1});
const auto bias = std::make_shared<ngraph::opset8::Constant>(ngraph::element::f32, ngraph::Shape{ 1 }, std::vector<float>{0});
const auto node_multiply = std::make_shared<ngraph::opset8::Multiply>(relu->output(0), scale);
const auto node_add = std::make_shared<ngraph::opset8::Add>(node_multiply, bias);
node_add->set_friendly_name("save_infer_model/scale_0.tmp_1");
node_add->output(0).get_tensor().add_names({ "save_infer_model/scale_0.tmp_1" });
const auto result = std::make_shared<ngraph::opset8::Result>(node_add->output(0));
result->set_friendly_name("save_infer_model/scale_0.tmp_1/Result");
const auto reference = std::make_shared<ngraph::Function>(
ngraph::NodeVector{ result },
ngraph::ParameterVector{ data },
"RefPDPDFunction");
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES);
const FunctionsComparator::Result res = func_comparator(function, reference);
ASSERT_TRUE(res.valid);
}
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
TEST(PDPD_Reader_Tests, ImportBasicModelToCoreWstring) {
std::string win_dir_path{ PDPD_TEST_MODELS };
std::replace(win_dir_path.begin(), win_dir_path.end(), '/', '\\');
const std::wstring unicode_win_dir_path = FileUtils::multiByteCharToWString(win_dir_path.c_str());
auto model = unicode_win_dir_path + L"ひらがな日本語.pdmodel";
InferenceEngine::Core ie;
auto cnnNetwork = ie.ReadNetwork(model);
auto function = cnnNetwork.getFunction();
const auto inputType = ngraph::element::f32;
const auto inputShape = ngraph::Shape{ 3 };
const auto data = std::make_shared<ngraph::opset8::Parameter>(inputType, inputShape);
data->set_friendly_name("x");
data->output(0).get_tensor().add_names({ "x" });
const auto relu = std::make_shared<ngraph::opset8::Relu>(data->output(0));
relu->set_friendly_name("relu_0.tmp_0");
relu->output(0).get_tensor().add_names({ "relu_0.tmp_0" });
const auto scale = std::make_shared<ngraph::opset8::Constant>(ngraph::element::f32, ngraph::Shape{ 1 }, std::vector<float>{1});
const auto bias = std::make_shared<ngraph::opset8::Constant>(ngraph::element::f32, ngraph::Shape{ 1 }, std::vector<float>{0});
const auto node_multiply = std::make_shared<ngraph::opset8::Multiply>(relu->output(0), scale);
const auto node_add = std::make_shared<ngraph::opset8::Add>(node_multiply, bias);
node_add->set_friendly_name("save_infer_model/scale_0.tmp_1");
node_add->output(0).get_tensor().add_names({ "save_infer_model/scale_0.tmp_1" });
const auto result = std::make_shared<ngraph::opset8::Result>(node_add->output(0));
result->set_friendly_name("save_infer_model/scale_0.tmp_1/Result");
const auto reference = std::make_shared<ngraph::Function>(
ngraph::NodeVector{ result },
ngraph::ParameterVector{ data },
"RefPDPDFunction");
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES);
const FunctionsComparator::Result res = func_comparator(function, reference);
ASSERT_TRUE(res.valid);
}
#endif

View File

@ -24,7 +24,7 @@ def moc_pipeline(argv: argparse.Namespace):
str(fem.get_available_front_ends())))
log.debug('Initializing new FE for framework {}'.format(argv.framework))
fe = fem.load_by_framework(argv.framework)
input_model = fe.load_from_file(argv.input_model)
input_model = fe.load(argv.input_model)
user_shapes, outputs, freeze_placeholder = fe_user_data_repack(
input_model, argv.placeholder_shapes, argv.placeholder_data_types,

View File

@ -25,7 +25,7 @@ extern "C" MOCK_API void* GetFrontEndData()
{
FrontEndPluginInfo* res = new FrontEndPluginInfo();
res->m_name = "mock_mo_ngraph_frontend";
res->m_creator = [](FrontEndCapFlags flags) { return std::make_shared<FrontEndMockPy>(flags); };
res->m_creator = []() { return std::make_shared<FrontEndMockPy>(); };
return res;
}

View File

@ -292,11 +292,9 @@ public:
/// was called with correct arguments during test execution
struct MOCK_API FeStat
{
FrontEndCapFlags m_load_flags;
std::vector<std::string> m_load_paths;
int m_convert_model = 0;
// Getters
FrontEndCapFlags load_flags() const { return m_load_flags; }
std::vector<std::string> load_paths() const { return m_load_paths; }
int convert_model() const { return m_convert_model; }
};
@ -309,13 +307,8 @@ class MOCK_API FrontEndMockPy : public FrontEnd
static FeStat m_stat;
public:
FrontEndMockPy(FrontEndCapFlags flags) { m_stat.m_load_flags = flags; }
FrontEndMockPy() {}
InputModel::Ptr load_from_file(const std::string& path) const override
{
m_stat.m_load_paths.push_back(path);
return std::make_shared<InputModelMockPy>();
}
std::shared_ptr<ngraph::Function> convert(InputModel::Ptr model) const override
{
@ -326,4 +319,15 @@ public:
static FeStat get_stat() { return m_stat; }
static void clear_stat() { m_stat = {}; }
protected:
InputModel::Ptr load_impl(const std::vector<std::shared_ptr<Variant>>& params) const override
{
if (params.size() > 0 && is_type<VariantWrapper<std::string>>(params[0]))
{
auto path = as_type_ptr<VariantWrapper<std::string>>(params[0])->get();
m_stat.m_load_paths.push_back(path);
}
return std::make_shared<InputModelMockPy>();
}
};

View File

@ -17,7 +17,6 @@ static void register_mock_frontend_stat(py::module m)
m.def("clear_frontend_statistic", &FrontEndMockPy::clear_stat);
py::class_<FeStat> feStat(m, "FeStat", py::dynamic_attr());
feStat.def_property_readonly("load_flags", &FeStat::load_flags);
feStat.def_property_readonly("load_paths", &FeStat::load_paths);
feStat.def_property_readonly("convert_model", &FeStat::convert_model);
}

View File

@ -75,4 +75,27 @@ namespace ngraph
{
}
};
template <typename T>
inline std::shared_ptr<Variant> make_variant(const T& p)
{
return std::dynamic_pointer_cast<VariantImpl<T>>(std::make_shared<VariantWrapper<T>>(p));
}
template <size_t N>
inline std::shared_ptr<Variant> make_variant(const char (&s)[N])
{
return std::dynamic_pointer_cast<VariantImpl<std::string>>(
std::make_shared<VariantWrapper<std::string>>(s));
}
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
template <size_t N>
inline std::shared_ptr<Variant> make_variant(const wchar_t (&s)[N])
{
return std::dynamic_pointer_cast<VariantImpl<std::wstring>>(
std::make_shared<VariantWrapper<std::wstring>>(s));
}
#endif
} // namespace ngraph

View File

@ -10,6 +10,7 @@
#include "frontend_manager_defs.hpp"
#include "input_model.hpp"
#include "ngraph/function.hpp"
#include "ngraph/variant.hpp"
namespace ngraph
{
@ -26,43 +27,31 @@ namespace ngraph
virtual ~FrontEnd();
/// \brief Loads an input model by specified model file path
/// If model is stored in several files (e.g. model topology and model weights) -
/// frontend implementation is responsible to handle this case, generally frontend may
/// retrieve other file names from main file
/// \param path Main model file path
/// \return Loaded input model
virtual InputModel::Ptr load_from_file(const std::string& path) const;
/// \brief Validates if FrontEnd can recognize model with parameters specified.
/// Same parameters should be used to load model.
/// \param vars Any number of parameters of any type. What kind of parameters
/// are accepted is determined by each FrontEnd individually, typically it is
/// std::string containing path to the model file. For more information please
/// refer to specific FrontEnd documentation.
/// \return true if model recognized, false - otherwise.
template <typename... Types>
inline bool supported(const Types&... vars) const
{
return supported_impl({make_variant(vars)...});
}
/// \brief Loads an input model by specified number of model files
/// This shall be used for cases when client knows all model files (model, weights, etc)
/// \param paths Array of model files
/// \return Loaded input model
virtual InputModel::Ptr load_from_files(const std::vector<std::string>& paths) const;
/// \brief Loads an input model by already loaded memory buffer
/// Memory structure is frontend-defined and is not specified in generic API
/// \param model Model memory buffer
/// \return Loaded input model
virtual InputModel::Ptr load_from_memory(const void* model) const;
/// \brief Loads an input model from set of memory buffers
/// Memory structure is frontend-defined and is not specified in generic API
/// \param modelParts Array of model memory buffers
/// \return Loaded input model
virtual InputModel::Ptr
load_from_memory_fragments(const std::vector<const void*>& modelParts) const;
/// \brief Loads an input model by input stream representing main model file
/// \param stream Input stream of main model
/// \return Loaded input model
virtual InputModel::Ptr load_from_stream(std::istream& stream) const;
/// \brief Loads an input model by input streams representing all model files
/// \param streams Array of input streams for model
/// \return Loaded input model
virtual InputModel::Ptr
load_from_streams(const std::vector<std::istream*>& streams) const;
/// \brief Loads an input model by any specified arguments. Each FrontEnd separately
/// defines what arguments it can accept.
/// \param vars Any number of parameters of any type. What kind of parameters
/// are accepted is determined by each FrontEnd individually, typically it is
/// std::string containing path to the model file. For more information please
/// refer to specific FrontEnd documentation.
/// \return Loaded input model.
template <typename... Types>
inline InputModel::Ptr load(const Types&... vars) const
{
return load_impl({make_variant(vars)...});
}
/// \brief Completely convert and normalize entire function, throws if it is not
/// possible
@ -95,8 +84,20 @@ namespace ngraph
/// \brief Runs normalization passes on function that was loaded with partial conversion
/// \param function partially converted nGraph function
virtual void normalize(std::shared_ptr<ngraph::Function> function) const;
protected:
virtual bool
supported_impl(const std::vector<std::shared_ptr<Variant>>& variants) const;
virtual InputModel::Ptr
load_impl(const std::vector<std::shared_ptr<Variant>>& variants) const;
};
template <>
inline bool FrontEnd::supported(const std::vector<std::shared_ptr<Variant>>& variants) const
{
return supported_impl(variants);
}
} // namespace frontend
} // namespace ngraph

View File

@ -8,36 +8,14 @@
#include <string>
#include "frontend.hpp"
#include "frontend_manager_defs.hpp"
#include "ngraph/variant.hpp"
namespace ngraph
{
namespace frontend
{
/// Capabilities for requested FrontEnd
/// In general, frontend implementation may be divided into several libraries by capability
/// level It will allow faster load of frontend when only limited usage is expected by
/// client application as well as binary size can be minimized by removing not needed parts
/// from application's package
namespace FrontEndCapabilities
{
/// \brief Just reading and conversion, w/o any modifications; intended to be used in
/// Reader
static const int FEC_DEFAULT = 0;
/// \brief Topology cutting capability
static const int FEC_CUT = 1;
/// \brief Query entities by names, renaming and adding new names for operations and
/// tensors
static const int FEC_NAMES = 2;
/// \brief Partial model conversion and decoding capability
static const int FEC_WILDCARDS = 4;
}; // namespace FrontEndCapabilities
// -------------- FrontEndManager -----------------
using FrontEndCapFlags = int;
using FrontEndFactory = std::function<FrontEnd::Ptr(FrontEndCapFlags fec)>;
using FrontEndFactory = std::function<FrontEnd::Ptr()>;
/// \brief Frontend management class, loads available frontend plugins on construction
/// Allows load of frontends for particular framework, register new and list available
@ -62,26 +40,22 @@ namespace ngraph
/// \param framework Framework name. Throws exception if name is not in list of
/// available frontends
///
/// \param fec Frontend capabilities. It is recommended to use only
/// those capabilities which are needed to minimize load time
///
/// \return Frontend interface for further loading of models
FrontEnd::Ptr
load_by_framework(const std::string& framework,
FrontEndCapFlags fec = FrontEndCapabilities::FEC_DEFAULT);
FrontEnd::Ptr load_by_framework(const std::string& framework);
/// \brief Loads frontend by model file path. Selects and loads appropriate frontend
/// depending on model file extension and other file info (header)
/// \brief Loads frontend by model fragments described by each FrontEnd documentation.
/// Selects and loads appropriate frontend depending on model file extension and other
/// file info (header)
///
/// \param framework
/// Framework name. Throws exception if name is not in list of available frontends
///
/// \param fec Frontend capabilities. It is recommended to use only those capabilities
/// which are needed to minimize load time
///
/// \return Frontend interface for further loading of model
FrontEnd::Ptr load_by_model(const std::string& path,
FrontEndCapFlags fec = FrontEndCapabilities::FEC_DEFAULT);
template <typename... Types>
FrontEnd::Ptr load_by_model(const Types&... vars)
{
return load_by_model_impl({make_variant(vars)...});
}
/// \brief Gets list of registered frontends
std::vector<std::string> get_available_front_ends() const;
@ -97,6 +71,8 @@ namespace ngraph
private:
class Impl;
FrontEnd::Ptr load_by_model_impl(const std::vector<std::shared_ptr<Variant>>& variants);
std::unique_ptr<Impl> m_impl;
};
@ -119,4 +95,31 @@ namespace ngraph
} // namespace frontend
template <>
class FRONTEND_API VariantWrapper<std::shared_ptr<std::istream>>
: public VariantImpl<std::shared_ptr<std::istream>>
{
public:
static constexpr VariantTypeInfo type_info{"Variant::std::shared_ptr<std::istream>", 0};
const VariantTypeInfo& get_type_info() const override { return type_info; }
VariantWrapper(const value_type& value)
: VariantImpl<value_type>(value)
{
}
};
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
template <>
class FRONTEND_API VariantWrapper<std::wstring> : public VariantImpl<std::wstring>
{
public:
static constexpr VariantTypeInfo type_info{"Variant::std::wstring", 0};
const VariantTypeInfo& get_type_info() const override { return type_info; }
VariantWrapper(const value_type& value)
: VariantImpl<value_type>(value)
{
}
};
#endif
} // namespace ngraph

View File

@ -8,6 +8,7 @@
#include "frontend_manager/frontend_exceptions.hpp"
#include "frontend_manager/frontend_manager.hpp"
#include "plugin_loader.hpp"
#include "utils.hpp"
using namespace ngraph;
using namespace ngraph::frontend;
@ -23,11 +24,11 @@ public:
~Impl() = default;
FrontEnd::Ptr loadByFramework(const std::string& framework, FrontEndCapFlags fec)
FrontEnd::Ptr loadByFramework(const std::string& framework)
{
FRONT_END_INITIALIZATION_CHECK(
m_factories.count(framework), "FrontEnd for Framework ", framework, " is not found");
return m_factories[framework](fec);
return m_factories[framework]();
}
std::vector<std::string> availableFrontEnds() const
@ -42,9 +43,17 @@ public:
return keys;
}
FrontEnd::Ptr loadByModel(const std::string& path, FrontEndCapFlags fec)
FrontEnd::Ptr loadByModel(const std::vector<std::shared_ptr<Variant>>& variants)
{
FRONT_END_NOT_IMPLEMENTED(loadByModel);
for (const auto& factory : m_factories)
{
auto FE = factory.second();
if (FE->supported(variants))
{
return FE;
}
}
return FrontEnd::Ptr();
}
void registerFrontEnd(const std::string& name, FrontEndFactory creator)
@ -81,7 +90,7 @@ private:
}
else
{
registerFromDir(".");
registerFromDir(getFrontendLibraryPath());
}
}
};
@ -96,14 +105,15 @@ FrontEndManager& FrontEndManager::operator=(FrontEndManager&&) = default;
FrontEndManager::~FrontEndManager() = default;
FrontEnd::Ptr FrontEndManager::load_by_framework(const std::string& framework, FrontEndCapFlags fec)
FrontEnd::Ptr FrontEndManager::load_by_framework(const std::string& framework)
{
return m_impl->loadByFramework(framework, fec);
return m_impl->loadByFramework(framework);
}
FrontEnd::Ptr FrontEndManager::load_by_model(const std::string& path, FrontEndCapFlags fec)
FrontEnd::Ptr
FrontEndManager::load_by_model_impl(const std::vector<std::shared_ptr<Variant>>& variants)
{
return m_impl->loadByModel(path, fec);
return m_impl->loadByModel(variants);
}
std::vector<std::string> FrontEndManager::get_available_front_ends() const
@ -122,37 +132,15 @@ FrontEnd::FrontEnd() = default;
FrontEnd::~FrontEnd() = default;
InputModel::Ptr FrontEnd::load_from_file(const std::string& path) const
bool FrontEnd::supported_impl(const std::vector<std::shared_ptr<Variant>>& variants) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_file);
return false;
}
InputModel::Ptr FrontEnd::load_from_files(const std::vector<std::string>& paths) const
InputModel::Ptr FrontEnd::load_impl(const std::vector<std::shared_ptr<Variant>>& params) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_files);
FRONT_END_NOT_IMPLEMENTED(load_impl);
}
InputModel::Ptr FrontEnd::load_from_memory(const void* model) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_memory);
}
InputModel::Ptr
FrontEnd::load_from_memory_fragments(const std::vector<const void*>& modelParts) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_memory_fragments);
}
InputModel::Ptr FrontEnd::load_from_stream(std::istream& path) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_stream);
}
InputModel::Ptr FrontEnd::load_from_streams(const std::vector<std::istream*>& paths) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_streams);
}
std::shared_ptr<ngraph::Function> FrontEnd::convert(InputModel::Ptr model) const
{
FRONT_END_NOT_IMPLEMENTED(convert);
@ -422,3 +410,9 @@ Place::Ptr Place::get_source_tensor(int inputPortIndex) const
{
FRONT_END_NOT_IMPLEMENTED(get_source_tensor);
}
constexpr VariantTypeInfo VariantWrapper<std::shared_ptr<std::istream>>::type_info;
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
constexpr VariantTypeInfo VariantWrapper<std::wstring>::type_info;
#endif

View File

@ -0,0 +1,68 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "utils.hpp"
#include "frontend_manager/frontend_exceptions.hpp"
#include "plugin_loader.hpp"
#ifndef _WIN32
#include <dlfcn.h>
#include <limits.h>
#include <unistd.h>
#ifdef ENABLE_UNICODE_PATH_SUPPORT
#include <codecvt>
#include <locale>
#endif
#else
#if defined(WINAPI_FAMILY) && !WINAPI_PARTITION_DESKTOP
#error "Only WINAPI_PARTITION_DESKTOP is supported, because of GetModuleHandleEx[A|W]"
#endif
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <Windows.h>
#endif
namespace
{
std::string getPathName(const std::string& s)
{
size_t i = s.rfind(FileSeparator, s.length());
if (i != std::string::npos)
{
return (s.substr(0, i));
}
return {};
}
} // namespace
static std::string _getFrontendLibraryPath()
{
#ifdef _WIN32
CHAR ie_library_path[MAX_PATH];
HMODULE hm = NULL;
if (!GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
reinterpret_cast<LPSTR>(ngraph::frontend::getFrontendLibraryPath),
&hm))
{
FRONT_END_INITIALIZATION_CHECK(false, "GetModuleHandle returned ", GetLastError());
}
GetModuleFileNameA(hm, (LPSTR)ie_library_path, sizeof(ie_library_path));
return getPathName(std::string(ie_library_path));
#elif defined(__APPLE__) || defined(__linux__)
Dl_info info;
dladdr(reinterpret_cast<void*>(ngraph::frontend::getFrontendLibraryPath), &info);
return getPathName(std::string(info.dli_fname)).c_str();
#else
#error "Unsupported OS"
#endif // _WIN32
}
std::string ngraph::frontend::getFrontendLibraryPath()
{
return _getFrontendLibraryPath();
}

View File

@ -0,0 +1,14 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include "frontend_manager/frontend_manager_defs.hpp"
namespace ngraph
{
namespace frontend
{
FRONTEND_API std::string getFrontendLibraryPath();
} // namespace frontend
} // namespace ngraph

View File

@ -14,44 +14,33 @@ namespace ngraph
{
class PDPD_API FrontEndPDPD : public FrontEnd
{
static std::shared_ptr<Function>
convert_model(const std::shared_ptr<InputModelPDPD>& model);
public:
FrontEndPDPD() = default;
/**
* @brief Reads model from file and deducts file names of weights
* @param path path to folder which contains __model__ file or path to .pdmodel file
* @return InputModel::Ptr
*/
InputModel::Ptr load_from_file(const std::string& path) const override;
/**
* @brief Reads model and weights from files
* @param paths vector containing path to .pdmodel and .pdiparams files
* @return InputModel::Ptr
*/
InputModel::Ptr load_from_files(const std::vector<std::string>& paths) const override;
/**
* @brief Reads model from stream
* @param model_stream stream containing .pdmodel or __model__ files. Can only be used
* if model have no weights
* @return InputModel::Ptr
*/
InputModel::Ptr load_from_stream(std::istream& model_stream) const override;
/**
* @brief Reads model from stream
* @param paths vector of streams containing .pdmodel and .pdiparams files. Can't be
* used in case of multiple weight files
* @return InputModel::Ptr
*/
InputModel::Ptr
load_from_streams(const std::vector<std::istream*>& paths) const override;
/// \brief Completely convert the remaining, not converted part of a function.
/// \param partiallyConverted partially converted nGraph function
/// \return fully converted nGraph function
std::shared_ptr<Function> convert(InputModel::Ptr model) const override;
protected:
/// \brief Check if FrontEndPDPD can recognize model from given parts
/// \param params Can be path to folder which contains __model__ file or path to
/// .pdmodel file
/// \return InputModel::Ptr
bool supported_impl(
const std::vector<std::shared_ptr<Variant>>& variants) const override;
/// \brief Reads model from 1 or 2 given file names or 1 or 2 std::istream containing
/// model in protobuf format and weights
/// \param params Can contain path to folder with __model__ file or path to .pdmodel
/// file or 1 or 2 streams with model and weights
/// \return InputModel::Ptr
InputModel::Ptr
load_impl(const std::vector<std::shared_ptr<Variant>>& params) const override;
private:
static std::shared_ptr<Function>
convert_model(const std::shared_ptr<InputModelPDPD>& model);
};
} // namespace frontend

View File

@ -13,7 +13,6 @@ namespace ngraph
{
class OpPlacePDPD;
class TensorPlacePDPD;
class PDPD_API InputModelPDPD : public InputModel
{
friend class FrontEndPDPD;
@ -26,6 +25,9 @@ namespace ngraph
public:
explicit InputModelPDPD(const std::string& path);
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
explicit InputModelPDPD(const std::wstring& path);
#endif
explicit InputModelPDPD(const std::vector<std::istream*>& streams);
std::vector<Place::Ptr> get_inputs() const override;
std::vector<Place::Ptr> get_outputs() const override;

View File

@ -2,31 +2,26 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <algorithm>
#include <chrono>
#include <fstream>
#include <map>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "framework.pb.h"
#include <paddlepaddle_frontend/exceptions.hpp>
#include <paddlepaddle_frontend/frontend.hpp>
#include <paddlepaddle_frontend/model.hpp>
#include <paddlepaddle_frontend/place.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/variant.hpp>
#include <paddlepaddle_frontend/exceptions.hpp>
#include "decoder.hpp"
#include "node_context.hpp"
#include "op_table.hpp"
#include <functional>
#include "pdpd_utils.hpp"
#include "frontend_manager/frontend_manager.hpp"
@ -67,8 +62,45 @@ namespace ngraph
}
}
return CREATORS_MAP.at(op->type())(
NodeContext(DecoderPDPDProto(op_place), named_inputs));
try
{
return CREATORS_MAP.at(op->type())(
NodeContext(DecoderPDPDProto(op_place), named_inputs));
}
catch (...)
{
// TODO: define exception types
// In case of partial conversion we need to create generic ngraph op here
return NamedOutputs();
}
}
std::istream* variant_to_stream_ptr(const std::shared_ptr<Variant>& variant,
std::ifstream& ext_stream)
{
if (is_type<VariantWrapper<std::shared_ptr<std::istream>>>(variant))
{
auto m_stream =
as_type_ptr<VariantWrapper<std::shared_ptr<std::istream>>>(variant)->get();
return m_stream.get();
}
else if (is_type<VariantWrapper<std::string>>(variant))
{
const auto& model_path =
as_type_ptr<VariantWrapper<std::string>>(variant)->get();
ext_stream.open(model_path, std::ios::in | std::ifstream::binary);
}
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
else if (is_type<VariantWrapper<std::wstring>>(variant))
{
const auto& model_path =
as_type_ptr<VariantWrapper<std::wstring>>(variant)->get();
ext_stream.open(model_path, std::ios::in | std::ifstream::binary);
}
#endif
FRONT_END_INITIALIZATION_CHECK(ext_stream && ext_stream.is_open(),
"Cannot open model file.");
return &ext_stream;
}
} // namespace pdpd
@ -91,6 +123,7 @@ namespace ngraph
const auto& type = inp_place->getElementType();
auto param = std::make_shared<Parameter>(type, shape);
param->set_friendly_name(var->name());
param->output(0).get_tensor().add_names({var->name()});
nodes_dict[var->name()] = param;
parameter_nodes.push_back(param);
}
@ -155,41 +188,102 @@ namespace ngraph
return std::make_shared<ngraph::Function>(result_nodes, parameter_nodes);
}
InputModel::Ptr FrontEndPDPD::load_from_file(const std::string& path) const
bool FrontEndPDPD::supported_impl(
const std::vector<std::shared_ptr<Variant>>& variants) const
{
return load_from_files({path});
}
// FrontEndPDPD can only load model specified by one path, one file or two files.
if (variants.empty() || variants.size() > 2)
return false;
InputModel::Ptr FrontEndPDPD::load_from_files(const std::vector<std::string>& paths) const
{
if (paths.size() == 1)
// Validating first path, it must contain a model
if (is_type<VariantWrapper<std::string>>(variants[0]))
{
// The case when folder with __model__ and weight files is provided or .pdmodel file
return std::make_shared<InputModelPDPD>(paths[0]);
std::string suffix = ".pdmodel";
std::string model_path =
as_type_ptr<VariantWrapper<std::string>>(variants[0])->get();
if (!pdpd::endsWith(model_path, suffix))
{
model_path += pdpd::get_path_sep<char>() + "__model__";
}
std::ifstream model_str(model_path, std::ios::in | std::ifstream::binary);
// It is possible to validate here that protobuf can read model from the stream,
// but it will complicate the check, while it should be as quick as possible
return model_str && model_str.is_open();
}
else if (paths.size() == 2)
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
else if (is_type<VariantWrapper<std::wstring>>(variants[0]))
{
// The case when .pdmodel and .pdparams files are provided
std::ifstream model_stream(paths[0], std::ios::in | std::ifstream::binary);
FRONT_END_INITIALIZATION_CHECK(model_stream && model_stream.is_open(),
"Cannot open model file.");
std::ifstream weights_stream(paths[1], std::ios::in | std::ifstream::binary);
FRONT_END_INITIALIZATION_CHECK(weights_stream && weights_stream.is_open(),
"Cannot open weights file.");
return load_from_streams({&model_stream, &weights_stream});
std::wstring suffix = L".pdmodel";
std::wstring model_path =
as_type_ptr<VariantWrapper<std::wstring>>(variants[0])->get();
if (!pdpd::endsWith(model_path, suffix))
{
model_path += pdpd::get_path_sep<wchar_t>() + L"__model__";
}
std::ifstream model_str(model_path, std::ios::in | std::ifstream::binary);
// It is possible to validate here that protobuf can read model from the stream,
// but it will complicate the check, while it should be as quick as possible
return model_str && model_str.is_open();
}
FRONT_END_INITIALIZATION_CHECK(false, "Model can be loaded either from 1 or 2 files");
}
InputModel::Ptr FrontEndPDPD::load_from_stream(std::istream& model_stream) const
{
return load_from_streams({&model_stream});
#endif
else if (is_type<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0]))
{
// Validating first stream, it must contain a model
std::shared_ptr<std::istream> p_model_stream =
as_type_ptr<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0])->get();
paddle::framework::proto::ProgramDesc fw;
return fw.ParseFromIstream(p_model_stream.get());
}
return false;
}
InputModel::Ptr
FrontEndPDPD::load_from_streams(const std::vector<std::istream*>& streams) const
FrontEndPDPD::load_impl(const std::vector<std::shared_ptr<Variant>>& variants) const
{
return std::make_shared<InputModelPDPD>(streams);
if (variants.size() == 1)
{
// The case when folder with __model__ and weight files is provided or .pdmodel file
if (is_type<VariantWrapper<std::string>>(variants[0]))
{
std::string m_path =
as_type_ptr<VariantWrapper<std::string>>(variants[0])->get();
return std::make_shared<InputModelPDPD>(m_path);
}
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
else if (is_type<VariantWrapper<std::wstring>>(variants[0]))
{
std::wstring m_path =
as_type_ptr<VariantWrapper<std::wstring>>(variants[0])->get();
return std::make_shared<InputModelPDPD>(m_path);
}
#endif
// The case with only model stream provided and no weights. This means model has
// no learnable weights
else if (is_type<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0]))
{
std::shared_ptr<std::istream> p_model_stream =
as_type_ptr<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0])
->get();
return std::make_shared<InputModelPDPD>(
std::vector<std::istream*>{p_model_stream.get()});
}
}
else if (variants.size() == 2)
{
// The case when .pdmodel and .pdparams files are provided
std::ifstream model_stream;
std::ifstream weights_stream;
std::istream* p_model_stream =
pdpd::variant_to_stream_ptr(variants[0], model_stream);
std::istream* p_weights_stream =
pdpd::variant_to_stream_ptr(variants[1], weights_stream);
if (p_model_stream && p_weights_stream)
{
return std::make_shared<InputModelPDPD>(
std::vector<std::istream*>{p_model_stream, p_weights_stream});
}
}
PDPD_THROW("Model can be loaded either from 1 or 2 files/streams");
}
std::shared_ptr<ngraph::Function> FrontEndPDPD::convert(InputModel::Ptr model) const
@ -211,6 +305,6 @@ extern "C" PDPD_API void* GetFrontEndData()
{
FrontEndPluginInfo* res = new FrontEndPluginInfo();
res->m_name = "pdpd";
res->m_creator = [](FrontEndCapFlags) { return std::make_shared<FrontEndPDPD>(); };
res->m_creator = []() { return std::make_shared<FrontEndPDPD>(); };
return res;
}

View File

@ -11,6 +11,12 @@
#include "decoder.hpp"
#include "framework.pb.h"
#include "node_context.hpp"
#include "pdpd_utils.hpp"
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
#include <codecvt>
#include <locale>
#endif
namespace ngraph
{
@ -21,7 +27,8 @@ namespace ngraph
class InputModelPDPD::InputModelPDPDImpl
{
public:
InputModelPDPDImpl(const std::string& path, const InputModel& input_model);
template <typename T>
InputModelPDPDImpl(const std::basic_string<T>& path, const InputModel& input_model);
InputModelPDPDImpl(const std::vector<std::istream*>& streams,
const InputModel& input_model);
std::vector<Place::Ptr> getInputs() const;
@ -37,7 +44,6 @@ namespace ngraph
void setElementType(Place::Ptr place, const ngraph::element::Type&);
void setTensorValue(Place::Ptr place, const void* value);
std::vector<uint8_t> readWeight(const std::string& name, int64_t len);
std::vector<std::shared_ptr<OpPlacePDPD>> getOpPlaces() const { return m_op_places; }
std::map<std::string, std::shared_ptr<TensorPlacePDPD>> getVarPlaces() const
{
@ -50,7 +56,9 @@ namespace ngraph
private:
void loadPlaces();
void loadConsts(std::string folder_with_weights, std::istream* weight_stream);
template <typename T>
void loadConsts(const std::basic_string<T>& folder_with_weights,
std::istream* weight_stream);
std::vector<std::shared_ptr<OpPlacePDPD>> m_op_places;
std::map<std::string, std::shared_ptr<TensorPlacePDPD>> m_var_places;
@ -142,16 +150,6 @@ namespace ngraph
namespace pdpd
{
bool endsWith(const std::string& str, const std::string& suffix)
{
if (str.length() >= suffix.length())
{
return (0 ==
str.compare(str.length() - suffix.length(), suffix.length(), suffix));
}
return false;
}
void read_tensor(std::istream& is, char* data, size_t len)
{
std::vector<char> header(16);
@ -163,16 +161,81 @@ namespace ngraph
is.read(data, len);
}
template <typename T>
std::basic_string<T> get_const_path(const std::basic_string<T>& folder_with_weights,
const std::string& name)
{
return folder_with_weights + pdpd::get_path_sep<T>() + name;
}
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
template <>
std::basic_string<wchar_t> get_const_path(const std::basic_string<wchar_t>& folder,
const std::string& name)
{
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
std::wstring _name = converter.from_bytes(name);
return folder + pdpd::get_path_sep<wchar_t>() + _name;
}
#endif
template <typename T>
std::basic_string<T> get_model_path(const std::basic_string<T>& path,
std::ifstream* weights_stream)
{
std::string model_file{path};
std::string ext = ".pdmodel";
if (pdpd::endsWith(model_file, ext))
{
std::string params_ext = ".pdiparams";
std::string weights_file{path};
weights_file.replace(weights_file.size() - ext.size(), ext.size(), params_ext);
weights_stream->open(weights_file, std::ios::binary);
// Don't throw error if file isn't opened
// It may mean that model don't have constants
}
else
{
model_file += pdpd::get_path_sep<T>() + "__model__";
}
return model_file;
}
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
template <>
std::basic_string<wchar_t> get_model_path(const std::basic_string<wchar_t>& path,
std::ifstream* weights_stream)
{
std::wstring model_file{path};
std::wstring ext = L".pdmodel";
if (pdpd::endsWith(model_file, ext))
{
std::wstring params_ext = L".pdiparams";
std::wstring weights_file{path};
weights_file.replace(weights_file.size() - ext.size(), ext.size(), params_ext);
weights_stream->open(weights_file, std::ios::binary);
// Don't throw error if file isn't opened
// It may mean that model don't have constants
}
else
{
model_file += pdpd::get_path_sep<wchar_t>() + L"__model__";
}
return model_file;
}
#endif
} // namespace pdpd
void InputModelPDPD::InputModelPDPDImpl::loadConsts(std::string folder_with_weights,
std::istream* weight_stream)
template <typename T>
void InputModelPDPD::InputModelPDPDImpl::loadConsts(
const std::basic_string<T>& folder_with_weights, std::istream* weight_stream)
{
for (const auto& item : m_var_places)
{
const auto& var_desc = item.second->getDesc();
const auto& name = item.first;
if (pdpd::endsWith(name, "feed") || pdpd::endsWith(name, "fetch"))
if (pdpd::endsWith(name, std::string{"feed"}) ||
pdpd::endsWith(name, std::string{"fetch"}))
continue;
if (!var_desc->persistable())
continue;
@ -192,7 +255,7 @@ namespace ngraph
}
else if (!folder_with_weights.empty())
{
std::ifstream is(folder_with_weights + "/" + name,
std::ifstream is(pdpd::get_const_path(folder_with_weights, name),
std::ios::in | std::ifstream::binary);
FRONT_END_GENERAL_CHECK(is && is.is_open(),
"Cannot open file for constant value.");
@ -210,35 +273,24 @@ namespace ngraph
}
}
InputModelPDPD::InputModelPDPDImpl::InputModelPDPDImpl(const std::string& path,
template <typename T>
InputModelPDPD::InputModelPDPDImpl::InputModelPDPDImpl(const std::basic_string<T>& path,
const InputModel& input_model)
: m_fw_ptr{std::make_shared<ProgramDesc>()}
, m_input_model(input_model)
{
std::string ext = ".pdmodel";
std::string model_file(path);
std::unique_ptr<std::ifstream> weights_stream;
if (model_file.length() >= ext.length() &&
(0 == model_file.compare(model_file.length() - ext.length(), ext.length(), ext)))
{
std::string weights_file(path);
weights_file.replace(weights_file.size() - ext.size(), ext.size(), ".pdiparams");
weights_stream = std::unique_ptr<std::ifstream>(
new std::ifstream(weights_file, std::ios::binary));
// Don't throw error if file isn't opened
// It may mean that model don't have constants
}
else
{
model_file += "/__model__";
}
std::string empty_str = "";
std::ifstream weights_stream;
std::ifstream pb_stream(pdpd::get_model_path<T>(path, &weights_stream),
std::ios::in | std::ifstream::binary);
std::ifstream pb_stream(model_file, std::ios::binary);
FRONT_END_GENERAL_CHECK(pb_stream && pb_stream.is_open(), "Model file doesn't exist");
FRONT_END_GENERAL_CHECK(m_fw_ptr->ParseFromIstream(&pb_stream),
"Model can't be parsed");
loadPlaces();
loadConsts(weights_stream ? "" : path, weights_stream.get());
loadConsts(weights_stream && weights_stream.is_open() ? std::basic_string<T>{} : path,
&weights_stream);
}
InputModelPDPD::InputModelPDPDImpl::InputModelPDPDImpl(
@ -257,7 +309,7 @@ namespace ngraph
loadPlaces();
if (streams.size() > 1)
loadConsts("", streams[1]);
loadConsts(std::string{""}, streams[1]);
}
std::vector<Place::Ptr> InputModelPDPD::InputModelPDPDImpl::getInputs() const
@ -367,6 +419,13 @@ namespace ngraph
{
}
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
InputModelPDPD::InputModelPDPD(const std::wstring& path)
: _impl{std::make_shared<InputModelPDPDImpl>(path, *this)}
{
}
#endif
InputModelPDPD::InputModelPDPD(const std::vector<std::istream*>& streams)
: _impl{std::make_shared<InputModelPDPDImpl>(streams, *this)}
{

View File

@ -32,12 +32,12 @@ namespace ngraph
}
else
{
scale = builder::make_constant(
dtype, Shape{1}, node.get_attribute<float>("scale"));
auto scale_val = node.get_attribute<float>("scale");
scale = ngraph::opset6::Constant::create(dtype, Shape{1}, {scale_val});
}
bias =
builder::make_constant(dtype, Shape{1}, node.get_attribute<float>("bias"));
auto bias_val = node.get_attribute<float>("bias");
bias = ngraph::opset6::Constant::create(dtype, Shape{1}, {bias_val});
auto bias_after_scale = node.get_attribute<bool>("bias_after_scale");
std::shared_ptr<Node> result_node;

View File

@ -0,0 +1,51 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "frontend_manager/frontend_exceptions.hpp"
namespace ngraph
{
namespace frontend
{
namespace pdpd
{
#ifdef _WIN32
const char PATH_SEPARATOR = '\\';
#if defined(ENABLE_UNICODE_PATH_SUPPORT)
const wchar_t WPATH_SEPARATOR = L'\\';
#endif
#else
const char PATH_SEPARATOR = '/';
#endif
template <typename T>
inline std::basic_string<T> get_path_sep()
{
return std::basic_string<T>{PATH_SEPARATOR};
}
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
template <>
inline std::basic_string<wchar_t> get_path_sep()
{
return std::basic_string<wchar_t>{WPATH_SEPARATOR};
}
#endif
template <typename T>
bool endsWith(const std::basic_string<T>& str, const std::basic_string<T>& suffix)
{
if (str.length() >= suffix.length())
{
return (0 ==
str.compare(str.length() - suffix.length(), suffix.length(), suffix));
}
return false;
}
} // namespace pdpd
} // namespace frontend
} // namespace ngraph

View File

@ -17,7 +17,6 @@ from ngraph.impl import Function
from ngraph.impl import Node
from ngraph.impl import PartialShape
from ngraph.frontend import FrontEnd
from ngraph.frontend import FrontEndCapabilities
from ngraph.frontend import FrontEndManager
from ngraph.frontend import GeneralFailure
from ngraph.frontend import NotImplementedFailure

View File

@ -11,7 +11,6 @@ Low level wrappers for the FrontEnd c++ api.
# main classes
from _pyngraph import FrontEndManager
from _pyngraph import FrontEnd
from _pyngraph import FrontEndCapabilities
from _pyngraph import InputModel
from _pyngraph import Place

View File

@ -19,10 +19,11 @@ void regclass_pyngraph_FrontEnd(py::module m)
m, "FrontEnd", py::dynamic_attr());
fem.doc() = "ngraph.impl.FrontEnd wraps ngraph::frontend::FrontEnd";
fem.def("load_from_file",
&ngraph::frontend::FrontEnd::load_from_file,
py::arg("path"),
R"(
fem.def(
"load",
[](ngraph::frontend::FrontEnd& self, const std::string& s) { return self.load(s); },
py::arg("path"),
R"(
Loads an input model by specified model file path.
Parameters
@ -32,7 +33,7 @@ void regclass_pyngraph_FrontEnd(py::module m)
Returns
----------
load_from_file : InputModel
load : InputModel
Loaded input model.
)");

View File

@ -38,7 +38,6 @@ void regclass_pyngraph_FrontEndManager(py::module m)
fem.def("load_by_framework",
&ngraph::frontend::FrontEndManager::load_by_framework,
py::arg("framework"),
py::arg("capabilities") = ngraph::frontend::FrontEndCapabilities::FEC_DEFAULT,
R"(
Loads frontend by name of framework and capabilities.
@ -47,10 +46,6 @@ void regclass_pyngraph_FrontEndManager(py::module m)
framework : str
Framework name. Throws exception if name is not in list of available frontends.
capabilities : int
Frontend capabilities. Default is FrontEndCapabilities.FEC_DEFAULT. It is recommended to use only
those capabilities which are needed to minimize load time.
Returns
----------
load_by_framework : FrontEnd
@ -58,30 +53,6 @@ void regclass_pyngraph_FrontEndManager(py::module m)
)");
}
void regclass_pyngraph_FEC(py::module m)
{
class FeCaps
{
public:
int get_caps() const { return m_caps; }
private:
int m_caps;
};
py::class_<FeCaps, std::shared_ptr<FeCaps>> type(m, "FrontEndCapabilities");
// type.doc() = "FrontEndCapabilities";
type.attr("DEFAULT") = ngraph::frontend::FrontEndCapabilities::FEC_DEFAULT;
type.attr("CUT") = ngraph::frontend::FrontEndCapabilities::FEC_CUT;
type.attr("NAMES") = ngraph::frontend::FrontEndCapabilities::FEC_NAMES;
type.attr("WILDCARDS") = ngraph::frontend::FrontEndCapabilities::FEC_WILDCARDS;
type.def(
"__eq__",
[](const FeCaps& a, const FeCaps& b) { return a.get_caps() == b.get_caps(); },
py::is_operator());
}
void regclass_pyngraph_GeneralFailureFrontEnd(py::module m)
{
static py::exception<ngraph::frontend::GeneralFailure> exc(std::move(m), "GeneralFailure");

View File

@ -9,7 +9,6 @@
namespace py = pybind11;
void regclass_pyngraph_FrontEndManager(py::module m);
void regclass_pyngraph_FEC(py::module m);
void regclass_pyngraph_NotImplementedFailureFrontEnd(py::module m);
void regclass_pyngraph_InitializationFailureFrontEnd(py::module m);
void regclass_pyngraph_OpConversionFailureFrontEnd(py::module m);

View File

@ -51,7 +51,6 @@ PYBIND11_MODULE(_pyngraph, m)
regclass_pyngraph_OpConversionFailureFrontEnd(m);
regclass_pyngraph_OpValidationFailureFrontEnd(m);
regclass_pyngraph_NotImplementedFailureFrontEnd(m);
regclass_pyngraph_FEC(m);
regclass_pyngraph_FrontEndManager(m);
regclass_pyngraph_FrontEnd(m);
regclass_pyngraph_InputModel(m);

View File

@ -18,7 +18,7 @@ extern "C" MOCK_API void* GetFrontEndData()
{
FrontEndPluginInfo* res = new FrontEndPluginInfo();
res->m_name = "mock_py";
res->m_creator = [](FrontEndCapFlags flags) { return std::make_shared<FrontEndMockPy>(flags); };
res->m_creator = []() { return std::make_shared<FrontEndMockPy>(); };
return res;
}

View File

@ -479,7 +479,6 @@ public:
struct MOCK_API FeStat
{
FrontEndCapFlags m_load_flags;
std::vector<std::string> m_load_paths;
int m_convert_model = 0;
int m_convert = 0;
@ -487,7 +486,6 @@ struct MOCK_API FeStat
int m_decode = 0;
int m_normalize = 0;
// Getters
FrontEndCapFlags load_flags() const { return m_load_flags; }
std::vector<std::string> load_paths() const { return m_load_paths; }
int convert_model() const { return m_convert_model; }
int convert() const { return m_convert; }
@ -501,11 +499,12 @@ class MOCK_API FrontEndMockPy : public FrontEnd
mutable FeStat m_stat;
public:
FrontEndMockPy(FrontEndCapFlags flags) { m_stat.m_load_flags = flags; }
FrontEndMockPy() {}
InputModel::Ptr load_from_file(const std::string& path) const override
InputModel::Ptr load_impl(const std::vector<std::shared_ptr<Variant>>& params) const override
{
m_stat.m_load_paths.push_back(path);
if (params.size() > 0 && is_type<VariantWrapper<std::string>>(params[0]))
m_stat.m_load_paths.push_back(as_type_ptr<VariantWrapper<std::string>>(params[0])->get());
return std::make_shared<InputModelMockPy>();
}

View File

@ -27,7 +27,6 @@ static void register_mock_frontend_stat(py::module m)
py::arg("frontend"));
py::class_<FeStat> feStat(m, "FeStat", py::dynamic_attr());
feStat.def_property_readonly("load_flags", &FeStat::load_flags);
feStat.def_property_readonly("load_paths", &FeStat::load_paths);
feStat.def_property_readonly("convert_model", &FeStat::convert_model);
feStat.def_property_readonly("convert", &FeStat::convert);

View File

@ -4,7 +4,7 @@
import pickle
from ngraph import PartialShape
from ngraph.frontend import FrontEndCapabilities, FrontEndManager, InitializationFailure
from ngraph.frontend import FrontEndManager, InitializationFailure
from ngraph.utils.types import get_element_type
import numpy as np
@ -31,28 +31,9 @@ def test_pickle():
pickle.dumps(fem)
@mock_needed
def test_load_by_framework_caps():
frontEnds = fem.get_available_front_ends()
assert frontEnds is not None
assert "mock_py" in frontEnds
caps = [FrontEndCapabilities.DEFAULT,
FrontEndCapabilities.CUT,
FrontEndCapabilities.NAMES,
FrontEndCapabilities.WILDCARDS,
FrontEndCapabilities.CUT | FrontEndCapabilities.NAMES | FrontEndCapabilities.WILDCARDS]
for cap in caps:
fe = fem.load_by_framework(framework="mock_py", capabilities=cap)
stat = get_fe_stat(fe)
assert cap == stat.load_flags
for i in range(len(caps) - 1):
for j in range(i + 1, len(caps)):
assert caps[i] != caps[j]
def test_load_by_unknown_framework():
frontEnds = fem.get_available_front_ends()
assert not("UnknownFramework" in frontEnds)
assert not ("UnknownFramework" in frontEnds)
try:
fem.load_by_framework("UnknownFramework")
except InitializationFailure as exc:
@ -62,10 +43,10 @@ def test_load_by_unknown_framework():
@mock_needed
def test_load_from_file():
def test_load():
fe = fem.load_by_framework(framework="mock_py")
assert fe is not None
model = fe.load_from_file("abc.bin")
model = fe.load("abc.bin")
assert model is not None
stat = get_fe_stat(fe)
assert "abc.bin" in stat.load_paths
@ -75,7 +56,7 @@ def test_load_from_file():
def test_convert_model():
fe = fem.load_by_framework(framework="mock_py")
assert fe is not None
model = fe.load_from_file(path="")
model = fe.load(path="")
func = fe.convert(model=model)
assert func is not None
stat = get_fe_stat(fe)
@ -86,7 +67,7 @@ def test_convert_model():
def test_convert_partially():
fe = fem.load_by_framework(framework="mock_py")
assert fe is not None
model = fe.load_from_file(path="")
model = fe.load(path="")
func = fe.convert_partially(model=model)
stat = get_fe_stat(fe)
assert stat.convert_partially == 1
@ -99,7 +80,7 @@ def test_convert_partially():
def test_decode_and_normalize():
fe = fem.load_by_framework(framework="mock_py")
assert fe is not None
model = fe.load_from_file(path="")
model = fe.load(path="")
func = fe.decode(model=model)
stat = get_fe_stat(fe)
assert stat.decode == 1
@ -113,7 +94,7 @@ def test_decode_and_normalize():
@mock_needed
def init_model():
fe = fem.load_by_framework(framework="mock_py")
model = fe.load_from_file(path="")
model = fe.load(path="")
return model
@ -379,7 +360,7 @@ def test_model_set_element_type():
@mock_needed
def init_place():
fe = fem.load_by_framework(framework="mock_py")
model = fe.load_from_file(path="")
model = fe.load(path="")
place = model.get_place_by_tensor_name(tensorName="")
return model, place

View File

@ -631,6 +631,7 @@ install(TARGETS unit-test
EXCLUDE_FROM_ALL)
############ FRONTEND ############
target_include_directories(unit-test PRIVATE ${FRONTEND_INCLUDE_PATH} frontend/shared/include)
target_link_libraries(unit-test PRIVATE frontend_manager cnpy)
add_subdirectory(frontend)

View File

@ -35,7 +35,7 @@ TEST(FrontEndManagerTest, testAvailableFrontEnds)
{
FrontEndManager fem;
ASSERT_NO_THROW(fem.register_front_end(
"mock", [](FrontEndCapFlags fec) { return std::make_shared<FrontEnd>(); }));
"mock", []() { return std::make_shared<FrontEnd>(); }));
auto frontends = fem.get_available_front_ends();
ASSERT_NE(std::find(frontends.begin(), frontends.end(), "mock"), frontends.end());
FrontEnd::Ptr fe;
@ -50,26 +50,6 @@ TEST(FrontEndManagerTest, testAvailableFrontEnds)
ASSERT_EQ(std::find(frontends.begin(), frontends.end(), "mock"), frontends.end());
}
TEST(FrontEndManagerTest, testLoadWithFlags)
{
int expFlags = FrontEndCapabilities::FEC_CUT | FrontEndCapabilities::FEC_WILDCARDS |
FrontEndCapabilities::FEC_NAMES;
int actualFlags = FrontEndCapabilities::FEC_DEFAULT;
FrontEndManager fem;
ASSERT_NO_THROW(fem.register_front_end("mock", [&actualFlags](int fec) {
actualFlags = fec;
return std::make_shared<FrontEnd>();
}));
auto frontends = fem.get_available_front_ends();
ASSERT_NE(std::find(frontends.begin(), frontends.end(), "mock"), frontends.end());
FrontEnd::Ptr fe;
ASSERT_NO_THROW(fe = fem.load_by_framework("mock", expFlags));
ASSERT_TRUE(actualFlags & FrontEndCapabilities::FEC_CUT);
ASSERT_TRUE(actualFlags & FrontEndCapabilities::FEC_WILDCARDS);
ASSERT_TRUE(actualFlags & FrontEndCapabilities::FEC_NAMES);
ASSERT_EQ(expFlags, actualFlags);
}
TEST(FrontEndManagerTest, testMockPluginFrontEnd)
{
std::string fePath = ngraph::file_util::get_directory(
@ -86,17 +66,13 @@ TEST(FrontEndManagerTest, testMockPluginFrontEnd)
TEST(FrontEndManagerTest, testDefaultFrontEnd)
{
FrontEndManager fem;
ASSERT_ANY_THROW(fem.load_by_model(""));
FrontEnd::Ptr fe;
ASSERT_NO_THROW(fe = fem.load_by_model(""));
ASSERT_FALSE(fe);
std::unique_ptr<FrontEnd> fePtr(new FrontEnd()); // to verify base destructor
FrontEnd::Ptr fe = std::make_shared<FrontEnd>();
ASSERT_ANY_THROW(fe->load_from_file(""));
ASSERT_ANY_THROW(fe->load_from_files({"", ""}));
ASSERT_ANY_THROW(fe->load_from_memory(nullptr));
ASSERT_ANY_THROW(fe->load_from_memory_fragments({nullptr, nullptr}));
std::stringstream str;
ASSERT_ANY_THROW(fe->load_from_stream(str));
ASSERT_ANY_THROW(fe->load_from_streams({&str, &str}));
fe = std::make_shared<FrontEnd>();
ASSERT_ANY_THROW(fe->load(""));
ASSERT_ANY_THROW(fe->convert(std::shared_ptr<Function>(nullptr)));
ASSERT_ANY_THROW(fe->convert(InputModel::Ptr(nullptr)));
ASSERT_ANY_THROW(fe->convert_partially(nullptr));

View File

@ -29,6 +29,6 @@ extern "C" MOCK_API void* GetFrontEndData()
{
FrontEndPluginInfo* res = new FrontEndPluginInfo();
res->m_name = "mock1";
res->m_creator = [](FrontEndCapFlags) { return std::make_shared<FrontEndMock>(); };
res->m_creator = []() { return std::make_shared<FrontEndMock>(); };
return res;
}

View File

@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../shared/include/basic_api.hpp"
#include "basic_api.hpp"
using namespace ngraph;
using namespace ngraph::frontend;

View File

@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../shared/include/cut_specific_model.hpp"
#include "cut_specific_model.hpp"
using namespace ngraph;
using namespace ngraph::frontend;

View File

@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../shared/include/load_from.hpp"
#include "load_from.hpp"
using namespace ngraph;
using namespace ngraph::frontend;

View File

@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../shared/include/partial_shape.hpp"
#include "partial_shape.hpp"
using namespace ngraph;
using namespace ngraph::frontend;

View File

@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../shared/include/set_element_type.hpp"
#include "set_element_type.hpp"
using namespace ngraph;
using namespace ngraph::frontend;

View File

@ -2,8 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../include/basic_api.hpp"
#include "../include/utils.hpp"
#include "basic_api.hpp"
#include "utils.hpp"
using namespace ngraph;
using namespace ngraph::frontend;
@ -34,7 +34,7 @@ void FrontEndBasicTest::doLoadFromFile()
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_feName));
ASSERT_NE(m_frontEnd, nullptr);
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load_from_file(m_modelFile));
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(m_modelFile));
ASSERT_NE(m_inputModel, nullptr);
}

View File

@ -2,8 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../include/cut_specific_model.hpp"
#include "../include/utils.hpp"
#include "cut_specific_model.hpp"
#include "utils.hpp"
#include "ngraph/opsets/opset7.hpp"
using namespace ngraph;
@ -44,7 +44,7 @@ void FrontEndCutModelTest::doLoadFromFile()
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_param.m_frontEndName));
ASSERT_NE(m_frontEnd, nullptr);
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load_from_file(m_param.m_modelName));
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(m_param.m_modelName));
ASSERT_NE(m_inputModel, nullptr);
}

View File

@ -2,9 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../include/load_from.hpp"
#include "load_from.hpp"
#include <fstream>
#include "../include/utils.hpp"
#include "utils.hpp"
using namespace ngraph;
using namespace ngraph::frontend;
@ -23,18 +23,18 @@ void FrontEndLoadFromTest::SetUp()
m_param = GetParam();
}
///////////////////////////////////////////////////////////////////
///////////////////load from Variants//////////////////////
TEST_P(FrontEndLoadFromTest, testLoadFromFile)
TEST_P(FrontEndLoadFromTest, testLoadFromFilePath)
{
std::string model_path = m_param.m_modelsPath + m_param.m_file;
std::vector<std::string> frontends;
FrontEnd::Ptr fe;
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_param.m_frontEndName));
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(model_path));
ASSERT_NE(m_frontEnd, nullptr);
ASSERT_NO_THROW(m_inputModel =
m_frontEnd->load_from_file(m_param.m_modelsPath + m_param.m_file));
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(model_path));
ASSERT_NE(m_inputModel, nullptr);
std::shared_ptr<ngraph::Function> function;
@ -42,21 +42,17 @@ TEST_P(FrontEndLoadFromTest, testLoadFromFile)
ASSERT_NE(function, nullptr);
}
TEST_P(FrontEndLoadFromTest, testLoadFromFiles)
TEST_P(FrontEndLoadFromTest, testLoadFromTwoFiles)
{
std::string model_path = m_param.m_modelsPath + m_param.m_files[0];
std::string weights_path = m_param.m_modelsPath + m_param.m_files[1];
std::vector<std::string> frontends;
FrontEnd::Ptr fe;
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_param.m_frontEndName));
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(model_path, weights_path));
ASSERT_NE(m_frontEnd, nullptr);
auto dir_files = m_param.m_files;
for (auto& file : dir_files)
{
file = m_param.m_modelsPath + file;
}
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load_from_files(dir_files));
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(model_path, weights_path));
ASSERT_NE(m_inputModel, nullptr);
std::shared_ptr<ngraph::Function> function;
@ -66,14 +62,16 @@ TEST_P(FrontEndLoadFromTest, testLoadFromFiles)
TEST_P(FrontEndLoadFromTest, testLoadFromStream)
{
auto ifs = std::make_shared<std::ifstream>(m_param.m_modelsPath + m_param.m_stream,
std::ios::in | std::ifstream::binary);
auto is = std::dynamic_pointer_cast<std::istream>(ifs);
std::vector<std::string> frontends;
FrontEnd::Ptr fe;
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_param.m_frontEndName));
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(is));
ASSERT_NE(m_frontEnd, nullptr);
std::ifstream is(m_param.m_modelsPath + m_param.m_stream, std::ios::in | std::ifstream::binary);
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load_from_stream(is));
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(is));
ASSERT_NE(m_inputModel, nullptr);
std::shared_ptr<ngraph::Function> function;
@ -81,23 +79,22 @@ TEST_P(FrontEndLoadFromTest, testLoadFromStream)
ASSERT_NE(function, nullptr);
}
TEST_P(FrontEndLoadFromTest, testLoadFromStreams)
TEST_P(FrontEndLoadFromTest, testLoadFromTwoStreams)
{
auto model_ifs = std::make_shared<std::ifstream>(m_param.m_modelsPath + m_param.m_streams[0],
std::ios::in | std::ifstream::binary);
auto weights_ifs = std::make_shared<std::ifstream>(m_param.m_modelsPath + m_param.m_streams[1],
std::ios::in | std::ifstream::binary);
auto model_is = std::dynamic_pointer_cast<std::istream>(model_ifs);
auto weights_is = std::dynamic_pointer_cast<std::istream>(weights_ifs);
std::vector<std::string> frontends;
FrontEnd::Ptr fe;
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_param.m_frontEndName));
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(model_is, weights_is));
ASSERT_NE(m_frontEnd, nullptr);
std::vector<std::shared_ptr<std::ifstream>> is_vec;
std::vector<std::istream*> is_ptr_vec;
for (auto& file : m_param.m_streams)
{
is_vec.push_back(std::make_shared<std::ifstream>(m_param.m_modelsPath + file,
std::ios::in | std::ifstream::binary));
is_ptr_vec.push_back(is_vec.back().get());
}
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load_from_streams(is_ptr_vec));
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(model_is, weights_is));
ASSERT_NE(m_inputModel, nullptr);
std::shared_ptr<ngraph::Function> function;

View File

@ -44,7 +44,7 @@ void FrontEndFuzzyOpTest::doLoadFromFile()
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_feName));
ASSERT_NE(m_frontEnd, nullptr);
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load_from_file(m_modelFile));
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(m_modelFile));
ASSERT_NE(m_inputModel, nullptr);
}

View File

@ -2,8 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../include/partial_shape.hpp"
#include "../include/utils.hpp"
#include "partial_shape.hpp"
#include "utils.hpp"
using namespace ngraph;
using namespace ngraph::frontend;
@ -42,7 +42,7 @@ void FrontEndPartialShapeTest::doLoadFromFile()
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_baseParam.m_frontEndName));
ASSERT_NE(m_frontEnd, nullptr);
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load_from_file(m_partShape.m_modelName));
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(m_partShape.m_modelName));
ASSERT_NE(m_inputModel, nullptr);
}

View File

@ -2,8 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "../include/set_element_type.hpp"
#include "../include/utils.hpp"
#include "set_element_type.hpp"
#include "utils.hpp"
using namespace ngraph;
using namespace ngraph::frontend;
@ -35,7 +35,7 @@ void FrontEndElementTypeTest::doLoadFromFile()
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_param.m_frontEndName));
ASSERT_NE(m_frontEnd, nullptr);
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load_from_file(m_param.m_modelName));
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(m_param.m_modelName));
ASSERT_NE(m_inputModel, nullptr);
}

View File

@ -107,16 +107,16 @@ namespace ngraph
TEST(op, variant)
{
shared_ptr<Variant> var_std_string = make_shared<VariantWrapper<std::string>>("My string");
shared_ptr<Variant> var_std_string = make_variant<std::string>("My string");
ASSERT_TRUE((is_type<VariantWrapper<std::string>>(var_std_string)));
EXPECT_EQ((as_type_ptr<VariantWrapper<std::string>>(var_std_string)->get()), "My string");
shared_ptr<Variant> var_int64_t = make_shared<VariantWrapper<int64_t>>(27);
shared_ptr<Variant> var_int64_t = make_variant<int64_t>(27);
ASSERT_TRUE((is_type<VariantWrapper<int64_t>>(var_int64_t)));
EXPECT_FALSE((is_type<VariantWrapper<std::string>>(var_int64_t)));
EXPECT_EQ((as_type_ptr<VariantWrapper<int64_t>>(var_int64_t)->get()), 27);
shared_ptr<Variant> var_ship = make_shared<VariantWrapper<Ship>>(Ship{"Lollipop", 3, 4});
shared_ptr<Variant> var_ship = make_variant<Ship>(Ship{"Lollipop", 3, 4});
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(var_ship)));
Ship& ship = as_type_ptr<VariantWrapper<Ship>>(var_ship)->get();
EXPECT_EQ(ship.name, "Lollipop");