Add transformation pipline for POT (#5328)

This commit is contained in:
Gleb Kazantaev 2021-04-22 19:13:14 +03:00 committed by GitHub
parent cc7ae5e6d1
commit 28cad9e3fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 0 deletions

View File

@ -5,10 +5,14 @@ from .cimport offline_transformations_api_impl_defs as C
from ..inference_engine.ie_api cimport IENetwork from ..inference_engine.ie_api cimport IENetwork
from libcpp cimport bool from libcpp cimport bool
from libcpp.string cimport string
def ApplyMOCTransformations(IENetwork network, bool cf): def ApplyMOCTransformations(IENetwork network, bool cf):
C.ApplyMOCTransformations(network.impl, cf) C.ApplyMOCTransformations(network.impl, cf)
def ApplyPOTTransformations(IENetwork network, string device):
C.ApplyPOTTransformations(network.impl, device)
def ApplyLowLatencyTransformation(IENetwork network): def ApplyLowLatencyTransformation(IENetwork network):
C.ApplyLowLatencyTransformation(network.impl) C.ApplyLowLatencyTransformation(network.impl)

View File

@ -5,6 +5,7 @@
#include "offline_transformations_api_impl.hpp" #include "offline_transformations_api_impl.hpp"
#include <moc_transformations.hpp> #include <moc_transformations.hpp>
#include <pot_transformations.hpp>
#include <pruning.hpp> #include <pruning.hpp>
#include <transformations/control_flow/unroll_tensor_iterator.hpp> #include <transformations/control_flow/unroll_tensor_iterator.hpp>
@ -21,6 +22,12 @@ void InferenceEnginePython::ApplyMOCTransformations(InferenceEnginePython::IENet
manager.run_passes(network.actual->getFunction()); manager.run_passes(network.actual->getFunction());
} }
void InferenceEnginePython::ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::string device) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::POTTransformations>(std::move(device));
manager.run_passes(network.actual->getFunction());
}
void InferenceEnginePython::ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network) { void InferenceEnginePython::ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network) {
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::LowLatency>(); manager.register_pass<ngraph::pass::LowLatency>();

View File

@ -4,6 +4,8 @@
#pragma once #pragma once
#include <string>
#include "Python.h" #include "Python.h"
#include "ie_api_impl.hpp" #include "ie_api_impl.hpp"
@ -11,6 +13,8 @@ namespace InferenceEnginePython {
void ApplyMOCTransformations(InferenceEnginePython::IENetwork network, bool cf); void ApplyMOCTransformations(InferenceEnginePython::IENetwork network, bool cf);
void ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::string device);
void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network); void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network);
void ApplyPruningTransformation(InferenceEnginePython::IENetwork network); void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);

View File

@ -2,11 +2,15 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from libcpp cimport bool from libcpp cimport bool
from libcpp.string cimport string
from ..inference_engine.ie_api_impl_defs cimport IENetwork from ..inference_engine.ie_api_impl_defs cimport IENetwork
cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEnginePython": cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEnginePython":
cdef void ApplyMOCTransformations(IENetwork network, bool cf) cdef void ApplyMOCTransformations(IENetwork network, bool cf)
cdef void ApplyPOTTransformations(IENetwork network, string device)
cdef void ApplyLowLatencyTransformation(IENetwork network) cdef void ApplyLowLatencyTransformation(IENetwork network)
cdef void ApplyPruningTransformation(IENetwork network) cdef void ApplyPruningTransformation(IENetwork network)

View File

@ -0,0 +1,33 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class POTTransformations;
} // namespace pass
} // namespace ngraph
/**
* @brief This transformation is an entry point for nGraph transformations that will be
* executed inside POT.
*/
class ngraph::pass::POTTransformations: public ngraph::pass::FunctionPass {
std::string m_device;
public:
NGRAPH_RTTI_DECLARATION;
explicit POTTransformations(std::string device) : m_device(std::move(device)) {}
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
};

View File

@ -0,0 +1,25 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <memory>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/lstm_cell_decomposition.hpp>
#include "pot_transformations.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::POTTransformations, "POTTransformations", 0);
bool ngraph::pass::POTTransformations::run_on_function(std::shared_ptr<ngraph::Function> f) {
ngraph::pass::Manager manager(get_pass_config());
if (m_device == "CPU") {
// TODO: register CPU passes
// manager.register_pass<ngraph::pass::LSTMCellDecomposition>();
} else {
throw ngraph_error("Device name is unsupported");
}
manager.run_passes(f);
return false;
}