Fixed input/output shape initialization (#1695)

* Fixed input/output shape initialization

* Use template_extension library in tests
This commit is contained in:
Ilya Churaev 2020-08-10 18:24:25 +03:00 committed by GitHub
parent 97842212c3
commit 3928f8806d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 9 additions and 197 deletions

View File

@ -23,6 +23,8 @@ OpImplementation::OpImplementation(const std::shared_ptr<ngraph::Node> &node) {
if (castedNode->get_input_element_type(0) != ngraph::element::f32 || castedNode->get_output_element_type(0) != ngraph::element::f32)
THROW_IE_EXCEPTION << "Operation supports only FP32 tensors.";
add = castedNode->getAddAttr();
inShape = castedNode->get_input_shape(0);
outShape = castedNode->get_output_shape(0);
} catch (InferenceEngine::details::InferenceEngineException& ex) {
error = ex.what();
}

View File

@ -5,8 +5,6 @@
set(TARGET_NAME ieFuncTests)
add_subdirectory(extension_lib)
addIeTargetTest(
NAME ${TARGET_NAME}
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
@ -22,7 +20,7 @@ addIeTargetTest(
inference_engine_transformations
ADD_CPPLINT
DEPENDENCIES
extension_tests
template_extension
mock_engine
inference_engine_ir_reader
inference_engine_ir_v7_reader

View File

@ -48,10 +48,10 @@ public:
try {
auto extension = InferenceEngine::make_so_pointer<InferenceEngine::IExtension>(
FileUtils::makeSharedLibraryName<char>({},
std::string("extension_tests") + IE_BUILD_POSTFIX));
std::string("template_extension") + IE_BUILD_POSTFIX));
ie.AddExtension(extension);
} catch (const InferenceEngine::details::InferenceEngineException & ex) {
ASSERT_STR_CONTAINS(ex.what(), "name: experimental. Opset");
ASSERT_STR_CONTAINS(ex.what(), "name: custom_opset. Opset");
}
}
};

View File

@ -1,18 +0,0 @@
# Copyright (C) 2020 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
set(TARGET_NAME "extension_tests")
file(GLOB_RECURSE SRC src/*.cpp)
add_definitions(-DIMPLEMENT_INFERENCE_EXTENSION_API)
add_library(${TARGET_NAME} SHARED ${SRC})
target_link_libraries(${TARGET_NAME} PRIVATE inference_engine ${NGRAPH_LIBRARIES})
target_include_directories(${TARGET_NAME} PRIVATE
${IE_MAIN_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/include)
add_dependencies(${TARGET_NAME} inference_engine ngraph)

View File

@ -1,79 +0,0 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ie_iextension.h>
#include <ie_api.h>
#include <ngraph/ngraph.hpp>
#include <memory>
#include <vector>
#include <string>
#include <map>
class ExtensionTestOp : public ngraph::op::Op {
public:
static constexpr ngraph::NodeTypeInfo type_info{"Test", 0};
const ngraph::NodeTypeInfo& get_type_info() const override { return type_info; }
ExtensionTestOp() = default;
explicit ExtensionTestOp(const ngraph::Output<ngraph::Node>& arg): Op({arg}) {
constructor_validate_and_infer_types();
}
void validate_and_infer_types() override {
auto input_shape = get_input_partial_shape(0).to_shape();
ngraph::Shape output_shape(input_shape);
for (int i = 0; i < input_shape.size(); ++i) {
output_shape[i] = input_shape[i];
}
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
}
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
if (new_args.size() != 1) {
throw ngraph::ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<ExtensionTestOp>(new_args.at(0));
}
bool visit_attributes(ngraph::AttributeVisitor& visitor) override {
return true;
}
};
class TestExtension : public InferenceEngine::IExtension {
public:
TestExtension() = default;
void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override;
void Unload() noexcept override {}
void Release() noexcept override {
delete this;
}
/**
* @brief Returns operation sets
* This method throws an exception if it was not implemented
* @return map of opset name to opset
*/
std::map<std::string, ngraph::OpSet> getOpSets() override;
/**
* @brief Returns vector of implementation types
* @param node shared pointer to nGraph op
* @return vector of strings
*/
std::vector<std::string> getImplTypes(const std::shared_ptr<ngraph::Node>& node) override;
/**
* @brief Returns implementation for specific nGraph op
* @param node shared pointer to nGraph op
* @param implType implementation type
* @return shared pointer to implementation
*/
InferenceEngine::ILayerImpl::Ptr getImplementation(const std::shared_ptr<ngraph::Node>& node, const std::string& implType) override;
};

