Files
openvino/docs/template_extension/extension.cpp

98 lines
2.9 KiB
C++
Raw Normal View History

2020-05-22 22:34:00 +03:00
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "extension.hpp"
#include "cpu_kernel.hpp"
#include "op.hpp"
#include <ngraph/ngraph.hpp>
Tests and docs for registering custom ONNX operators (#2416) * Add tests, examples and documentation changes for custom ONNX operators registration mechanism * Change snippet paths * fix CoreThreadingTests.ReadNetwork - data race in ops_bridge * Make TemplateExtension::Operation externally visible * changes after review * apply code format * use std::int64_t * forward declare get_attribute_value specializations * introduce unregister_operator in onnx_importer * onnx_custom_op - lock mem first then take a buffer * func tests - create template_extension via make_so_pointer * fix build with NGRAPH_ONNX_IMPORT_ENABLE=OFF * remove exports from Operation and Extension * Move multithreaded AddExtension test to different directory to it can be excluded when NGRAPH_ONNX_IMPORT_ENABLE=OFF * Dont include Extension tests if ENABLE_MKL_DNN=OFF * fix excluding onnx_reader tests * include extension tests only if mkl is enabled * add comment on empty blob * use register_operator conditionally in template_extension * fix docs after review * create static library from onnx_custom_op * add additional test for unregister_operator * move model example after register step * revert changes to unit tests * update ngraphConfig.cmake.in header * add headers to onnx_custom_op * changes to docs CMakeLists * remove redundant onnx_importer dependency * remove extension directory from func tests * make onnx_importer a component of ngraph package * docs fixes * update header of ngraph/cmake/share/ngraphConfig.cmake.in with ngraph_onnx_importer_FOUND
2020-10-12 06:36:19 +02:00
#ifdef NGRAPH_ONNX_IMPORT_ENABLED
#include <onnx_import/onnx_utils.hpp>
#endif
2020-05-22 22:34:00 +03:00
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
using namespace TemplateExtension;
Tests and docs for registering custom ONNX operators (#2416) * Add tests, examples and documentation changes for custom ONNX operators registration mechanism * Change snippet paths * fix CoreThreadingTests.ReadNetwork - data race in ops_bridge * Make TemplateExtension::Operation externally visible * changes after review * apply code format * use std::int64_t * forward declare get_attribute_value specializations * introduce unregister_operator in onnx_importer * onnx_custom_op - lock mem first then take a buffer * func tests - create template_extension via make_so_pointer * fix build with NGRAPH_ONNX_IMPORT_ENABLE=OFF * remove exports from Operation and Extension * Move multithreaded AddExtension test to different directory to it can be excluded when NGRAPH_ONNX_IMPORT_ENABLE=OFF * Dont include Extension tests if ENABLE_MKL_DNN=OFF * fix excluding onnx_reader tests * include extension tests only if mkl is enabled * add comment on empty blob * use register_operator conditionally in template_extension * fix docs after review * create static library from onnx_custom_op * add additional test for unregister_operator * move model example after register step * revert changes to unit tests * update ngraphConfig.cmake.in header * add headers to onnx_custom_op * changes to docs CMakeLists * remove redundant onnx_importer dependency * remove extension directory from func tests * make onnx_importer a component of ngraph package * docs fixes * update header of ngraph/cmake/share/ngraphConfig.cmake.in with ngraph_onnx_importer_FOUND
2020-10-12 06:36:19 +02:00
//! [extension:ctor]
Extension::Extension() {
#ifdef NGRAPH_ONNX_IMPORT_ENABLED
ngraph::onnx_import::register_operator(
Operation::type_info.name, 1, "custom_domain", [](const ngraph::onnx_import::Node& node) -> ngraph::OutputVector {
ngraph::OutputVector ng_inputs{node.get_ng_inputs()};
int64_t add = node.get_attribute_value<int64_t>("add");
return {std::make_shared<Operation>(ng_inputs.at(0), add)};
});
#endif
}
//! [extension:ctor]
//! [extension:dtor]
Extension::~Extension() {
#ifdef NGRAPH_ONNX_IMPORT_ENABLED
ngraph::onnx_import::unregister_operator(Operation::type_info.name, 1, "custom_domain");
#endif
}
//! [extension:dtor]
2020-05-22 22:34:00 +03:00
//! [extension:GetVersion]
void Extension::GetVersion(const InferenceEngine::Version *&versionInfo) const noexcept {
static InferenceEngine::Version ExtensionDescription = {
{1, 0}, // extension API version
"1.0",
"template_ext" // extension description message
};
versionInfo = &ExtensionDescription;
}
//! [extension:GetVersion]
//! [extension:getOpSets]
std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
std::map<std::string, ngraph::OpSet> opsets;
ngraph::OpSet opset;
opset.insert<Operation>();
opsets["custom_opset"] = opset;
return opsets;
}
//! [extension:getOpSets]
//! [extension:getImplTypes]
std::vector<std::string> Extension::getImplTypes(const std::shared_ptr<ngraph::Node> &node) {
if (std::dynamic_pointer_cast<Operation>(node)) {
return {"CPU"};
}
return {};
}
//! [extension:getImplTypes]
//! [extension:getImplementation]
InferenceEngine::ILayerImpl::Ptr Extension::getImplementation(const std::shared_ptr<ngraph::Node> &node, const std::string &implType) {
if (std::dynamic_pointer_cast<Operation>(node) && implType == "CPU") {
return std::make_shared<OpImplementation>(node);
}
return nullptr;
}
//! [extension:getImplementation]
//! [extension:CreateExtension]
// Exported function
INFERENCE_EXTENSION_API(InferenceEngine::StatusCode) InferenceEngine::CreateExtension(InferenceEngine::IExtension *&ext,
InferenceEngine::ResponseDesc *resp) noexcept {
try {
ext = new Extension();
return OK;
} catch (std::exception &ex) {
if (resp) {
std::string err = ((std::string) "Couldn't create extension: ") + ex.what();
err.copy(resp->msg, 255);
}
return InferenceEngine::GENERAL_ERROR;
}
}
//! [extension:CreateExtension]