Exec extensions (#963)
* Fixes * Removed some tests for extensions * Added const * Removed unknown pragma
This commit is contained in:
@@ -38,7 +38,6 @@ public:
|
||||
/**
|
||||
* @brief This class is a C++ helper to work with objects created using extensions.
|
||||
*/
|
||||
IE_SUPPRESS_DEPRECATED_START_WIN
|
||||
class INFERENCE_ENGINE_API_CLASS(Extension) : public IExtension {
|
||||
public:
|
||||
/**
|
||||
@@ -69,41 +68,6 @@ public:
|
||||
*/
|
||||
void Release() noexcept override {}
|
||||
|
||||
/**
|
||||
* @deprecated Use IExtension::getImplTypes to get implementation types for a particular node.
|
||||
* The method will removed in 2021.1 release.
|
||||
* @brief Gets the array with types of layers which are included in the extension
|
||||
*
|
||||
* @param types Types array
|
||||
* @param size Size of the types array
|
||||
* @param resp Response descriptor
|
||||
* @return Status code
|
||||
*/
|
||||
INFERENCE_ENGINE_DEPRECATED("Use IExtension::getImplTypes to get implementation types for a particular node")
|
||||
StatusCode getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
return actual->getPrimitiveTypes(types, size, resp);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use IExtension::getImplementation to get a concrete implementation.
|
||||
* The method will be removed in 2021.1 release.
|
||||
* @brief Gets the factory with implementations for a given layer
|
||||
*
|
||||
* @param factory Factory with implementations
|
||||
* @param cnnLayer A layer to get the factory for
|
||||
* @param resp Response descriptor
|
||||
* @return Status code
|
||||
*/
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
INFERENCE_ENGINE_DEPRECATED("Use IExtension::getImplementation to get a concrete implementation")
|
||||
StatusCode getFactoryFor(ILayerImplFactory*& factory, const CNNLayer* cnnLayer,
|
||||
ResponseDesc* resp) noexcept override {
|
||||
return actual->getFactoryFor(factory, cnnLayer, resp);
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
/**
|
||||
* @brief Returns operation sets
|
||||
* This method throws an exception if it was not implemented
|
||||
|
||||
@@ -144,81 +144,11 @@ public:
|
||||
ResponseDesc* resp) noexcept = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @deprecated Implement IExtension::getImplTypes and IExtension::getImplementation
|
||||
* The interface will be removed in 2021.1 release.
|
||||
* @interface ILayerImplFactory
|
||||
* @brief This class provides interface for extension factories
|
||||
*/
|
||||
class INFERENCE_ENGINE_DEPRECATED("Implement IExtension::getImplTypes and IExtension::getImplementation") ILayerImplFactory {
|
||||
public:
|
||||
/**
|
||||
* @brief A shared pointer to the ILayerImplFactory interface
|
||||
*/
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
using Ptr = std::shared_ptr<ILayerImplFactory>;
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
using ImplCreator = std::function<ILayerImpl*()>;
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
*/
|
||||
virtual ~ILayerImplFactory() = default;
|
||||
|
||||
/**
|
||||
* @brief Gets all possible implementations for the given cnn Layer
|
||||
*
|
||||
* @param impls the vector with implementations which is ordered by priority
|
||||
* @param resp response descriptor
|
||||
* @return status code
|
||||
*/
|
||||
virtual StatusCode getImplementations(std::vector<ILayerImpl::Ptr>& impls, ResponseDesc* resp) noexcept = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief This class is the main extension interface
|
||||
*/
|
||||
class INFERENCE_ENGINE_API_CLASS(IExtension) : public InferenceEngine::details::IRelease {
|
||||
public:
|
||||
/**
|
||||
* @deprecated Use IExtension::getImplementation to get a concrete implementation
|
||||
* The method will be removed in 2021.1 release.
|
||||
* @brief Provides a factory for a specified CNNLayer
|
||||
* @param factory A factory returned from an extension plugin
|
||||
* @param cnnLayer A CNNLayer object to provide factory for
|
||||
* @param resp Response descriptor
|
||||
* @return Status code
|
||||
*/
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
INFERENCE_ENGINE_DEPRECATED("Use IExtension::getImplementation to get a concrete implementation")
|
||||
virtual StatusCode getFactoryFor(ILayerImplFactory*& factory, const CNNLayer* cnnLayer,
|
||||
ResponseDesc* resp) noexcept {
|
||||
(void)factory;
|
||||
(void)cnnLayer;
|
||||
(void)resp;
|
||||
return NOT_IMPLEMENTED;
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
/**
|
||||
* @deprecated Use IExtension::getImplTypes to get implementation types for a particular node
|
||||
* The method will be removed in 2021.1 release.
|
||||
* @brief Fills passed array with types of layers which kernel implementations are included in the extension
|
||||
*
|
||||
* @param types Array to store the layer types
|
||||
* @param size Size of the layer types array
|
||||
* @param resp Response descriptor
|
||||
* @return Status code
|
||||
*/
|
||||
INFERENCE_ENGINE_DEPRECATED("Use IExtension::getImplTypes to get implementation types for a particular node")
|
||||
virtual StatusCode getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
|
||||
(void)types;
|
||||
(void)size;
|
||||
(void)resp;
|
||||
return NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns operation sets
|
||||
* This method throws an exception if it was not implemented
|
||||
|
||||
@@ -511,7 +511,7 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
|
||||
this->node = node;
|
||||
}
|
||||
};
|
||||
static std::vector<std::shared_ptr<Builder::INodeConverter>> convertors = {
|
||||
const static std::vector<std::shared_ptr<Builder::INodeConverter>> convertors = {
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::Abs>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::Acos>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::v1::Add>>(),
|
||||
|
||||
@@ -31,8 +31,6 @@ InferenceEngine::ILayerImpl::Ptr MKLDNNExtensionManager::CreateImplementation(co
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
std::shared_ptr<InferenceEngine::ILayerImplFactory> MKLDNNExtensionManager::CreateExtensionFactory(
|
||||
const InferenceEngine::CNNLayerPtr &layer) {
|
||||
if (!layer)
|
||||
@@ -40,9 +38,10 @@ std::shared_ptr<InferenceEngine::ILayerImplFactory> MKLDNNExtensionManager::Crea
|
||||
std::shared_ptr<ILayerImplFactory> factory;
|
||||
for (auto& ext : _extensions) {
|
||||
ResponseDesc responseDesc;
|
||||
StatusCode rc;
|
||||
StatusCode rc = GENERAL_ERROR;
|
||||
ILayerImplFactory* factory_ptr = nullptr;
|
||||
rc = ext->getFactoryFor(factory_ptr, layer.get(), &responseDesc);
|
||||
if (auto mkldnnExt = std::dynamic_pointer_cast<Extensions::Cpu::MKLDNNExtensions>(ext))
|
||||
rc = mkldnnExt->getFactoryFor(factory_ptr, layer.get(), &responseDesc);
|
||||
if (rc != OK) {
|
||||
factory = nullptr;
|
||||
continue;
|
||||
@@ -55,5 +54,3 @@ std::shared_ptr<InferenceEngine::ILayerImplFactory> MKLDNNExtensionManager::Crea
|
||||
}
|
||||
return factory;
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <ie_iextension.h>
|
||||
#include <ie_layers.h>
|
||||
#include "ie_ishape_infer_extension.hpp"
|
||||
#include "nodes/list.hpp"
|
||||
|
||||
namespace MKLDNNPlugin {
|
||||
|
||||
@@ -18,9 +19,7 @@ public:
|
||||
using Ptr = std::shared_ptr<MKLDNNExtensionManager>;
|
||||
MKLDNNExtensionManager() = default;
|
||||
InferenceEngine::ILayerImpl::Ptr CreateImplementation(const std::shared_ptr<ngraph::Node>& op);
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::shared_ptr<InferenceEngine::ILayerImplFactory> CreateExtensionFactory(const InferenceEngine::CNNLayerPtr& Layer);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
void AddExtension(InferenceEngine::IExtensionPtr extension);
|
||||
|
||||
private:
|
||||
|
||||
@@ -153,8 +153,6 @@ protected:
|
||||
std::vector<LayerConfig> confs;
|
||||
};
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
template <class IMPL>
|
||||
class ImplFactory : public ILayerImplFactory {
|
||||
public:
|
||||
@@ -174,16 +172,12 @@ protected:
|
||||
InferenceEngine::CNNLayerPtr cnnLayer;
|
||||
};
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
template <typename __prim>
|
||||
inline void extRegister(MKLDNNExtensions * extInstance, const char * __type) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
extInstance->AddExt(__type,
|
||||
[](const CNNLayer* layer) -> InferenceEngine::ILayerImplFactory* {
|
||||
return new __prim(layer);
|
||||
});
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
#define REG_FACTORY_FOR(__prim, __type) \
|
||||
|
||||
@@ -13,10 +13,34 @@
|
||||
#include <algorithm>
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
class ILayerImplFactory {
|
||||
public:
|
||||
/**
|
||||
* @brief A shared pointer to the ILayerImplFactory interface
|
||||
*/
|
||||
using Ptr = std::shared_ptr<ILayerImplFactory>;
|
||||
|
||||
using ImplCreator = std::function<ILayerImpl*()>;
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
*/
|
||||
virtual ~ILayerImplFactory() = default;
|
||||
|
||||
/**
|
||||
* @brief Gets all possible implementations for the given cnn Layer
|
||||
*
|
||||
* @param impls the vector with implementations which is ordered by priority
|
||||
* @param resp response descriptor
|
||||
* @return status code
|
||||
*/
|
||||
virtual StatusCode getImplementations(std::vector<ILayerImpl::Ptr>& impls, ResponseDesc* resp) noexcept = 0;
|
||||
};
|
||||
|
||||
namespace Extensions {
|
||||
namespace Cpu {
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
using ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer*)>;
|
||||
|
||||
struct ExtensionsHolder {
|
||||
@@ -27,13 +51,14 @@ class MKLDNNExtensions : public IExtension {
|
||||
public:
|
||||
MKLDNNExtensions();
|
||||
|
||||
StatusCode getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override {
|
||||
virtual StatusCode
|
||||
getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
|
||||
collectTypes(types, size, extensionsHolder->list);
|
||||
return OK;
|
||||
}
|
||||
|
||||
StatusCode
|
||||
getFactoryFor(ILayerImplFactory*& factory, const CNNLayer* cnnLayer, ResponseDesc* resp) noexcept override {
|
||||
virtual StatusCode
|
||||
getFactoryFor(ILayerImplFactory*& factory, const CNNLayer* cnnLayer, ResponseDesc* resp) noexcept {
|
||||
auto& factories = extensionsHolder->list;
|
||||
if (factories.find(cnnLayer->type) == factories.end()) {
|
||||
std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
|
||||
@@ -80,8 +105,6 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
} // namespace Cpu
|
||||
} // namespace Extensions
|
||||
} // namespace InferenceEngine
|
||||
|
||||
@@ -36,9 +36,7 @@ void MKLDNNGenericNode::initSupportedPrimitiveDescriptors() {
|
||||
|
||||
std::vector<InferenceEngine::ILayerImpl::Ptr> impls_no_exec;
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
InferenceEngine::StatusCode rc = extFactory->getImplementations(impls_no_exec, &resp);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
for (const auto& impl : impls_no_exec) {
|
||||
if (auto exec_impl = std::dynamic_pointer_cast<InferenceEngine::ILayerExecImpl>(impl)) {
|
||||
impls.emplace_back(exec_impl);
|
||||
|
||||
@@ -37,10 +37,7 @@ public:
|
||||
|
||||
|
||||
protected:
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
InferenceEngine::ILayerImplFactory::Ptr extFactory;
|
||||
InferenceEngine::IShapeInferImpl::Ptr extShapeInference;
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
std::vector<InferenceEngine::ILayerExecImpl::Ptr> impls;
|
||||
std::map<std::string, std::string> params;
|
||||
std::map<std::string, InferenceEngine::Blob::Ptr> blobs;
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
class ExtensionTestOp : public ngraph::op::Op {
|
||||
public:
|
||||
static constexpr ngraph::NodeTypeInfo type_info{"Test", 0};
|
||||
@@ -56,26 +54,6 @@ public:
|
||||
void Release() noexcept override {
|
||||
delete this;
|
||||
}
|
||||
InferenceEngine::StatusCode getFactoryFor(InferenceEngine::ILayerImplFactory*& factory,
|
||||
const InferenceEngine::CNNLayer* cnnLayer,
|
||||
InferenceEngine::ResponseDesc* resp) noexcept override {
|
||||
if (cnnLayer == nullptr || cnnLayer->type != "test")
|
||||
return InferenceEngine::GENERAL_ERROR;
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fills passed array with types of layers which kernel implementations are included in the extension
|
||||
*
|
||||
* @param types Array to store the layer types
|
||||
* @param size Size of the layer types array
|
||||
* @param resp Response descriptor
|
||||
* @return Status code
|
||||
*/
|
||||
InferenceEngine::StatusCode getPrimitiveTypes(char**& types, unsigned int& size, InferenceEngine::ResponseDesc* resp) noexcept override {
|
||||
size = 1;
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns operation sets
|
||||
@@ -99,5 +77,3 @@ public:
|
||||
*/
|
||||
InferenceEngine::ILayerImpl::Ptr getImplementation(const std::shared_ptr<ngraph::Node>& node, const std::string& implType) override;
|
||||
};
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
@@ -362,23 +362,12 @@ class BadExtension : public InferenceEngine::IExtension {
|
||||
public:
|
||||
BadExtension() {}
|
||||
|
||||
InferenceEngine::StatusCode
|
||||
getPrimitiveTypes(char**& types, unsigned int& size, InferenceEngine::ResponseDesc* resp) noexcept override {
|
||||
return GENERAL_ERROR;
|
||||
};
|
||||
|
||||
void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override {};
|
||||
|
||||
void Unload() noexcept override {};
|
||||
|
||||
void Release() noexcept override {}
|
||||
|
||||
InferenceEngine::StatusCode
|
||||
getFactoryFor(InferenceEngine::ILayerImplFactory*& factory, const InferenceEngine::CNNLayer* cnnLayer,
|
||||
InferenceEngine::ResponseDesc* resp) noexcept override {
|
||||
return InferenceEngine::StatusCode::NOT_IMPLEMENTED;
|
||||
};
|
||||
|
||||
std::map<std::string, ngraph::OpSet> getOpSets() override {
|
||||
static std::map<std::string, ngraph::OpSet> opsets;
|
||||
if (opsets.empty()) {
|
||||
|
||||
@@ -24,31 +24,6 @@ std::string getExtensionPath() {
|
||||
std::string("extension_tests") + IE_BUILD_POSTFIX);
|
||||
}
|
||||
|
||||
TEST(ExtensionTests, testGetFactoryFor) {
|
||||
IExtensionPtr extension = make_so_pointer<IExtension>(getExtensionPath());
|
||||
CNNLayer testLayer({"test1", "test", Precision::FP32});
|
||||
ILayerImplFactory* factory = nullptr;
|
||||
ResponseDesc resp;
|
||||
ASSERT_EQ(OK, extension->getFactoryFor(factory, &testLayer, &resp));
|
||||
}
|
||||
|
||||
TEST(ExtensionTests, testGetIncorrectFactoryFor) {
|
||||
IExtensionPtr extension = make_so_pointer<IExtension>(getExtensionPath());
|
||||
CNNLayer testLayer({"test1", "test_incorrect", Precision::FP32});
|
||||
ILayerImplFactory* factory = nullptr;
|
||||
ResponseDesc resp;
|
||||
ASSERT_NE(OK, extension->getFactoryFor(factory, &testLayer, &resp));
|
||||
}
|
||||
|
||||
TEST(ExtensionTests, testGetPrimitiveTypes) {
|
||||
IExtensionPtr extension = make_so_pointer<IExtension>(getExtensionPath());
|
||||
ResponseDesc resp;
|
||||
char **types;
|
||||
unsigned int size(0);
|
||||
ASSERT_EQ(OK, extension->getPrimitiveTypes(types, size, &resp));
|
||||
ASSERT_EQ(1, size);
|
||||
}
|
||||
|
||||
TEST(ExtensionTests, testGetOpSets) {
|
||||
IExtensionPtr extension = make_so_pointer<IExtension>(getExtensionPath());
|
||||
auto opsets = extension->getOpSets();
|
||||
|
||||
@@ -55,8 +55,7 @@ target_compile_definitions(${TARGET_NAME}
|
||||
MODELS_PATH=\"${MODELS_PATH}\" PARENT_SCOPE)
|
||||
|
||||
target_include_directories(${TARGET_NAME} PRIVATE
|
||||
${IE_MAIN_SOURCE_DIR}/src/extension
|
||||
${IE_MAIN_SOURCE_DIR}/src/extension/common)
|
||||
${IE_MAIN_SOURCE_DIR}/src/mkldnn_plugin)
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE ${LIBRARIES})
|
||||
|
||||
|
||||
@@ -23,8 +23,6 @@ struct extension_params {
|
||||
std::map<std::string, std::string> config;
|
||||
};
|
||||
|
||||
using ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer *)>;
|
||||
|
||||
class FakePrimitiveImpl : public InferenceEngine::ILayerExecImpl {
|
||||
public:
|
||||
FakePrimitiveImpl(const InferenceEngine::CNNLayer *layer) {
|
||||
@@ -61,26 +59,8 @@ private:
|
||||
InferenceEngine::CNNLayer* cnnLayer;
|
||||
};
|
||||
|
||||
class FakePrimitiveFactory : public InferenceEngine::ILayerImplFactory {
|
||||
public:
|
||||
FakePrimitiveFactory(const InferenceEngine::CNNLayer *layer) {
|
||||
cnnLayer = const_cast<InferenceEngine::CNNLayer *>(layer);
|
||||
}
|
||||
// First implementation has more priority than next
|
||||
InferenceEngine::StatusCode getImplementations(std::vector<InferenceEngine::ILayerImpl::Ptr>& impls, InferenceEngine::ResponseDesc *resp) noexcept override {
|
||||
impls.push_back(InferenceEngine::ILayerImpl::Ptr(new FakePrimitiveImpl(cnnLayer)));
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
private:
|
||||
InferenceEngine::CNNLayer * cnnLayer;
|
||||
};
|
||||
|
||||
class TestExtension : public InferenceEngine::IExtension {
|
||||
public:
|
||||
TestExtension() {
|
||||
factories["Fake"] = [](const InferenceEngine::CNNLayer * cnnLayer) -> InferenceEngine::ILayerImplFactory* { return new FakePrimitiveFactory(cnnLayer); };
|
||||
}
|
||||
void Release() noexcept override { delete this; }
|
||||
|
||||
void GetVersion(const InferenceEngine::Version *&versionInfo) const noexcept override
|
||||
@@ -90,28 +70,6 @@ public:
|
||||
}
|
||||
|
||||
void Unload() noexcept override {}
|
||||
StatusCode getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override {
|
||||
types = new char *[factories.size()];
|
||||
size_t count = 0;
|
||||
for (auto it = factories.begin(); it != factories.end(); it++, count ++) {
|
||||
types[count] = new char[it->first.size() + 1];
|
||||
std::copy(it->first.begin(), it->first.end(), types[count]);
|
||||
types[count][it->first.size() ] = '\0';
|
||||
}
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
StatusCode getFactoryFor(ILayerImplFactory *&factory, const CNNLayer *cnnLayer, ResponseDesc *resp) noexcept override {
|
||||
if (factories.find(cnnLayer->type) == factories.end()) {
|
||||
std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
|
||||
errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
|
||||
return InferenceEngine::NOT_FOUND;
|
||||
}
|
||||
factory = factories[cnnLayer->type](cnnLayer);
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
private:
|
||||
std::map<std::string, ext_factory> factories;
|
||||
};
|
||||
|
||||
class NewFakePrimitiveImpl : public InferenceEngine::ILayerExecImpl {
|
||||
@@ -361,17 +319,17 @@ protected:
|
||||
#endif
|
||||
|
||||
TEST_F(smoke_ExtensionTest, MKLDNN_delete_extension) {
|
||||
std::shared_ptr<IExtension> ext(new TestExtension());
|
||||
std::shared_ptr<IExtension> ext(new NewTestExtension());
|
||||
checkExtensionRemoved({"MKLDNN", ext});
|
||||
}
|
||||
|
||||
TEST_F(smoke_ExtensionTest, MKLDNN_no_delete_extension_from_another_engine) {
|
||||
std::shared_ptr<IExtension> ext(new TestExtension());
|
||||
std::shared_ptr<IExtension> ext(new NewTestExtension());
|
||||
checkExtensionNotRemovedFromAnotherEngineObject({"MKLDNN", ext});
|
||||
}
|
||||
|
||||
TEST_F(smoke_ExtensionTest, MKLDNN_no_share_extension_between_engines) {
|
||||
std::shared_ptr<IExtension> ext(new TestExtension());
|
||||
std::shared_ptr<IExtension> ext(new NewTestExtension());
|
||||
checkNotSharedExtensions(ext, "CPU");
|
||||
}
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ private:
|
||||
|
||||
using fake_ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer *)>;
|
||||
|
||||
class FakeConstExtensionFabric : public InferenceEngine::IExtension {
|
||||
class FakeConstExtensionFabric : public InferenceEngine::Extensions::Cpu::MKLDNNExtensions {
|
||||
public:
|
||||
FakeConstExtensionFabric() {
|
||||
factories["ConstLayer"] = [](const InferenceEngine::CNNLayer * cnnLayer) -> InferenceEngine::ILayerImplFactory* { return new ConstLayerFactory(cnnLayer); };
|
||||
|
||||
@@ -18,7 +18,7 @@ struct TestExtensionsHolder {
|
||||
};
|
||||
|
||||
|
||||
class FakeExtensions : public IExtension {
|
||||
class FakeExtensions : public Cpu::MKLDNNExtensions {
|
||||
public:
|
||||
void Unload() noexcept override {};
|
||||
|
||||
|
||||
@@ -415,7 +415,7 @@ private:
|
||||
};
|
||||
using fake_ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer *)>;
|
||||
|
||||
class FakeExtensionFabric : public InferenceEngine::IExtension {
|
||||
class FakeExtensionFabric : public InferenceEngine::Extensions::Cpu::MKLDNNExtensions {
|
||||
public:
|
||||
FakeExtensionFabric() {
|
||||
factories["CustomNewConvolution"] = [](const InferenceEngine::CNNLayer * cnnLayer) -> InferenceEngine::ILayerImplFactory* { return new FakeGenericPrimitiveFactory(); };
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "tests_common.hpp"
|
||||
#include <ie_core.hpp>
|
||||
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace std;
|
||||
using namespace mkldnn;
|
||||
@@ -292,7 +291,7 @@ private:
|
||||
InferenceEngine::CNNLayer * cnnLayer;
|
||||
};
|
||||
|
||||
class FakeFabric : public InferenceEngine::IExtension {
|
||||
class FakeFabric : public InferenceEngine::Extensions::Cpu::MKLDNNExtensions {
|
||||
public:
|
||||
FakeFabric() {
|
||||
factories["ReLU"] = [](const InferenceEngine::CNNLayer * cnnLayer) -> InferenceEngine::ILayerImplFactory* { return new FakeReLUFactory(cnnLayer); };
|
||||
|
||||
Reference in New Issue
Block a user