unified approach to handle extensions by frontends

This commit is contained in:
Mateusz Bencer 2022-03-17 14:24:20 +01:00
parent 658748f83d
commit bf85ac24a6
11 changed files with 22 additions and 18 deletions

View File

@ -5,6 +5,7 @@
#include "conversion_extension.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/frontend/paddle/frontend.hpp"
#include "openvino/frontend/plugin_loader.hpp"
#include "paddle_utils.hpp"
#include "so_extension.hpp"
@ -31,7 +32,10 @@ class PaddleFrontendWrapper : public ov::frontend::paddle::FrontEnd {
m_transformation_extensions.end())
<< "DecoderTransformationExtension is not registered.";
} else if (auto so_ext = std::dynamic_pointer_cast<ov::detail::SOExtension>(extension)) {
EXPECT_NE(std::find(m_extensions.begin(), m_extensions.end(), so_ext), m_extensions.end())
const auto frontend_shared_data = std::static_pointer_cast<FrontEndSharedData>(m_shared_object);
EXPECT_TRUE(frontend_shared_data) << "Incorrect type of shared object was used";
const auto extensions = frontend_shared_data->extensions();
EXPECT_NE(std::find(extensions.begin(), extensions.end(), so_ext), extensions.end())
<< "SOExtension is not registered.";
}
}

View File

@ -22,7 +22,10 @@ namespace frontend {
class FRONTEND_API FrontEnd {
friend class FrontEndManager;
protected:
std::shared_ptr<void> m_shared_object = {}; // Library handle
private:
std::shared_ptr<FrontEnd> m_actual = {};
public:

View File

@ -9,7 +9,7 @@
#include "openvino/frontend/extension/op.hpp"
#include "openvino/frontend/manager.hpp"
#include "openvino/frontend/place.hpp"
#include "plugin_loader.hpp"
#include "openvino/frontend/plugin_loader.hpp"
#include "so_extension.hpp"
#include "utils.hpp"
@ -75,6 +75,11 @@ void FrontEnd::normalize(const std::shared_ptr<Model>& model) const {
void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
if (m_actual) {
if (auto so_ext = std::dynamic_pointer_cast<ov::detail::SOExtension>(extension)) {
if (std::dynamic_pointer_cast<ov::BaseOpExtension>(so_ext->extension())) {
add_extension_to_shared_data(m_shared_object, so_ext->extension());
}
}
add_extension_to_shared_data(m_shared_object, extension);
m_actual->add_extension(extension);
return;

View File

@ -9,8 +9,8 @@
#include "ngraph/except.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/frontend/plugin_loader.hpp"
#include "openvino/util/env_util.hpp"
#include "plugin_loader.hpp"
#include "utils.hpp"
using namespace ov;

View File

@ -20,9 +20,9 @@
#include <string>
#include <vector>
#include "openvino/frontend/plugin_loader.hpp"
#include "openvino/util/file_util.hpp"
#include "openvino/util/shared_object.hpp"
#include "plugin_loader.hpp"
using namespace ov;
using namespace ov::frontend;

View File

@ -5,8 +5,8 @@
#include "utils.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/frontend/plugin_loader.hpp"
#include "openvino/util/file_util.hpp"
#include "plugin_loader.hpp"
#ifndef _WIN32
# include <dlfcn.h>

View File

@ -45,7 +45,6 @@ protected:
InputModel::Ptr load_impl(const std::vector<ov::Any>& params) const override;
private:
std::vector<ov::Extension::Ptr> extensions;
std::shared_ptr<TelemetryExtension> m_telemetry;
};

View File

@ -11,6 +11,7 @@
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/shared_buffer.hpp"
#include "openvino/core/any.hpp"
#include "openvino/frontend/plugin_loader.hpp"
#include "openvino/util/file_util.hpp"
#include "so_extension.hpp"
#include "xml_parse_utils.h"
@ -102,12 +103,7 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
void FrontEnd::add_extension(const ov::Extension::Ptr& ext) {
if (auto telemetry = std::dynamic_pointer_cast<TelemetryExtension>(ext)) {
m_telemetry = telemetry;
} else if (auto so_ext = std::dynamic_pointer_cast<ov::detail::SOExtension>(ext)) {
if (std::dynamic_pointer_cast<ov::BaseOpExtension>(so_ext->extension())) {
extensions.emplace_back(so_ext->extension());
}
} else if (std::dynamic_pointer_cast<ov::BaseOpExtension>(ext))
extensions.emplace_back(ext);
}
}
InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& variants) const {
@ -117,7 +113,9 @@ InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& variants) const
auto create_extensions_map = [&]() -> std::unordered_map<ov::DiscreteTypeInfo, ov::BaseOpExtension::Ptr> {
std::unordered_map<ov::DiscreteTypeInfo, ov::BaseOpExtension::Ptr> exts;
for (const auto& ext : extensions) {
const auto frontend_shared_data = std::static_pointer_cast<FrontEndSharedData>(m_shared_object);
OPENVINO_ASSERT(frontend_shared_data, "Shared object has invalid type");
for (auto& ext : frontend_shared_data->extensions()) {
if (auto base_ext = std::dynamic_pointer_cast<ov::BaseOpExtension>(ext))
exts.insert({base_ext->get_type_info(), base_ext});
}

View File

@ -77,10 +77,6 @@ protected:
std::function<std::map<std::string, OutputVector>(const std::map<std::string, Output<Node>>&,
const std::shared_ptr<OpPlace>&)> func);
// m_extensions should be the first member here,
// m_extensions can contain SO Extension (holder for other Extensions),
// so it should be released last.
std::vector<Extension::Ptr> m_extensions;
TelemetryExtension::Ptr m_telemetry;
std::vector<DecoderTransformationExtension::Ptr> m_transformation_extensions;
std::vector<ConversionExtensionBase::Ptr> m_conversion_extensions;

View File

@ -365,7 +365,6 @@ void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
m_transformation_extensions.push_back(transformation);
} else if (const auto& so_ext = std::dynamic_pointer_cast<ov::detail::SOExtension>(extension)) {
add_extension(so_ext->extension());
m_extensions.push_back(so_ext);
} else if (auto common_conv_ext = std::dynamic_pointer_cast<ov::frontend::ConversionExtension>(extension)) {
m_conversion_extensions.push_back(common_conv_ext);
m_op_translators[common_conv_ext->get_op_type()] = [=](const NodeContext& context) {