ONNX FE handles old API extensions (#12644)

* ONNX FE handles old API extensions
by legacy conversion extension

* Move `LegacyOpExtension` class to dev API
This commit is contained in:
Pawel Raasz 2022-08-26 07:27:58 +02:00 committed by GitHub
parent 1818e120e3
commit 0c2c341da6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 33 additions and 2 deletions

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/op_extension.hpp"
namespace ov {
/** @brief Class to distinguish legacy extension. */
class LegacyOpExtension : public BaseOpExtension {};
} // namespace ov

View File

@ -34,6 +34,7 @@ protected:
std::vector<Extension::Ptr> m_other_extensions;
std::vector<DecoderTransformationExtension::Ptr> m_transformation_extensions;
ExtensionHolder m_extensions;
std::once_flag has_legacy_extension;
};
} // namespace onnx

View File

@ -14,6 +14,7 @@
#include <sstream>
#include <utils/onnx_internal.hpp>
#include "legacy_op_extension.hpp"
#include "onnx_common/onnx_model_validator.hpp"
#include "openvino/frontend/extension/telemetry.hpp"
#include "ops_bridge.hpp"
@ -165,5 +166,10 @@ void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
m_extensions.conversions.push_back(onnx_conv_ext);
} else if (auto progress_reporter = std::dynamic_pointer_cast<ProgressReporterExtension>(extension)) {
m_extensions.progress_reporter = progress_reporter;
} else if (const auto& legacy_ext = std::dynamic_pointer_cast<ov::LegacyOpExtension>(extension)) {
m_other_extensions.push_back(legacy_ext);
std::call_once(has_legacy_extension, [this] {
m_extensions.conversions.push_back(ngraph::onnx_import::detail::get_legacy_conversion_extension());
});
}
}

View File

@ -58,6 +58,9 @@ void unregister_operator(const std::string& name, std::int64_t version, const st
legacy_conversion_extension->unregister_operator(name, version, domain);
}
const LegacyConversionExtension::Ptr detail::get_legacy_conversion_extension() {
return legacy_conversion_extension;
}
} // namespace onnx_import
} // namespace ngraph

View File

@ -7,6 +7,7 @@
#include <memory>
#include <string>
#include "legacy_conversion_extension.hpp"
#include "ngraph/function.hpp"
#include "openvino/frontend/extension/holder.hpp"
@ -53,6 +54,11 @@ std::shared_ptr<Function> decode_to_framework_nodes(std::shared_ptr<ONNX_NAMESPA
///
/// \return A nGraph function.
void convert_decoded_function(std::shared_ptr<Function> function);
/// \brief Get the legacy conversion extension.
///
/// \return const LegacyConversionExtension::Ptr
const LegacyConversionExtension::Ptr get_legacy_conversion_extension();
} // namespace detail
} // namespace onnx_import
} // namespace ngraph

View File

@ -24,6 +24,7 @@
#endif
#include "ie_itt.hpp"
#include "legacy/ie_reader.hpp"
#include "legacy_op_extension.hpp"
#include "ngraph/function.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/variant.hpp"
@ -41,7 +42,7 @@ namespace ov {
/*
* @brief Wrapper for old IE extensions to new API
*/
class ExtensionWrapper : public ov::BaseOpExtension {
class ExtensionWrapper : public ov::LegacyOpExtension {
public:
ExtensionWrapper(const InferenceEngine::IExtensionPtr& ext, const std::string& opset, const std::string& name)
: m_ext(ext),

View File

@ -42,6 +42,7 @@ protected:
void TearDown() override {
std::remove(m_out_xml_path.c_str());
std::remove(m_out_bin_path.c_str());
ov::shutdown();
}
};
@ -70,7 +71,7 @@ TEST_F(CustomOpsSerializationTest, CustomOpUser_MO) {
// a shared library for ONNX don't make sence in static OpenVINO build
#ifndef OPENVINO_STATIC_LIBRARY
TEST_F(CustomOpsSerializationTest, DISABLED_CustomOpUser_ONNXImporter) {
TEST_F(CustomOpsSerializationTest, CustomOpUser_ONNXImporter) {
const std::string model = CommonTestUtils::getModelFromTestModelZoo(IR_SERIALIZATION_MODELS_PATH "custom_op.onnx");
InferenceEngine::Core ie;