Fixed input/output shape initialization (#1695)
* Fixed input/output shape initialization * Use template_extension library in tests
This commit is contained in:
parent
97842212c3
commit
3928f8806d
@ -23,6 +23,8 @@ OpImplementation::OpImplementation(const std::shared_ptr<ngraph::Node> &node) {
|
||||
if (castedNode->get_input_element_type(0) != ngraph::element::f32 || castedNode->get_output_element_type(0) != ngraph::element::f32)
|
||||
THROW_IE_EXCEPTION << "Operation supports only FP32 tensors.";
|
||||
add = castedNode->getAddAttr();
|
||||
inShape = castedNode->get_input_shape(0);
|
||||
outShape = castedNode->get_output_shape(0);
|
||||
} catch (InferenceEngine::details::InferenceEngineException& ex) {
|
||||
error = ex.what();
|
||||
}
|
||||
|
@ -5,8 +5,6 @@
|
||||
|
||||
set(TARGET_NAME ieFuncTests)
|
||||
|
||||
add_subdirectory(extension_lib)
|
||||
|
||||
addIeTargetTest(
|
||||
NAME ${TARGET_NAME}
|
||||
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
@ -22,7 +20,7 @@ addIeTargetTest(
|
||||
inference_engine_transformations
|
||||
ADD_CPPLINT
|
||||
DEPENDENCIES
|
||||
extension_tests
|
||||
template_extension
|
||||
mock_engine
|
||||
inference_engine_ir_reader
|
||||
inference_engine_ir_v7_reader
|
||||
|
@ -48,10 +48,10 @@ public:
|
||||
try {
|
||||
auto extension = InferenceEngine::make_so_pointer<InferenceEngine::IExtension>(
|
||||
FileUtils::makeSharedLibraryName<char>({},
|
||||
std::string("extension_tests") + IE_BUILD_POSTFIX));
|
||||
std::string("template_extension") + IE_BUILD_POSTFIX));
|
||||
ie.AddExtension(extension);
|
||||
} catch (const InferenceEngine::details::InferenceEngineException & ex) {
|
||||
ASSERT_STR_CONTAINS(ex.what(), "name: experimental. Opset");
|
||||
ASSERT_STR_CONTAINS(ex.what(), "name: custom_opset. Opset");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -1,18 +0,0 @@
|
||||
# Copyright (C) 2020 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
|
||||
set(TARGET_NAME "extension_tests")
|
||||
|
||||
file(GLOB_RECURSE SRC src/*.cpp)
|
||||
|
||||
add_definitions(-DIMPLEMENT_INFERENCE_EXTENSION_API)
|
||||
|
||||
add_library(${TARGET_NAME} SHARED ${SRC})
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE inference_engine ${NGRAPH_LIBRARIES})
|
||||
target_include_directories(${TARGET_NAME} PRIVATE
|
||||
${IE_MAIN_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
|
||||
add_dependencies(${TARGET_NAME} inference_engine ngraph)
|
@ -1,79 +0,0 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ie_iextension.h>
|
||||
#include <ie_api.h>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
class ExtensionTestOp : public ngraph::op::Op {
|
||||
public:
|
||||
static constexpr ngraph::NodeTypeInfo type_info{"Test", 0};
|
||||
const ngraph::NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
|
||||
ExtensionTestOp() = default;
|
||||
explicit ExtensionTestOp(const ngraph::Output<ngraph::Node>& arg): Op({arg}) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void validate_and_infer_types() override {
|
||||
auto input_shape = get_input_partial_shape(0).to_shape();
|
||||
|
||||
ngraph::Shape output_shape(input_shape);
|
||||
for (int i = 0; i < input_shape.size(); ++i) {
|
||||
output_shape[i] = input_shape[i];
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
|
||||
if (new_args.size() != 1) {
|
||||
throw ngraph::ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
||||
return std::make_shared<ExtensionTestOp>(new_args.at(0));
|
||||
}
|
||||
|
||||
bool visit_attributes(ngraph::AttributeVisitor& visitor) override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class TestExtension : public InferenceEngine::IExtension {
|
||||
public:
|
||||
TestExtension() = default;
|
||||
void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override;
|
||||
void Unload() noexcept override {}
|
||||
void Release() noexcept override {
|
||||
delete this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns operation sets
|
||||
* This method throws an exception if it was not implemented
|
||||
* @return map of opset name to opset
|
||||
*/
|
||||
std::map<std::string, ngraph::OpSet> getOpSets() override;
|
||||
|
||||
/**
|
||||
* @brief Returns vector of implementation types
|
||||
* @param node shared pointer to nGraph op
|
||||
* @return vector of strings
|
||||
*/
|
||||
std::vector<std::string> getImplTypes(const std::shared_ptr<ngraph::Node>& node) override;
|
||||
|
||||
/**
|
||||
* @brief Returns implementation for specific nGraph op
|
||||
* @param node shared pointer to nGraph op
|
||||
* @param implType implementation type
|
||||
* @return shared pointer to implementation
|
||||
*/
|
||||
InferenceEngine::ILayerImpl::Ptr getImplementation(const std::shared_ptr<ngraph::Node>& node, const std::string& implType) override;
|
||||
};
|
@ -1,92 +0,0 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <extension.hpp>
|
||||
#include <ngraph/opsets/opset.hpp>
|
||||
#include <ngraph/factory.hpp>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
constexpr ngraph::NodeTypeInfo ExtensionTestOp::type_info;
|
||||
|
||||
class FakeImplementation: public InferenceEngine::ILayerExecImpl {
|
||||
public:
|
||||
InferenceEngine::StatusCode getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig>& conf,
|
||||
InferenceEngine::ResponseDesc* resp) noexcept override {
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
InferenceEngine::StatusCode init(InferenceEngine::LayerConfig& config, InferenceEngine::ResponseDesc* resp) noexcept override {
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
InferenceEngine::StatusCode execute(std::vector<InferenceEngine::Blob::Ptr>& inputs,
|
||||
std::vector<InferenceEngine::Blob::Ptr>& outputs,
|
||||
InferenceEngine::ResponseDesc* resp) noexcept override {
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
};
|
||||
|
||||
void TestExtension::GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept {
|
||||
static InferenceEngine::Version ExtensionDescription = {
|
||||
{ 2, 0 }, // extension API version
|
||||
"2.0",
|
||||
"ie-test-ext" // extension description message
|
||||
};
|
||||
|
||||
versionInfo = &ExtensionDescription;
|
||||
}
|
||||
|
||||
std::map<std::string, ngraph::OpSet> TestExtension::getOpSets() {
|
||||
std::map<std::string, ngraph::OpSet> opsets;
|
||||
ngraph::OpSet opset;
|
||||
opset.insert<ExtensionTestOp>();
|
||||
opsets["experimental"] = opset;
|
||||
return opsets;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns vector of implementation types
|
||||
* @param node shared pointer to nGraph op
|
||||
* @return vector of strings
|
||||
*/
|
||||
std::vector<std::string> TestExtension::getImplTypes(const std::shared_ptr<ngraph::Node>& node) {
|
||||
if (std::dynamic_pointer_cast<ExtensionTestOp>(node)) {
|
||||
return {"CPU"};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns implementation for specific nGraph op
|
||||
* @param node shared pointer to nGraph op
|
||||
* @param implType implementation type
|
||||
* @return shared pointer to implementation
|
||||
*/
|
||||
InferenceEngine::ILayerImpl::Ptr TestExtension::getImplementation(const std::shared_ptr<ngraph::Node>& node, const std::string& implType) {
|
||||
if (std::dynamic_pointer_cast<ExtensionTestOp>(node) && implType == "CPU") {
|
||||
return std::make_shared<FakeImplementation>();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Exported function
|
||||
INFERENCE_EXTENSION_API(InferenceEngine::StatusCode) InferenceEngine::CreateExtension(InferenceEngine::IExtension*& ext,
|
||||
InferenceEngine::ResponseDesc* resp) noexcept {
|
||||
try {
|
||||
ext = new TestExtension();
|
||||
return OK;
|
||||
} catch (std::exception& ex) {
|
||||
if (resp) {
|
||||
std::string err = ((std::string)"Couldn't create extension: ") + ex.what();
|
||||
err.copy(resp->msg, 255);
|
||||
}
|
||||
return InferenceEngine::GENERAL_ERROR;
|
||||
}
|
||||
}
|
@ -64,7 +64,7 @@ public:
|
||||
void safeAddExtension(InferenceEngine::Core & ie) {
|
||||
try {
|
||||
auto extension = InferenceEngine::make_so_pointer<InferenceEngine::IExtension>(
|
||||
FileUtils::makeSharedLibraryName<char>({}, "extension_tests"));
|
||||
FileUtils::makeSharedLibraryName<char>({}, "template_extension"));
|
||||
ie.AddExtension(extension);
|
||||
} catch (const InferenceEngine::details::InferenceEngineException & ex) {
|
||||
ASSERT_STR_CONTAINS(ex.what(), "name: experimental");
|
||||
|
@ -21,6 +21,7 @@ addIeTargetTest(
|
||||
${OpenCV_LIBRARIES}
|
||||
ADD_CPPLINT
|
||||
DEPENDENCIES
|
||||
template_extension
|
||||
mock_engine
|
||||
LABELS
|
||||
IE
|
||||
|
@ -21,7 +21,7 @@ using ExtensionTests = ::testing::Test;
|
||||
|
||||
std::string getExtensionPath() {
|
||||
return FileUtils::makeSharedLibraryName<char>({},
|
||||
std::string("extension_tests") + IE_BUILD_POSTFIX);
|
||||
std::string("template_extension") + IE_BUILD_POSTFIX);
|
||||
}
|
||||
|
||||
TEST(ExtensionTests, testGetOpSets) {
|
||||
@ -55,4 +55,4 @@ TEST(ExtensionTests, testGetImplementationThrowsIfNgraphNodeIsNullPtr) {
|
||||
IExtensionPtr extension = make_so_pointer<IExtension>(getExtensionPath());
|
||||
ASSERT_THROW(extension->getImplementation(std::shared_ptr<ngraph::Node> (), ""),
|
||||
InferenceEngine::details::InferenceEngineException);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user