Add transformation pipline for POT (#5328)
This commit is contained in:
parent
cc7ae5e6d1
commit
28cad9e3fb
@ -5,10 +5,14 @@ from .cimport offline_transformations_api_impl_defs as C
|
|||||||
from ..inference_engine.ie_api cimport IENetwork
|
from ..inference_engine.ie_api cimport IENetwork
|
||||||
|
|
||||||
from libcpp cimport bool
|
from libcpp cimport bool
|
||||||
|
from libcpp.string cimport string
|
||||||
|
|
||||||
def ApplyMOCTransformations(IENetwork network, bool cf):
|
def ApplyMOCTransformations(IENetwork network, bool cf):
|
||||||
C.ApplyMOCTransformations(network.impl, cf)
|
C.ApplyMOCTransformations(network.impl, cf)
|
||||||
|
|
||||||
|
def ApplyPOTTransformations(IENetwork network, string device):
|
||||||
|
C.ApplyPOTTransformations(network.impl, device)
|
||||||
|
|
||||||
def ApplyLowLatencyTransformation(IENetwork network):
|
def ApplyLowLatencyTransformation(IENetwork network):
|
||||||
C.ApplyLowLatencyTransformation(network.impl)
|
C.ApplyLowLatencyTransformation(network.impl)
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "offline_transformations_api_impl.hpp"
|
#include "offline_transformations_api_impl.hpp"
|
||||||
|
|
||||||
#include <moc_transformations.hpp>
|
#include <moc_transformations.hpp>
|
||||||
|
#include <pot_transformations.hpp>
|
||||||
#include <pruning.hpp>
|
#include <pruning.hpp>
|
||||||
|
|
||||||
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
|
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
|
||||||
@ -21,6 +22,12 @@ void InferenceEnginePython::ApplyMOCTransformations(InferenceEnginePython::IENet
|
|||||||
manager.run_passes(network.actual->getFunction());
|
manager.run_passes(network.actual->getFunction());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InferenceEnginePython::ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::string device) {
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::POTTransformations>(std::move(device));
|
||||||
|
manager.run_passes(network.actual->getFunction());
|
||||||
|
}
|
||||||
|
|
||||||
void InferenceEnginePython::ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network) {
|
void InferenceEnginePython::ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network) {
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::LowLatency>();
|
manager.register_pass<ngraph::pass::LowLatency>();
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "Python.h"
|
#include "Python.h"
|
||||||
#include "ie_api_impl.hpp"
|
#include "ie_api_impl.hpp"
|
||||||
|
|
||||||
@ -11,6 +13,8 @@ namespace InferenceEnginePython {
|
|||||||
|
|
||||||
void ApplyMOCTransformations(InferenceEnginePython::IENetwork network, bool cf);
|
void ApplyMOCTransformations(InferenceEnginePython::IENetwork network, bool cf);
|
||||||
|
|
||||||
|
void ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::string device);
|
||||||
|
|
||||||
void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network);
|
void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network);
|
||||||
|
|
||||||
void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);
|
void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);
|
||||||
|
@ -2,11 +2,15 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from libcpp cimport bool
|
from libcpp cimport bool
|
||||||
|
from libcpp.string cimport string
|
||||||
|
|
||||||
from ..inference_engine.ie_api_impl_defs cimport IENetwork
|
from ..inference_engine.ie_api_impl_defs cimport IENetwork
|
||||||
|
|
||||||
cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEnginePython":
|
cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEnginePython":
|
||||||
cdef void ApplyMOCTransformations(IENetwork network, bool cf)
|
cdef void ApplyMOCTransformations(IENetwork network, bool cf)
|
||||||
|
|
||||||
|
cdef void ApplyPOTTransformations(IENetwork network, string device)
|
||||||
|
|
||||||
cdef void ApplyLowLatencyTransformation(IENetwork network)
|
cdef void ApplyLowLatencyTransformation(IENetwork network)
|
||||||
|
|
||||||
cdef void ApplyPruningTransformation(IENetwork network)
|
cdef void ApplyPruningTransformation(IENetwork network)
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class POTTransformations;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ngraph
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief This transformation is an entry point for nGraph transformations that will be
|
||||||
|
* executed inside POT.
|
||||||
|
*/
|
||||||
|
|
||||||
|
class ngraph::pass::POTTransformations: public ngraph::pass::FunctionPass {
|
||||||
|
std::string m_device;
|
||||||
|
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
explicit POTTransformations(std::string device) : m_device(std::move(device)) {}
|
||||||
|
|
||||||
|
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||||
|
};
|
@ -0,0 +1,25 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
|
||||||
|
#include <transformations/op_conversions/lstm_cell_decomposition.hpp>
|
||||||
|
|
||||||
|
#include "pot_transformations.hpp"
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::POTTransformations, "POTTransformations", 0);
|
||||||
|
|
||||||
|
bool ngraph::pass::POTTransformations::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||||
|
ngraph::pass::Manager manager(get_pass_config());
|
||||||
|
if (m_device == "CPU") {
|
||||||
|
// TODO: register CPU passes
|
||||||
|
// manager.register_pass<ngraph::pass::LSTMCellDecomposition>();
|
||||||
|
} else {
|
||||||
|
throw ngraph_error("Device name is unsupported");
|
||||||
|
}
|
||||||
|
manager.run_passes(f);
|
||||||
|
return false;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user