View File

@ -1,92 +0,0 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <extension.hpp>
#include <ngraph/opsets/opset.hpp>
#include <ngraph/factory.hpp>
#include <unordered_map>
#include <string>
#include <memory>
#include <vector>
#include <map>
IE_SUPPRESS_DEPRECATED_START
constexpr ngraph::NodeTypeInfo ExtensionTestOp::type_info;
class FakeImplementation: public InferenceEngine::ILayerExecImpl {
public:
InferenceEngine::StatusCode getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig>& conf,
InferenceEngine::ResponseDesc* resp) noexcept override {
return InferenceEngine::OK;
}
InferenceEngine::StatusCode init(InferenceEngine::LayerConfig& config, InferenceEngine::ResponseDesc* resp) noexcept override {
return InferenceEngine::OK;
}
InferenceEngine::StatusCode execute(std::vector<InferenceEngine::Blob::Ptr>& inputs,
std::vector<InferenceEngine::Blob::Ptr>& outputs,
InferenceEngine::ResponseDesc* resp) noexcept override {
return InferenceEngine::OK;
}
};
void TestExtension::GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept {
static InferenceEngine::Version ExtensionDescription = {
{ 2, 0 }, // extension API version
"2.0",
"ie-test-ext" // extension description message
};
versionInfo = &ExtensionDescription;
}
std::map<std::string, ngraph::OpSet> TestExtension::getOpSets() {
std::map<std::string, ngraph::OpSet> opsets;
ngraph::OpSet opset;
opset.insert<ExtensionTestOp>();
opsets["experimental"] = opset;
return opsets;
}
/**
* @brief Returns vector of implementation types
* @param node shared pointer to nGraph op
* @return vector of strings
*/
std::vector<std::string> TestExtension::getImplTypes(const std::shared_ptr<ngraph::Node>& node) {
if (std::dynamic_pointer_cast<ExtensionTestOp>(node)) {
return {"CPU"};
}
return {};
}
/**
* @brief Returns implementation for specific nGraph op
* @param node shared pointer to nGraph op
* @param implType implementation type
* @return shared pointer to implementation
*/
InferenceEngine::ILayerImpl::Ptr TestExtension::getImplementation(const std::shared_ptr<ngraph::Node>& node, const std::string& implType) {
if (std::dynamic_pointer_cast<ExtensionTestOp>(node) && implType == "CPU") {
return std::make_shared<FakeImplementation>();
}
return nullptr;
}
// Exported function
INFERENCE_EXTENSION_API(InferenceEngine::StatusCode) InferenceEngine::CreateExtension(InferenceEngine::IExtension*& ext,
InferenceEngine::ResponseDesc* resp) noexcept {
try {
ext = new TestExtension();
return OK;
} catch (std::exception& ex) {
if (resp) {
std::string err = ((std::string)"Couldn't create extension: ") + ex.what();
err.copy(resp->msg, 255);
}
return InferenceEngine::GENERAL_ERROR;
}
}

View File

@ -64,7 +64,7 @@ public:
void safeAddExtension(InferenceEngine::Core & ie) {
try {
auto extension = InferenceEngine::make_so_pointer<InferenceEngine::IExtension>(
FileUtils::makeSharedLibraryName<char>({}, "extension_tests"));
FileUtils::makeSharedLibraryName<char>({}, "template_extension"));
ie.AddExtension(extension);
} catch (const InferenceEngine::details::InferenceEngineException & ex) {
ASSERT_STR_CONTAINS(ex.what(), "name: experimental");

View File

@ -21,6 +21,7 @@ addIeTargetTest(
${OpenCV_LIBRARIES}
ADD_CPPLINT
DEPENDENCIES
template_extension
mock_engine
LABELS
IE

View File

@ -21,7 +21,7 @@ using ExtensionTests = ::testing::Test;
std::string getExtensionPath() {
return FileUtils::makeSharedLibraryName<char>({},
std::string("extension_tests") + IE_BUILD_POSTFIX);
std::string("template_extension") + IE_BUILD_POSTFIX);
}
TEST(ExtensionTests, testGetOpSets) {
@ -55,4 +55,4 @@ TEST(ExtensionTests, testGetImplementationThrowsIfNgraphNodeIsNullPtr) {
IExtensionPtr extension = make_so_pointer<IExtension>(getExtensionPath());
ASSERT_THROW(extension->getImplementation(std::shared_ptr<ngraph::Node> (), ""),
InferenceEngine::details::InferenceEngineException);
}
}