Add transformation pipline for POT (#5328)

This commit is contained in:
Gleb Kazantaev 2021-04-22 19:13:14 +03:00 committed by GitHub
parent cc7ae5e6d1
commit 28cad9e3fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 0 deletions

View File

@ -5,10 +5,14 @@ from .cimport offline_transformations_api_impl_defs as C
from ..inference_engine.ie_api cimport IENetwork
from libcpp cimport bool
from libcpp.string cimport string
def ApplyMOCTransformations(IENetwork network, bool cf):
C.ApplyMOCTransformations(network.impl, cf)
def ApplyPOTTransformations(IENetwork network, string device):
C.ApplyPOTTransformations(network.impl, device)
def ApplyLowLatencyTransformation(IENetwork network):
C.ApplyLowLatencyTransformation(network.impl)

View File

@ -5,6 +5,7 @@
#include "offline_transformations_api_impl.hpp"
#include <moc_transformations.hpp>
#include <pot_transformations.hpp>
#include <pruning.hpp>
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
@ -21,6 +22,12 @@ void InferenceEnginePython::ApplyMOCTransformations(InferenceEnginePython::IENet
manager.run_passes(network.actual->getFunction());
}
void InferenceEnginePython::ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::string device) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::POTTransformations>(std::move(device));
manager.run_passes(network.actual->getFunction());
}
void InferenceEnginePython::ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::LowLatency>();

View File

@ -4,6 +4,8 @@
#pragma once
#include <string>
#include "Python.h"
#include "ie_api_impl.hpp"
@ -11,6 +13,8 @@ namespace InferenceEnginePython {
void ApplyMOCTransformations(InferenceEnginePython::IENetwork network, bool cf);
void ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::string device);
void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network);
void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);

View File

@ -2,11 +2,15 @@
# SPDX-License-Identifier: Apache-2.0
from libcpp cimport bool
from libcpp.string cimport string
from ..inference_engine.ie_api_impl_defs cimport IENetwork
cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEnginePython":
cdef void ApplyMOCTransformations(IENetwork network, bool cf)
cdef void ApplyPOTTransformations(IENetwork network, string device)
cdef void ApplyLowLatencyTransformation(IENetwork network)
cdef void ApplyPruningTransformation(IENetwork network)

View File

@ -0,0 +1,33 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class POTTransformations;
} // namespace pass
} // namespace ngraph
/**
* @brief This transformation is an entry point for nGraph transformations that will be
* executed inside POT.
*/
class ngraph::pass::POTTransformations: public ngraph::pass::FunctionPass {
std::string m_device;
public:
NGRAPH_RTTI_DECLARATION;
explicit POTTransformations(std::string device) : m_device(std::move(device)) {}
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
};

View File

@ -0,0 +1,25 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <memory>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/lstm_cell_decomposition.hpp>
#include "pot_transformations.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::POTTransformations, "POTTransformations", 0);
bool ngraph::pass::POTTransformations::run_on_function(std::shared_ptr<ngraph::Function> f) {
ngraph::pass::Manager manager(get_pass_config());
if (m_device == "CPU") {
// TODO: register CPU passes
// manager.register_pass<ngraph::pass::LSTMCellDecomposition>();
} else {
throw ngraph_error("Device name is unsupported");
}
manager.run_passes(f);
return false;
}