Add transformation pipline for POT (#5328)
This commit is contained in:
parent
cc7ae5e6d1
commit
28cad9e3fb
@ -5,10 +5,14 @@ from .cimport offline_transformations_api_impl_defs as C
|
||||
from ..inference_engine.ie_api cimport IENetwork
|
||||
|
||||
from libcpp cimport bool
|
||||
from libcpp.string cimport string
|
||||
|
||||
def ApplyMOCTransformations(IENetwork network, bool cf):
|
||||
C.ApplyMOCTransformations(network.impl, cf)
|
||||
|
||||
def ApplyPOTTransformations(IENetwork network, string device):
|
||||
C.ApplyPOTTransformations(network.impl, device)
|
||||
|
||||
def ApplyLowLatencyTransformation(IENetwork network):
|
||||
C.ApplyLowLatencyTransformation(network.impl)
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "offline_transformations_api_impl.hpp"
|
||||
|
||||
#include <moc_transformations.hpp>
|
||||
#include <pot_transformations.hpp>
|
||||
#include <pruning.hpp>
|
||||
|
||||
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
|
||||
@ -21,6 +22,12 @@ void InferenceEnginePython::ApplyMOCTransformations(InferenceEnginePython::IENet
|
||||
manager.run_passes(network.actual->getFunction());
|
||||
}
|
||||
|
||||
void InferenceEnginePython::ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::string device) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::POTTransformations>(std::move(device));
|
||||
manager.run_passes(network.actual->getFunction());
|
||||
}
|
||||
|
||||
void InferenceEnginePython::ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::LowLatency>();
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "Python.h"
|
||||
#include "ie_api_impl.hpp"
|
||||
|
||||
@ -11,6 +13,8 @@ namespace InferenceEnginePython {
|
||||
|
||||
void ApplyMOCTransformations(InferenceEnginePython::IENetwork network, bool cf);
|
||||
|
||||
void ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::string device);
|
||||
|
||||
void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network);
|
||||
|
||||
void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);
|
||||
|
@ -2,11 +2,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from libcpp cimport bool
|
||||
from libcpp.string cimport string
|
||||
|
||||
from ..inference_engine.ie_api_impl_defs cimport IENetwork
|
||||
|
||||
cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEnginePython":
|
||||
cdef void ApplyMOCTransformations(IENetwork network, bool cf)
|
||||
|
||||
cdef void ApplyPOTTransformations(IENetwork network, string device)
|
||||
|
||||
cdef void ApplyLowLatencyTransformation(IENetwork network)
|
||||
|
||||
cdef void ApplyPruningTransformation(IENetwork network)
|
||||
|
@ -0,0 +1,33 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class POTTransformations;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @brief This transformation is an entry point for nGraph transformations that will be
|
||||
* executed inside POT.
|
||||
*/
|
||||
|
||||
class ngraph::pass::POTTransformations: public ngraph::pass::FunctionPass {
|
||||
std::string m_device;
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
explicit POTTransformations(std::string device) : m_device(std::move(device)) {}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||
};
|
@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include <transformations/op_conversions/lstm_cell_decomposition.hpp>
|
||||
|
||||
#include "pot_transformations.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::POTTransformations, "POTTransformations", 0);
|
||||
|
||||
bool ngraph::pass::POTTransformations::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
ngraph::pass::Manager manager(get_pass_config());
|
||||
if (m_device == "CPU") {
|
||||
// TODO: register CPU passes
|
||||
// manager.register_pass<ngraph::pass::LSTMCellDecomposition>();
|
||||
} else {
|
||||
throw ngraph_error("Device name is unsupported");
|
||||
}
|
||||
manager.run_passes(f);
|
||||
return false;
|
||||
}
|
Loading…
Reference in New Issue
Block a user