[CPU] Add GraphContext & EnforeceBF16 for sub-graphs (#14695)

This commit is contained in:
Tingqian Li 2023-01-09 19:28:48 +08:00 committed by GitHub
parent 80f0ffbb49
commit 29b8c9e7af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
192 changed files with 612 additions and 617 deletions

View File

@ -5,6 +5,7 @@
#include <ie_metric_helpers.hpp> #include <ie_metric_helpers.hpp>
#include <precision_utils.h> #include <precision_utils.h>
#include "exec_network.h" #include "exec_network.h"
#include <low_precision/low_precision.hpp>
#include "async_infer_request.h" #include "async_infer_request.h"
#include "infer_request.h" #include "infer_request.h"
@ -118,7 +119,6 @@ ExecNetwork::ExecNetwork(const InferenceEngine::CNNNetwork &network,
} else { } else {
_callbackExecutor = _taskExecutor; _callbackExecutor = _taskExecutor;
} }
int streams = std::max(1, _cfg.streamExecutorConfig._streams); int streams = std::max(1, _cfg.streamExecutorConfig._streams);
std::vector<Task> tasks; tasks.resize(streams); std::vector<Task> tasks; tasks.resize(streams);
_graphs.resize(streams); _graphs.resize(streams);
@ -177,12 +177,21 @@ ExecNetwork::GraphGuard::Lock ExecNetwork::GetGraph() const {
std::exception_ptr exception; std::exception_ptr exception;
auto makeGraph = [&] { auto makeGraph = [&] {
try { try {
GraphContext::Ptr ctx;
{ {
std::lock_guard<std::mutex> lock{*_mutex.get()}; std::lock_guard<std::mutex> lock{*_mutex.get()};
graphLock._graph.setConfig(_cfg); // disable weights caching if graph was created only once
auto weightsCache =
_cfg.streamExecutorConfig._streams != 1 ? _numaNodesWeights[numaNodeId] : nullptr;
auto isQuantizedFlag =
(_cfg.lpTransformsMode == Config::On) &&
ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(_network.getFunction());
ctx = std::make_shared<GraphContext>(_cfg, extensionManager, weightsCache, _mutex, isQuantizedFlag);
} }
graphLock._graph.CreateGraph(_network, extensionManager, _numaNodesWeights[numaNodeId], _mutex); graphLock._graph.CreateGraph(_network, ctx);
} catch(...) { } catch (...) {
exception = std::current_exception(); exception = std::current_exception();
} }
}; };
@ -198,19 +207,6 @@ ExecNetwork::GraphGuard::Lock ExecNetwork::GetGraph() const {
return graphLock; return graphLock;
} }
void ExecNetwork::setProperty(const std::map<std::string, std::string> &properties) {
{
std::lock_guard<std::mutex> lock{*_mutex.get()};
_cfg.readProperties(properties);
}
for (auto& g : _graphs) {
auto graphLock = GraphGuard::Lock(g);
if (graphLock._graph.IsReady()) {
graphLock._graph.setProperty(properties);
}
}
}
InferenceEngine::IInferRequestInternal::Ptr ExecNetwork::CreateInferRequest() { InferenceEngine::IInferRequestInternal::Ptr ExecNetwork::CreateInferRequest() {
return CreateAsyncInferRequestFromSync<AsyncInferRequest>(); return CreateAsyncInferRequestFromSync<AsyncInferRequest>();
} }
@ -235,7 +231,7 @@ Parameter ExecNetwork::GetConfigLegacy(const std::string &name) const {
IE_THROW() << "No graph was found"; IE_THROW() << "No graph was found";
/* legacy implementation return all the parameters which is actually not correct /* legacy implementation return all the parameters which is actually not correct
* since they are not reconfigurable. Fixed for new API */ * since they are not reconfigurable. Fixed for new API */
Config engConfig = GetGraph()._graph.getProperty(); Config engConfig = GetGraph()._graph.getConfig();
auto option = engConfig._config.find(name); auto option = engConfig._config.find(name);
if (option != engConfig._config.end()) { if (option != engConfig._config.end()) {
return option->second; return option->second;
@ -268,12 +264,12 @@ InferenceEngine::Parameter ExecNetwork::GetMetricLegacy(const std::string &name,
IE_SET_METRIC_RETURN(SUPPORTED_METRICS, metrics); IE_SET_METRIC_RETURN(SUPPORTED_METRICS, metrics);
} else if (name == METRIC_KEY(SUPPORTED_CONFIG_KEYS)) { } else if (name == METRIC_KEY(SUPPORTED_CONFIG_KEYS)) {
std::vector<std::string> configKeys; std::vector<std::string> configKeys;
for (auto && key : graph.getProperty()._config) { for (auto && key : graph.getConfig()._config) {
configKeys.push_back(key.first); configKeys.push_back(key.first);
} }
IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS, configKeys); IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS, configKeys);
} else if (name == METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS)) { } else if (name == METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS)) {
Config engConfig = graph.getProperty(); Config engConfig = graph.getConfig();
auto option = engConfig._config.find(CONFIG_KEY(CPU_THROUGHPUT_STREAMS)); auto option = engConfig._config.find(CONFIG_KEY(CPU_THROUGHPUT_STREAMS));
IE_ASSERT(option != engConfig._config.end()); IE_ASSERT(option != engConfig._config.end());
auto streams = std::stoi(option->second); auto streams = std::stoi(option->second);
@ -290,7 +286,7 @@ InferenceEngine::Parameter ExecNetwork::GetMetric(const std::string &name) const
// @todo Can't we just use local copy (_cfg) instead? // @todo Can't we just use local copy (_cfg) instead?
auto graphLock = GetGraph(); auto graphLock = GetGraph();
const auto& graph = graphLock._graph; const auto& graph = graphLock._graph;
const auto& config = graph.getProperty(); const auto& config = graph.getConfig();
if (isLegacyAPI()) { if (isLegacyAPI()) {
return GetMetricLegacy(name, graph); return GetMetricLegacy(name, graph);

View File

@ -9,6 +9,7 @@
#include "graph.h" #include "graph.h"
#include "extension_mngr.h" #include "extension_mngr.h"
#include "graph_context.h"
#include <threading/ie_thread_local.hpp> #include <threading/ie_thread_local.hpp>
#include <vector> #include <vector>
@ -38,8 +39,6 @@ public:
const ExtensionManager::Ptr &extMgr, const ExtensionManager::Ptr &extMgr,
const std::shared_ptr<InferenceEngine::IInferencePlugin>& plugin); const std::shared_ptr<InferenceEngine::IInferencePlugin>& plugin);
void setProperty(const std::map<std::string, std::string> &properties);
InferenceEngine::Parameter GetConfig(const std::string &name) const override; InferenceEngine::Parameter GetConfig(const std::string &name) const override;
InferenceEngine::Parameter GetMetric(const std::string &name) const override; InferenceEngine::Parameter GetMetric(const std::string &name) const override;

View File

@ -67,27 +67,20 @@ namespace intel_cpu {
typedef std::unordered_set<EdgePtr> edge_cluster_t; typedef std::unordered_set<EdgePtr> edge_cluster_t;
typedef std::vector<edge_cluster_t> edge_clusters_t; typedef std::vector<edge_cluster_t> edge_clusters_t;
dnnl::engine Graph::eng(dnnl::engine::kind::cpu, 0);
Graph::~Graph() { Graph::~Graph() {
CPU_DEBUG_CAP_ENABLE(summary_perf(*this)); CPU_DEBUG_CAP_ENABLE(summary_perf(*this));
} }
template<typename NET> template<typename NET>
void Graph::CreateGraph(NET &net, const ExtensionManager::Ptr& extMgr, void Graph::CreateGraph(NET &net, const GraphContext::CPtr ctx) {
WeightsSharing::Ptr &w_cache, const std::shared_ptr<std::mutex>& mutex) {
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "CreateGraph"); OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "CreateGraph");
if (IsReady()) if (IsReady())
ForgetGraphData(); ForgetGraphData();
// disable weights caching if graph was created only once
weightsCache = config.streamExecutorConfig._streams != 1 ? w_cache : nullptr;
rtParamsCache = std::make_shared<MultiCache>(config.rtCacheCapacity); context = ctx;
sharedMutex = mutex;
rtScratchPad = std::make_shared<DnnlScratchPad>(getEngine());
Replicate(net, extMgr); Replicate(net);
InitGraph(); InitGraph();
@ -96,15 +89,12 @@ void Graph::CreateGraph(NET &net, const ExtensionManager::Ptr& extMgr,
void Graph::CreateGraph(const std::vector<NodePtr> &graphNodes, void Graph::CreateGraph(const std::vector<NodePtr> &graphNodes,
const std::vector<EdgePtr> &graphEdges, const std::vector<EdgePtr> &graphEdges,
WeightsSharing::Ptr &w_cache, const GraphContext::CPtr ctx,
std::string name) { std::string name) {
if (IsReady()) if (IsReady())
ForgetGraphData(); ForgetGraphData();
// disable weights caching if graph was created only once
weightsCache = config.streamExecutorConfig._streams != 1 ? w_cache : nullptr;
rtParamsCache = std::make_shared<MultiCache>(config.rtCacheCapacity); context = ctx;
rtScratchPad = std::make_shared<DnnlScratchPad>(getEngine());
this->_name = std::move(name); this->_name = std::move(name);
this->reuse_io_tensors = false; this->reuse_io_tensors = false;
@ -125,18 +115,13 @@ void Graph::CreateGraph(const std::vector<NodePtr> &graphNodes,
CPU_DEBUG_CAP_ENABLE(serialize(*this)); CPU_DEBUG_CAP_ENABLE(serialize(*this));
} }
template void Graph::CreateGraph(const std::shared_ptr<const ngraph::Function>&, template void Graph::CreateGraph(const std::shared_ptr<const ngraph::Function>&, const GraphContext::CPtr);
const ExtensionManager::Ptr&, WeightsSharing::Ptr&, const std::shared_ptr<std::mutex>& mutex); template void Graph::CreateGraph(const CNNNetwork&, const GraphContext::CPtr);
template void Graph::CreateGraph(const CNNNetwork&,
const ExtensionManager::Ptr&, WeightsSharing::Ptr&, const std::shared_ptr<std::mutex>& mutex);
void Graph::Replicate(const std::shared_ptr<const ov::Model> &subgraph, const ExtensionManager::Ptr& extMgr) { void Graph::Replicate(const std::shared_ptr<const ov::Model> &subgraph) {
this->_name = "subgraph"; this->_name = "subgraph";
this->reuse_io_tensors = false; this->reuse_io_tensors = false;
isQuantizedFlag = (config.lpTransformsMode == Config::On) &&
ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(subgraph);
// Map data object onto producer node // Map data object onto producer node
std::map<std::shared_ptr<ov::Node>, NodePtr> op2node; std::map<std::shared_ptr<ov::Node>, NodePtr> op2node;
@ -156,14 +141,7 @@ void Graph::Replicate(const std::shared_ptr<const ov::Model> &subgraph, const Ex
}; };
for (const auto op : subgraph->get_ordered_ops()) { for (const auto op : subgraph->get_ordered_ops()) {
const NodePtr node {Node::factory().create(op, getEngine(), extMgr, weightsCache)}; const NodePtr node {Node::factory().create(op, context)};
if (isQuantized()) {
node->setQuantizedGraphFlag(true);
}
node->setRuntimeCache(rtParamsCache);
node->setSharedMutex(sharedMutex);
node->setRuntimeScratchPad(rtScratchPad);
graphNodes.push_back(node); graphNodes.push_back(node);
@ -208,15 +186,18 @@ void Graph::Replicate(const std::shared_ptr<const ov::Model> &subgraph, const Ex
const auto nodeName = std::string("stub_") + std::to_string(unusedOutput.get_index()) + "_" + parentNode->getName(); const auto nodeName = std::string("stub_") + std::to_string(unusedOutput.get_index()) + "_" + parentNode->getName();
const NodePtr outNode = std::make_shared<node::Input>(parentNode->outputShapes[port], const NodePtr outNode = std::make_shared<node::Input>(parentNode->outputShapes[port],
parentNode->getOriginalOutputPrecisionAtPort(port), parentNode->getOriginalOutputPrecisionAtPort(port),
nodeName, "Result", getEngine(), weightsCache); nodeName, "Result", context);
EdgePtr edge(new Edge(parentNode, outNode, port, 0)); EdgePtr edge(new Edge(parentNode, outNode, port, 0));
outNode->addEdge(edge); outNode->addEdge(edge);
graphEdges.push_back(edge); graphEdges.push_back(edge);
graphNodes.push_back(outNode); graphNodes.push_back(outNode);
} }
if (getConfig().enforceBF16)
EnforceBF16();
} }
void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& extMgr) { void Graph::Replicate(const CNNNetwork &network) {
OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, taskChain, itt::domains::intel_cpu_LT, "Graph::Replicate", "CNNNetwork"); OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, taskChain, itt::domains::intel_cpu_LT, "Graph::Replicate", "CNNNetwork");
InputsDataMap inputsInfo = network.getInputsInfo(); InputsDataMap inputsInfo = network.getInputsInfo();
@ -228,12 +209,12 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
// we perform model cloning and reshaping on Replicate stage to preserve input/output information // we perform model cloning and reshaping on Replicate stage to preserve input/output information
// it help to perform a graph compilation like in static case // it help to perform a graph compilation like in static case
// and handle dynamic batch case in inference stage with minimal code changes // and handle dynamic batch case in inference stage with minimal code changes
if (config.isNewApi && config.batchLimit > 0) { if (getConfig().isNewApi && getConfig().batchLimit > 0) {
auto upperBoundModel = ngraph::clone_function(*network.getFunction()); auto upperBoundModel = ngraph::clone_function(*network.getFunction());
std::map<ov::Output<ov::Node>, ov::PartialShape> newInShape; std::map<ov::Output<ov::Node>, ov::PartialShape> newInShape;
for (const auto& in : upperBoundModel->get_parameters()) { for (const auto& in : upperBoundModel->get_parameters()) {
auto newShape = in->get_output_partial_shape(0); auto newShape = in->get_output_partial_shape(0);
newShape[0] = config.batchLimit; newShape[0] = getConfig().batchLimit;
newInShape[in] = newShape; newInShape[in] = newShape;
} }
upperBoundModel->reshape(newInShape); upperBoundModel->reshape(newInShape);
@ -247,9 +228,6 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
IE_THROW() << "Function pointer inside CNNNetwork is nullptr"; IE_THROW() << "Function pointer inside CNNNetwork is nullptr";
} }
isQuantizedFlag = (config.lpTransformsMode == Config::On) &&
ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(func);
auto orderedOps = func->get_ordered_ops(); auto orderedOps = func->get_ordered_ops();
// TODO [NM]: unordered_map is preferred from performance perspective. Needs hash for ngraph::Node // TODO [NM]: unordered_map is preferred from performance perspective. Needs hash for ngraph::Node
@ -271,14 +249,7 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
// Replicate All Nodes in topological order // Replicate All Nodes in topological order
for (const auto& op : orderedOps) { for (const auto& op : orderedOps) {
const NodePtr node(Node::factory().create(op, getEngine(), extMgr, weightsCache)); const NodePtr node(Node::factory().create(op, context));
if (isQuantized()) {
node->setQuantizedGraphFlag(true);
}
node->setRuntimeCache(rtParamsCache);
node->setSharedMutex(sharedMutex);
node->setRuntimeScratchPad(rtScratchPad);
graphNodes.push_back(node); graphNodes.push_back(node);
@ -331,19 +302,16 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
const auto nodeName = std::string("stub_") + std::to_string(unusedOutput.get_index()) + "_" + parentNode->getName(); const auto nodeName = std::string("stub_") + std::to_string(unusedOutput.get_index()) + "_" + parentNode->getName();
const NodePtr outNode = std::make_shared<node::Input>(parentNode->outputShapes[port], const NodePtr outNode = std::make_shared<node::Input>(parentNode->outputShapes[port],
parentNode->getOriginalOutputPrecisionAtPort(port), parentNode->getOriginalOutputPrecisionAtPort(port),
nodeName, "Result", getEngine(), weightsCache); nodeName, "Result", context);
EdgePtr edge(new Edge(parentNode, outNode, port, 0)); EdgePtr edge(new Edge(parentNode, outNode, port, 0));
outNode->addEdge(edge); outNode->addEdge(edge);
graphEdges.push_back(edge); graphEdges.push_back(edge);
graphNodes.push_back(outNode); graphNodes.push_back(outNode);
} }
if (config.enforceBF16) if (getConfig().enforceBF16)
EnforceBF16(); EnforceBF16();
if (config.fcSparseWeiDecompressionRate < 1.0f)
setMinSparseRate(config.fcSparseWeiDecompressionRate);
auto hasSubgraphConsumers = [] (const NodePtr& node) -> bool { auto hasSubgraphConsumers = [] (const NodePtr& node) -> bool {
const auto & childEdges = node->getChildEdges(); const auto & childEdges = node->getChildEdges();
return std::any_of(childEdges.begin(), childEdges.end(), return std::any_of(childEdges.begin(), childEdges.end(),
@ -469,6 +437,7 @@ void Graph::InitDescriptors() {
if (inputNode) if (inputNode)
inputNode->withMeanImage(); inputNode->withMeanImage();
} }
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, node->profiling.getSupportedDescriptors); OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, node->profiling.getSupportedDescriptors);
node->getSupportedDescriptors(); node->getSupportedDescriptors();
@ -524,7 +493,7 @@ void Graph::ExtractConstantAndExecutableNodes() {
void Graph::ExecuteConstantNodesOnly() const { void Graph::ExecuteConstantNodesOnly() const {
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::ExecuteConstantNodesOnly"); OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::ExecuteConstantNodesOnly");
dnnl::stream stream(eng); dnnl::stream stream(getEngine());
using shared_memory_ptr = WeightsSharing::SharedMemory::Ptr; using shared_memory_ptr = WeightsSharing::SharedMemory::Ptr;
@ -537,7 +506,7 @@ void Graph::ExecuteConstantNodesOnly() const {
auto edgePtr = node->getChildEdgeAt(i); auto edgePtr = node->getChildEdgeAt(i);
if (edgePtr) { if (edgePtr) {
if (edgePtr->isUseExternalMemory()) { if (edgePtr->isUseExternalMemory()) {
auto ptr = weightsCache->get(edgePtr->name()); auto ptr = context->getWeightsCache()->get(edgePtr->name());
outputs.emplace_back(ptr); outputs.emplace_back(ptr);
if (!ptr->isValid()) if (!ptr->isValid())
hasExternalInvalidEdges = true; hasExternalInvalidEdges = true;
@ -551,7 +520,7 @@ void Graph::ExecuteConstantNodesOnly() const {
}; };
for (const auto &node : constantGraphNodes) { for (const auto &node : constantGraphNodes) {
if (weightsCache) { if (context->getWeightsCache()) {
auto sharedOutputs = acquireSharedOutputs(node); auto sharedOutputs = acquireSharedOutputs(node);
if (std::get<0>(sharedOutputs) || std::get<1>(sharedOutputs)) { if (std::get<0>(sharedOutputs) || std::get<1>(sharedOutputs)) {
@ -636,7 +605,7 @@ void Graph::InitEdges() {
inDesc.getPrecision().name() + "_" + outDesc.getPrecision().name(); inDesc.getPrecision().name() + "_" + outDesc.getPrecision().name();
auto convertNode = std::make_shared<node::Convert>(inDesc.getShape(), inDesc.getPrecision(), outDesc.getPrecision(), auto convertNode = std::make_shared<node::Convert>(inDesc.getShape(), inDesc.getPrecision(), outDesc.getPrecision(),
convertName, this->getEngine(), this->weightsCache); convertName, context);
convertNode->setDescs(inDesc, outDesc); convertNode->setDescs(inDesc, outDesc);
InsertNode(edge, convertNode, true); InsertNode(edge, convertNode, true);
@ -720,7 +689,7 @@ void Graph::AllocateWithReuse() {
auto constNode = std::static_pointer_cast<node::Input>(edge->getParent()); auto constNode = std::static_pointer_cast<node::Input>(edge->getParent());
edge->reuse(std::const_pointer_cast<Memory>(constNode->getMemoryPtr())); edge->reuse(std::const_pointer_cast<Memory>(constNode->getMemoryPtr()));
} else { } else {
edge->externalAllocate(weightsCache); edge->externalAllocate(context->getWeightsCache());
} }
erase = true; erase = true;
} }
@ -790,7 +759,7 @@ void Graph::AllocateWithReuse() {
MemorySolver staticMemSolver(definedBoxes); MemorySolver staticMemSolver(definedBoxes);
size_t total_size = static_cast<size_t>(staticMemSolver.solve()) * alignment; size_t total_size = static_cast<size_t>(staticMemSolver.solve()) * alignment;
memWorkspace = std::make_shared<Memory>(eng); memWorkspace = std::make_shared<Memory>(getEngine());
memWorkspace->Create(DnnlBlockedMemoryDesc(InferenceEngine::Precision::I8, Shape(InferenceEngine::SizeVector{total_size}))); memWorkspace->Create(DnnlBlockedMemoryDesc(InferenceEngine::Precision::I8, Shape(InferenceEngine::SizeVector{total_size})));
if (edge_clusters.empty()) if (edge_clusters.empty())
@ -935,15 +904,15 @@ void Graph::PushInputData(const std::string& name, const InferenceEngine::Blob::
if (ext_data_ptr != inter_data_ptr) { if (ext_data_ptr != inter_data_ptr) {
auto ext_tdesc = MemoryDescUtils::convertToDnnlBlockedMemoryDesc(in->getTensorDesc()); auto ext_tdesc = MemoryDescUtils::convertToDnnlBlockedMemoryDesc(in->getTensorDesc());
Memory ext_mem(eng); Memory ext_mem(getEngine());
ext_mem.Create(ext_tdesc, ext_data_ptr, false); ext_mem.Create(ext_tdesc, ext_data_ptr, false);
// branch for handling dynamic batch feature in new API // branch for handling dynamic batch feature in new API
if (getProperty().isNewApi && getProperty().batchLimit > 0 && ext_mem.getStaticDims()[0] != childEdge->getMemory().getStaticDims()[0]) { if (getConfig().isNewApi && getConfig().batchLimit > 0 && ext_mem.getStaticDims()[0] != childEdge->getMemory().getStaticDims()[0]) {
auto newDims = childEdge->getMemory().getStaticDims(); auto newDims = childEdge->getMemory().getStaticDims();
newDims[0] = ext_mem.getStaticDims()[0]; newDims[0] = ext_mem.getStaticDims()[0];
Memory tmpMem(eng); Memory tmpMem(getEngine());
auto newDesc = childEdge->getMemory().getDesc().cloneWithNewDims(newDims, true); auto newDesc = childEdge->getMemory().getDesc().cloneWithNewDims(newDims, true);
tmpMem.Create(newDesc, childEdge->getMemory().GetData(), false); tmpMem.Create(newDesc, childEdge->getMemory().GetData(), false);
@ -1006,7 +975,7 @@ void Graph::PullOutputData(BlobMap &out) {
if (expectedDesc.getLayout() == InferenceEngine::Layout::BLOCKED) { if (expectedDesc.getLayout() == InferenceEngine::Layout::BLOCKED) {
expectedDesc = TensorDesc(expectedDesc.getPrecision(), expectedDesc.getLayout()); expectedDesc = TensorDesc(expectedDesc.getPrecision(), expectedDesc.getLayout());
} }
if (getProperty().isNewApi && getProperty().batchLimit > 0) { if (getConfig().isNewApi && getConfig().batchLimit > 0) {
outDims[0] = node->batchToProcess(); outDims[0] = node->batchToProcess();
} }
out[name]->setShape(outDims); out[name]->setShape(outDims);
@ -1020,7 +989,7 @@ void Graph::PullOutputData(BlobMap &out) {
auto srcPrec = actualDesc.getPrecision(); auto srcPrec = actualDesc.getPrecision();
auto dstPrec = expectedDesc.getPrecision(); auto dstPrec = expectedDesc.getPrecision();
if ((getProperty().isNewApi && !getProperty().batchLimit) && srcPrec == dstPrec && ext_blob->byteSize() != intr_blob.GetSize()) if ((getConfig().isNewApi && !getConfig().batchLimit) && srcPrec == dstPrec && ext_blob->byteSize() != intr_blob.GetSize())
IE_THROW() << "Output blob byte size is not equal network output byte size (" IE_THROW() << "Output blob byte size is not equal network output byte size ("
<< ext_blob->byteSize() << "!=" << intr_blob.GetSize() << ")."; << ext_blob->byteSize() << "!=" << intr_blob.GetSize() << ").";
@ -1036,15 +1005,15 @@ void Graph::PullOutputData(BlobMap &out) {
auto outBlobDesc = expectedDesc.getLayout() == InferenceEngine::Layout::ANY auto outBlobDesc = expectedDesc.getLayout() == InferenceEngine::Layout::ANY
? DnnlBlockedMemoryDesc(expectedDesc.getPrecision(), Shape(expectedDesc.getDims())) ? DnnlBlockedMemoryDesc(expectedDesc.getPrecision(), Shape(expectedDesc.getDims()))
: MemoryDescUtils::convertToDnnlBlockedMemoryDesc(expectedDesc); : MemoryDescUtils::convertToDnnlBlockedMemoryDesc(expectedDesc);
Memory outBloMem(eng); Memory outBloMem(getEngine());
outBloMem.Create(outBlobDesc, ext_blob_ptr, false); outBloMem.Create(outBlobDesc, ext_blob_ptr, false);
// branch for handling dynamic batch feature in new API // branch for handling dynamic batch feature in new API
if (getProperty().isNewApi && getProperty().batchLimit > 0 && outBloMem.getStaticDims()[0] != intr_blob.getStaticDims()[0]) { if (getConfig().isNewApi && getConfig().batchLimit > 0 && outBloMem.getStaticDims()[0] != intr_blob.getStaticDims()[0]) {
auto newDims = intr_blob.getStaticDims(); auto newDims = intr_blob.getStaticDims();
newDims[0] = outBloMem.getStaticDims()[0]; newDims[0] = outBloMem.getStaticDims()[0];
Memory tmpMem(eng); Memory tmpMem(getEngine());
auto newDesc = intr_blob.getDesc().cloneWithNewDims(newDims, true); auto newDesc = intr_blob.getDesc().cloneWithNewDims(newDims, true);
tmpMem.Create(newDesc, intr_blob.GetData(), false); tmpMem.Create(newDesc, intr_blob.GetData(), false);
@ -1056,8 +1025,8 @@ void Graph::PullOutputData(BlobMap &out) {
size_t size_to_copy = intr_blob.GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount(); size_t size_to_copy = intr_blob.GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
// TODO: Should we support InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_LIMIT??? // TODO: Should we support InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_LIMIT???
// TODO [DS]: phase 2: should we support this behaviour? Looks obsolete in the dynamic shapes paradigm // TODO [DS]: phase 2: should we support this behaviour? Looks obsolete in the dynamic shapes paradigm
if (getProperty().batchLimit) { if (getConfig().batchLimit) {
if (node->isDynamicNode() && !getProperty().isNewApi) { if (node->isDynamicNode() && !getConfig().isNewApi) {
IE_THROW(NotImplemented) << "[DS] not implemented dynamic batch for node with dynamic shape"; IE_THROW(NotImplemented) << "[DS] not implemented dynamic batch for node with dynamic shape";
} }
int MB_to_process = node->batchToProcess(); int MB_to_process = node->batchToProcess();
@ -1070,11 +1039,11 @@ void Graph::PullOutputData(BlobMap &out) {
} }
void Graph::InferStatic(InferRequestBase* request) { void Graph::InferStatic(InferRequestBase* request) {
dnnl::stream stream(eng); dnnl::stream stream(getEngine());
for (const auto& node : executableGraphNodes) { for (const auto& node : executableGraphNodes) {
VERBOSE(node, config.debugCaps.verbose); VERBOSE(node, getConfig().debugCaps.verbose);
PERF(node, config.collectPerfCounters); PERF(node, getConfig().collectPerfCounters);
if (request) if (request)
request->ThrowIfCanceled(); request->ThrowIfCanceled();
@ -1083,7 +1052,7 @@ void Graph::InferStatic(InferRequestBase* request) {
} }
void Graph::InferDynamic(InferRequestBase* request) { void Graph::InferDynamic(InferRequestBase* request) {
dnnl::stream stream(eng); dnnl::stream stream(getEngine());
std::set<size_t> syncIndsWorkSet; std::set<size_t> syncIndsWorkSet;
for (const auto& nodeIndx : syncNodesInds) { for (const auto& nodeIndx : syncNodesInds) {
@ -1160,8 +1129,8 @@ void Graph::InferDynamic(InferRequestBase* request) {
updateNodes(stopIndx); updateNodes(stopIndx);
for (; inferCounter < stopIndx; ++inferCounter) { for (; inferCounter < stopIndx; ++inferCounter) {
auto& node = executableGraphNodes[inferCounter]; auto& node = executableGraphNodes[inferCounter];
VERBOSE(node, config.debugCaps.verbose); VERBOSE(node, getConfig().debugCaps.verbose);
PERF(node, config.collectPerfCounters); PERF(node, getConfig().collectPerfCounters);
if (request) if (request)
request->ThrowIfCanceled(); request->ThrowIfCanceled();
@ -1171,7 +1140,8 @@ void Graph::InferDynamic(InferRequestBase* request) {
} }
inline void Graph::ExecuteNode(const NodePtr& node, const dnnl::stream& stream) const { inline void Graph::ExecuteNode(const NodePtr& node, const dnnl::stream& stream) const {
DUMP(node, config.debugCaps, infer_count); DUMP(node, getConfig().debugCaps, infer_count);
OV_ITT_SCOPED_TASK(itt::domains::intel_cpu, node->profiling.execute); OV_ITT_SCOPED_TASK(itt::domains::intel_cpu, node->profiling.execute);
if (node->isDynamicNode()) { if (node->isDynamicNode()) {
@ -1316,22 +1286,6 @@ void Graph::GetPerfData(std::map<std::string, InferenceEngine::InferenceEnginePr
} }
} }
void Graph::setConfig(const Config &cfg) {
config = cfg;
}
const Config& Graph::getConfig() const {
return config;
}
void Graph::setProperty(const std::map<std::string, std::string>& properties) {
config.readProperties(properties);
}
Config Graph::getProperty() const {
return config;
}
void Graph::RemoveEdge(EdgePtr& edge) { void Graph::RemoveEdge(EdgePtr& edge) {
for (auto it = graphEdges.begin(); it != graphEdges.end(); it++) { for (auto it = graphEdges.begin(); it != graphEdges.end(); it++) {
if ((*it) == edge) { if ((*it) == edge) {
@ -1479,7 +1433,7 @@ void Graph::RemoveDroppedEdges() {
NodePtr Graph::InsertReorder(EdgePtr edge, std::string layerName, const MemoryDesc& inDesc, const MemoryDesc& outDesc, NodePtr Graph::InsertReorder(EdgePtr edge, std::string layerName, const MemoryDesc& inDesc, const MemoryDesc& outDesc,
bool isOptimized, const std::vector<int> & src_perm) { bool isOptimized, const std::vector<int> & src_perm) {
NodePtr newReorder(new node::Reorder(layerName, getEngine(), weightsCache)); NodePtr newReorder(new node::Reorder(layerName, context));
auto *reorderPtr = dynamic_cast<node::Reorder *>(newReorder.get()); auto *reorderPtr = dynamic_cast<node::Reorder *>(newReorder.get());
if (reorderPtr == nullptr) { if (reorderPtr == nullptr) {
IE_THROW() << "Graph::InsertReorder: Cannot cast to Reorder"; IE_THROW() << "Graph::InsertReorder: Cannot cast to Reorder";
@ -1529,12 +1483,6 @@ bool Graph::InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPo
afterNode->getParent()->childEdges.push_back(afterNode); afterNode->getParent()->childEdges.push_back(afterNode);
child->parentEdges.push_back(afterNode); child->parentEdges.push_back(afterNode);
if (isQuantized()) {
node->setQuantizedGraphFlag(true);
}
node->setRuntimeCache(rtParamsCache);
node->setRuntimeScratchPad(rtScratchPad);
if (initNode) { if (initNode) {
node->getSupportedDescriptors(); node->getSupportedDescriptors();
node->initSupportedPrimitiveDescriptors(); node->initSupportedPrimitiveDescriptors();
@ -1553,7 +1501,7 @@ bool Graph::InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPo
void Graph::EnforceBF16() { void Graph::EnforceBF16() {
// Floating point parts of FP32 + INT8 or FP32 + BIN mixed precision models will be executed in BF16 precision // Floating point parts of FP32 + INT8 or FP32 + BIN mixed precision models will be executed in BF16 precision
// only if enforceBF16 flag was set manually because current performance is not good enough to enable it by default // only if enforceBF16 flag was set manually because current performance is not good enough to enable it by default
if (!implication(isQuantized(), config.manualEnforceBF16)) if (!implication(context->isGraphQuantized(), getConfig().manualEnforceBF16))
return; return;
std::function<void(const NodePtr&, std::unordered_set<NodePtr>& skipNodes)> searchForNodesToSkip; std::function<void(const NodePtr&, std::unordered_set<NodePtr>& skipNodes)> searchForNodesToSkip;
@ -1594,6 +1542,9 @@ void Graph::EnforceBF16() {
continue; continue;
if (node->getType() != Type::Input && node->getType() != Type::Output) { if (node->getType() != Type::Input && node->getType() != Type::Output) {
DEBUG_LOG("#", node->getExecIndex(),
" ", node->getName(),
" is enforced to use BF16\n");
for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) { for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) {
const auto &parent = node->getParentEdgesAtPort(i)[0]->getParent(); const auto &parent = node->getParentEdgesAtPort(i)[0]->getParent();
/* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing. /* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing.
@ -1616,14 +1567,6 @@ void Graph::EnforceBF16() {
} }
} }
void Graph::setMinSparseRate(float minSparseRate) {
for (const auto &node : graphNodes) {
if (auto fcNodePtr = std::dynamic_pointer_cast<node::FullyConnected>(node)) {
fcNodePtr->setMinSparseRate(minSparseRate);
}
}
}
std::shared_ptr<ngraph::Function> Graph::dump() const { std::shared_ptr<ngraph::Function> Graph::dump() const {
return dump_graph_as_ie_ngraph_net(*this); return dump_graph_as_ie_ngraph_net(*this);
} }

View File

@ -12,6 +12,7 @@
#include "edge.h" #include "edge.h"
#include "cache/multi_cache.h" #include "cache/multi_cache.h"
#include "dnnl_scratch_pad.h" #include "dnnl_scratch_pad.h"
#include "graph_context.h"
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
@ -27,7 +28,6 @@ class InferRequest;
class Graph { class Graph {
public: public:
typedef std::shared_ptr<Graph> Ptr; typedef std::shared_ptr<Graph> Ptr;
WeightsSharing::Ptr weightsCache;
enum class Status { enum class Status {
NotReady = 0, NotReady = 0,
@ -42,21 +42,16 @@ public:
return (status != Status::NotReady); return (status != Status::NotReady);
} }
void setConfig(const Config &cfg); const Config & getConfig() const {
const Config& getConfig() const; return context->getConfig();
}
void setProperty(const std::map<std::string, std::string> &properties);
Config getProperty() const;
template<typename NET> template<typename NET>
void CreateGraph(NET &network, void CreateGraph(NET &network, const GraphContext::CPtr ctx);
const ExtensionManager::Ptr& extMgr,
WeightsSharing::Ptr &w_cache,
const std::shared_ptr<std::mutex>& mutex);
void CreateGraph(const std::vector<NodePtr> &graphNodes, void CreateGraph(const std::vector<NodePtr> &graphNodes,
const std::vector<EdgePtr> &graphEdges, const std::vector<EdgePtr> &graphEdges,
WeightsSharing::Ptr &w_cache, const GraphContext::CPtr ctx,
std::string name); std::string name);
bool hasMeanImageFor(const std::string& name) { bool hasMeanImageFor(const std::string& name) {
@ -111,7 +106,11 @@ public:
} }
dnnl::engine getEngine() const { dnnl::engine getEngine() const {
return eng; return context->getEngine();
}
GraphContext::CPtr getGraphContext() const {
return context;
} }
void GetPerfData(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfMap) const; void GetPerfData(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfMap) const;
@ -187,10 +186,6 @@ public:
void SortTopologically(); void SortTopologically();
bool isQuantized() const {
return isQuantizedFlag;
}
bool hasDynamicInput() const { bool hasDynamicInput() const {
return graphHasDynamicInput; return graphHasDynamicInput;
} }
@ -200,7 +195,6 @@ protected:
void ForgetGraphData() { void ForgetGraphData() {
status = Status::NotReady; status = Status::NotReady;
eng = dnnl::engine(dnnl::engine::kind::cpu, 0);
inputNodesMap.clear(); inputNodesMap.clear();
outputNodesMap.clear(); outputNodesMap.clear();
@ -210,7 +204,6 @@ protected:
syncNodesInds.clear(); syncNodesInds.clear();
} }
Status status { Status::NotReady }; Status status { Status::NotReady };
Config config;
// For dumping purposes. -1 - no counting, all other positive // For dumping purposes. -1 - no counting, all other positive
// values mean increment it within each Infer() call // values mean increment it within each Infer() call
@ -226,13 +219,10 @@ protected:
std::map<std::string, NormalizePreprocess> _normalizePreprocMap; std::map<std::string, NormalizePreprocess> _normalizePreprocMap;
std::string _name; std::string _name;
bool isQuantizedFlag = false;
bool graphHasDynamicInput = false; bool graphHasDynamicInput = false;
static dnnl::engine eng; void Replicate(const InferenceEngine::CNNNetwork &network);
void Replicate(const std::shared_ptr<const ov::Model> &subgraph);
void Replicate(const InferenceEngine::CNNNetwork &network, const ExtensionManager::Ptr& extMgr);
void Replicate(const std::shared_ptr<const ov::Model> &subgraph, const ExtensionManager::Ptr& extMgr);
void InitGraph(); void InitGraph();
void InitNodes(); void InitNodes();
void InitDescriptors(); void InitDescriptors();
@ -263,13 +253,11 @@ private:
std::vector<NodePtr> constantGraphNodes; std::vector<NodePtr> constantGraphNodes;
std::vector<NodePtr> executableGraphNodes; std::vector<NodePtr> executableGraphNodes;
MultiCachePtr rtParamsCache;
std::shared_ptr<std::mutex> sharedMutex = nullptr;
DnnlScratchPadPtr rtScratchPad;
std::unordered_map<Node*, size_t> syncNodesInds; std::unordered_map<Node*, size_t> syncNodesInds;
GraphContext::CPtr context;
void EnforceBF16(); void EnforceBF16();
void setMinSparseRate(float minSparseRate);
}; };
} // namespace intel_cpu } // namespace intel_cpu

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <dnnl_types.h>
#include "graph_context.h"
namespace ov {
namespace intel_cpu {
dnnl::engine GraphContext::eng(dnnl::engine::kind::cpu, 0);
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,82 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "cache/multi_cache.h"
#include "config.h"
#include "dnnl_scratch_pad.h"
#include "extension_mngr.h"
#include "weights_cache.hpp"
namespace ov {
namespace intel_cpu {
class GraphContext {
public:
typedef std::shared_ptr<GraphContext> Ptr;
typedef std::shared_ptr<const GraphContext> CPtr;
GraphContext(const Config& config,
ExtensionManager::Ptr extensionManager,
WeightsSharing::Ptr w_cache,
std::shared_ptr<std::mutex> sharedMutex,
bool isGraphQuantized)
: config(config),
extensionManager(extensionManager),
weightsCache(w_cache),
sharedMutex(sharedMutex),
isGraphQuantizedFlag(isGraphQuantized) {
rtParamsCache = std::make_shared<MultiCache>(config.rtCacheCapacity);
rtScratchPad = std::make_shared<DnnlScratchPad>(eng);
}
const Config& getConfig() const {
return config;
}
ExtensionManager::Ptr getExtensionManager() const {
return extensionManager;
}
WeightsSharing::Ptr getWeightsCache() const {
return weightsCache;
}
std::shared_ptr<std::mutex> getSharedMutex() const {
return sharedMutex;
}
MultiCachePtr getParamsCache() const {
return rtParamsCache;
}
DnnlScratchPadPtr getScratchPad() const {
return rtScratchPad;
}
dnnl::engine getEngine() const {
return eng;
}
bool isGraphQuantized() const {
return isGraphQuantizedFlag;
}
private:
Config config; // network-level config
ExtensionManager::Ptr extensionManager;
WeightsSharing::Ptr weightsCache; // per NUMA node caches for sharing weights data
std::shared_ptr<std::mutex> sharedMutex; // mutex for protection of type-relaxed Op in clone_model()
MultiCachePtr rtParamsCache; // primitive cache
DnnlScratchPadPtr rtScratchPad; // scratch pad
bool isGraphQuantizedFlag = false;
static dnnl::engine eng; // onednn engine (singleton)
};
} // namespace intel_cpu
} // namespace ov

View File

@ -2206,10 +2206,10 @@ void GraphOptimizer::reshapeRnnSeq(Graph &graph) {
parentNode->getOutputShapeAtPort(0).toPartialShape()), secondInput); parentNode->getOutputShapeAtPort(0).toPartialShape()), secondInput);
unsqueeze->set_friendly_name(parentNode->getName() + "_abc_a1bc_" + std::to_string(j)); unsqueeze->set_friendly_name(parentNode->getName() + "_abc_a1bc_" + std::to_string(j));
const auto cpuUnsqueeze = std::make_shared<Reshape>(unsqueeze, graph.getEngine(), graph.weightsCache); const auto cpuUnsqueeze = std::make_shared<Reshape>(unsqueeze, graph.getGraphContext());
graph.InsertNode(parentNode, childNode, cpuUnsqueeze, edge->getInputNum(), edge->getOutputNum(), false); graph.InsertNode(parentNode, childNode, cpuUnsqueeze, edge->getInputNum(), edge->getOutputNum(), false);
const auto cpuConstant = std::make_shared<node::Input>(secondInput, graph.getEngine(), graph.weightsCache); const auto cpuConstant = std::make_shared<node::Input>(secondInput, graph.getGraphContext());
EdgePtr newEdge(new Edge(cpuConstant, cpuUnsqueeze, 0, 1)); EdgePtr newEdge(new Edge(cpuConstant, cpuUnsqueeze, 0, 1));
cpuUnsqueeze->addEdge(newEdge); cpuUnsqueeze->addEdge(newEdge);
auto &graphEdges = graph.GetEdges(); auto &graphEdges = graph.GetEdges();

View File

@ -160,7 +160,7 @@ void InferRequestBase::InferImpl() {
if (graph->hasDynamicInput()) { if (graph->hasDynamicInput()) {
redefineMemoryForInputNodes(); redefineMemoryForInputNodes();
} else if (graph->getProperty().isNewApi && graph->getProperty().batchLimit > 0) { } else if (graph->getConfig().isNewApi && graph->getConfig().batchLimit > 0) {
const auto batch = _inputs.begin()->second->getTensorDesc().getDims()[0]; const auto batch = _inputs.begin()->second->getTensorDesc().getDims()[0];
SetBatch(batch); SetBatch(batch);
} }
@ -358,10 +358,10 @@ void LegacyInferRequest::initBlobs() {
} }
void LegacyInferRequest::SetBatch(int new_batch) { void LegacyInferRequest::SetBatch(int new_batch) {
if (!graph->getProperty().enableDynamicBatch) if (!graph->getConfig().enableDynamicBatch)
IE_THROW() << "Dynamic batch is not enabled."; IE_THROW() << "Dynamic batch is not enabled.";
if (new_batch < 1 || new_batch > graph->getProperty().batchLimit) { if (new_batch < 1 || new_batch > graph->getConfig().batchLimit) {
IE_THROW() << "Invalid dynamic batch size " << new_batch << IE_THROW() << "Invalid dynamic batch size " << new_batch <<
" for this request."; " for this request.";
} }
@ -433,7 +433,7 @@ void LegacyInferRequest::SetBlob(const std::string& name, const InferenceEngine:
auto pBlobDesc = MemoryDescUtils::interpretAsBlobDesc(graph->getInputNodeByName(name)->getChildEdgesAtPort(0)[0]->getMemory()); auto pBlobDesc = MemoryDescUtils::interpretAsBlobDesc(graph->getInputNodeByName(name)->getChildEdgesAtPort(0)[0]->getMemory());
if (data->getTensorDesc() == pBlobDesc && if (data->getTensorDesc() == pBlobDesc &&
graph->_normalizePreprocMap.find(name) == graph->_normalizePreprocMap.end() && !graph->getProperty().batchLimit) { graph->_normalizePreprocMap.find(name) == graph->_normalizePreprocMap.end() && !graph->getConfig().batchLimit) {
externalPtr[name] = data->buffer(); externalPtr[name] = data->buffer();
} else if (externalPtr.find(name) != externalPtr.end()) { } else if (externalPtr.find(name) != externalPtr.end()) {
externalPtr.erase(name); externalPtr.erase(name);
@ -467,7 +467,7 @@ void LegacyInferRequest::SetBlob(const std::string& name, const InferenceEngine:
auto pBlobDesc = MemoryDescUtils::interpretAsBlobDesc(graph->getOutputNodeByName(name)->getParentEdgesAtPort(0)[0]->getMemory()); auto pBlobDesc = MemoryDescUtils::interpretAsBlobDesc(graph->getOutputNodeByName(name)->getParentEdgesAtPort(0)[0]->getMemory());
if (data->getTensorDesc() == pBlobDesc && if (data->getTensorDesc() == pBlobDesc &&
!graph->getProperty().batchLimit) { !graph->getConfig().batchLimit) {
externalPtr[name] = data->buffer(); externalPtr[name] = data->buffer();
} else if (externalPtr.find(name) != externalPtr.end()) { } else if (externalPtr.find(name) != externalPtr.end()) {
externalPtr.erase(name); externalPtr.erase(name);
@ -509,7 +509,7 @@ InferenceEngine::Blob::Ptr LegacyInferRequest::GetBlob(const std::string& name)
_inputs[name] = make_blob_with_precision(desc); _inputs[name] = make_blob_with_precision(desc);
_inputs[name]->allocate(); _inputs[name]->allocate();
if (pBlobDesc == desc && if (pBlobDesc == desc &&
graph->_normalizePreprocMap.find(name) == graph->_normalizePreprocMap.end() && !graph->getProperty().batchLimit) { graph->_normalizePreprocMap.find(name) == graph->_normalizePreprocMap.end() && !graph->getConfig().batchLimit) {
externalPtr[name] = _inputs[name]->buffer(); externalPtr[name] = _inputs[name]->buffer();
} }
} }
@ -571,7 +571,7 @@ InferenceEngine::Blob::Ptr LegacyInferRequest::GetBlob(const std::string& name)
} }
_outputs[name] = data; _outputs[name] = data;
if (!externalPtr.count(name) && data->getTensorDesc() == pBlobDesc && !graph->getProperty().batchLimit) { if (!externalPtr.count(name) && data->getTensorDesc() == pBlobDesc && !graph->getConfig().batchLimit) {
externalPtr[name] = data->buffer(); externalPtr[name] = data->buffer();
} }
} }
@ -627,11 +627,11 @@ void InferRequest::initBlobs() {
} }
void InferRequest::SetBatch(int new_batch) { void InferRequest::SetBatch(int new_batch) {
if (!graph->getProperty().batchLimit || modelInputsMap.begin()->second->get_output_partial_shape(0).is_static()) { if (!graph->getConfig().batchLimit || modelInputsMap.begin()->second->get_output_partial_shape(0).is_static()) {
IE_THROW() << "Can't set batch for model that can't be executed via legacy dynamic batch or for static model"; IE_THROW() << "Can't set batch for model that can't be executed via legacy dynamic batch or for static model";
} }
if (new_batch < 1 || new_batch > graph->getProperty().batchLimit) { if (new_batch < 1 || new_batch > graph->getConfig().batchLimit) {
IE_THROW() << "Can't set batch that is bigger than upper bound"; IE_THROW() << "Can't set batch that is bigger than upper bound";
} }
@ -704,7 +704,7 @@ void InferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob:
blobDesc.getDims()); blobDesc.getDims());
} }
if (actualDesc->isCompatible(MemoryDescUtils::convertToCpuBlockedMemoryDesc(blobDesc)) && if (actualDesc->isCompatible(MemoryDescUtils::convertToCpuBlockedMemoryDesc(blobDesc)) &&
graph->_normalizePreprocMap.find(name) == graph->_normalizePreprocMap.end() && !graph->getProperty().batchLimit) { graph->_normalizePreprocMap.find(name) == graph->_normalizePreprocMap.end() && !graph->getConfig().batchLimit) {
externalPtr[name] = data->buffer(); externalPtr[name] = data->buffer();
} else if (externalPtr.find(name) != externalPtr.end()) { } else if (externalPtr.find(name) != externalPtr.end()) {
externalPtr.erase(name); externalPtr.erase(name);
@ -736,7 +736,7 @@ void InferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob:
} }
const auto &desc = graph->getOutputNodeByName(name)->getParentEdgesAtPort(0)[0]->getMemory().getDesc(); const auto &desc = graph->getOutputNodeByName(name)->getParentEdgesAtPort(0)[0]->getMemory().getDesc();
if (!isDynamic && blobDesc == MemoryDescUtils::convertToTensorDesc(desc) && !graph->getProperty().batchLimit) { if (!isDynamic && blobDesc == MemoryDescUtils::convertToTensorDesc(desc) && !graph->getConfig().batchLimit) {
externalPtr[name] = data->buffer(); externalPtr[name] = data->buffer();
} else if (externalPtr.find(name) != externalPtr.end()) { } else if (externalPtr.find(name) != externalPtr.end()) {
externalPtr.erase(name); externalPtr.erase(name);
@ -784,7 +784,7 @@ InferenceEngine::Blob::Ptr InferRequest::GetBlob(const std::string& name) {
if (!isDynamic && if (!isDynamic &&
desc == MemoryDescUtils::convertToTensorDesc(graph->getInputNodeByName(name)->getChildEdgesAtPort(0)[0]->getMemory().getDesc()) && desc == MemoryDescUtils::convertToTensorDesc(graph->getInputNodeByName(name)->getChildEdgesAtPort(0)[0]->getMemory().getDesc()) &&
graph->_normalizePreprocMap.find(name) == graph->_normalizePreprocMap.end() && !graph->getProperty().batchLimit) { graph->_normalizePreprocMap.find(name) == graph->_normalizePreprocMap.end() && !graph->getConfig().batchLimit) {
externalPtr[name] = _inputs[name]->buffer(); externalPtr[name] = _inputs[name]->buffer();
} }
} else { } else {
@ -841,7 +841,7 @@ InferenceEngine::Blob::Ptr InferRequest::GetBlob(const std::string& name) {
_outputs[name] = data; _outputs[name] = data;
if (!isDynamic && !externalPtr.count(name) && if (!isDynamic && !externalPtr.count(name) &&
data->getTensorDesc() == MemoryDescUtils::convertToTensorDesc(output->second->getParentEdgesAtPort(0)[0]->getMemory().getDesc()) && data->getTensorDesc() == MemoryDescUtils::convertToTensorDesc(output->second->getParentEdgesAtPort(0)[0]->getMemory().getDesc()) &&
!graph->getProperty().batchLimit) { !graph->getConfig().batchLimit) {
externalPtr[name] = data->buffer(); externalPtr[name] = data->buffer();
} }
} else { } else {

View File

@ -77,10 +77,19 @@ Node::NodesFactory & Node::factory() {
return factoryInstance; return factoryInstance;
} }
Node::Node(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &w_cache, const ShapeInferFactory& shapeInferFactory) Node::Node(const std::shared_ptr<ngraph::Node>& op,
: selectedPrimitiveDescriptorIndex(-1), permanent(false), temporary(false), constant(ConstantType::Unknown), const GraphContext::CPtr ctx,
weightCache(w_cache), engine(eng), name(op->get_friendly_name()), typeStr(op->get_type_name()), const ShapeInferFactory& shapeInferFactory)
type(TypeFromName(op->get_type_name())), profiling(op->get_friendly_name()) { : selectedPrimitiveDescriptorIndex(-1),
permanent(false),
temporary(false),
constant(ConstantType::Unknown),
context(ctx),
engine(ctx->getEngine()),
name(op->get_friendly_name()),
typeStr(op->get_type_name()),
type(TypeFromName(op->get_type_name())),
profiling(op->get_friendly_name()) {
algorithm = Algorithm::Default; algorithm = Algorithm::Default;
fusingPort = -1; fusingPort = -1;
const std::string errorPrefix = "Ngraph operation " + std::string(op->get_type_name()) + " with name " + op->get_friendly_name(); const std::string errorPrefix = "Ngraph operation " + std::string(op->get_type_name()) + " with name " + op->get_friendly_name();
@ -170,10 +179,18 @@ Node::Node(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, Wei
} }
} }
Node::Node(const std::string& type, const std::string& name, const dnnl::engine& eng, WeightsSharing::Ptr &w_cache) Node::Node(const std::string& type, const std::string& name, const GraphContext::CPtr ctx)
: selectedPrimitiveDescriptorIndex(-1), permanent(false), temporary(false), constant(ConstantType::Unknown), : selectedPrimitiveDescriptorIndex(-1),
weightCache(w_cache), engine(eng), fusingPort(-1), name(name), typeStr(type), permanent(false),
type(TypeFromName(type)), profiling(name) { temporary(false),
constant(ConstantType::Unknown),
context(ctx),
engine(ctx->getEngine()),
fusingPort(-1),
name(name),
typeStr(type),
type(TypeFromName(type)),
profiling(name) {
// TODO [NM]: What about filling inDims and outDims? // TODO [NM]: What about filling inDims and outDims?
} }
@ -795,6 +812,7 @@ void Node::prepareMemory(const std::vector<DnnlMemoryDescPtr>& intDescs) {
}; };
MemoryPtr ptr; MemoryPtr ptr;
auto weightCache = context->getWeightsCache();
if (weightCache != nullptr) { if (weightCache != nullptr) {
const uint64_t data_hash = weightCache->GetHashFunc().hash( const uint64_t data_hash = weightCache->GetHashFunc().hash(
internalBlob->buffer(), internalBlob->byteSize()); internalBlob->buffer(), internalBlob->byteSize());
@ -1206,8 +1224,7 @@ InferenceEngine::Precision Node::getRuntimePrecision() const {
return runtimePrecision; return runtimePrecision;
} }
Node* Node::NodesFactory::create(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, Node* Node::NodesFactory::create(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) {
const ExtensionManager::Ptr& extMgr, WeightsSharing::Ptr &w_cache) {
// getExceptionDescWithoutStatus removes redundant information from the exception message. For instance, the NotImplemented // getExceptionDescWithoutStatus removes redundant information from the exception message. For instance, the NotImplemented
// exception is generated in the form: full_path_to_src_file:line_number [ NOT_IMPLEMENTED ] reason. // exception is generated in the form: full_path_to_src_file:line_number [ NOT_IMPLEMENTED ] reason.
// An example for gather node: // An example for gather node:
@ -1229,15 +1246,15 @@ Node* Node::NodesFactory::create(const std::shared_ptr<ngraph::Node>& op, const
Node *newNode = nullptr; Node *newNode = nullptr;
std::string errorMessage; std::string errorMessage;
{ {
std::unique_ptr<Node> ol(createNodeIfRegistered(intel_cpu, Type::Generic, op, eng, w_cache)); std::unique_ptr<Node> ol(createNodeIfRegistered(intel_cpu, Type::Generic, op, context));
if (ol != nullptr && ol->created(extMgr)) if (ol != nullptr && ol->created(context->getExtensionManager()))
newNode = ol.release(); newNode = ol.release();
} }
if (newNode == nullptr) { if (newNode == nullptr) {
try { try {
std::unique_ptr<Node> ol(createNodeIfRegistered(intel_cpu, TypeFromName(op->get_type_name()), op, eng, w_cache)); std::unique_ptr<Node> ol(createNodeIfRegistered(intel_cpu, TypeFromName(op->get_type_name()), op, context));
if (ol != nullptr && ol->created(extMgr)) if (ol != nullptr && ol->created(context->getExtensionManager()))
newNode = ol.release(); newNode = ol.release();
} catch (const InferenceEngine::Exception& ex) { } catch (const InferenceEngine::Exception& ex) {
if (dynamic_cast<const InferenceEngine::NotImplemented*>(&ex) != nullptr) { if (dynamic_cast<const InferenceEngine::NotImplemented*>(&ex) != nullptr) {
@ -1250,8 +1267,8 @@ Node* Node::NodesFactory::create(const std::shared_ptr<ngraph::Node>& op, const
if (newNode == nullptr) { if (newNode == nullptr) {
try { try {
std::unique_ptr<Node> ol(new Reference(op, eng, w_cache, errorMessage)); std::unique_ptr<Node> ol(new Reference(op, context, errorMessage));
if (ol != nullptr && ol->created(extMgr)) if (ol != nullptr && ol->created(context->getExtensionManager()))
newNode = ol.release(); newNode = ol.release();
} catch (const InferenceEngine::Exception& ex) { } catch (const InferenceEngine::Exception& ex) {
if (dynamic_cast<const InferenceEngine::NotImplemented*>(&ex) != nullptr) { if (dynamic_cast<const InferenceEngine::NotImplemented*>(&ex) != nullptr) {
@ -1264,19 +1281,6 @@ Node* Node::NodesFactory::create(const std::shared_ptr<ngraph::Node>& op, const
} }
} }
// WA-start : TI node requires all attributes to construct internal subgpath
// including extManager, socket and dnnl::eng.
if (newNode) {
if (newNode->getType() == Type::TensorIterator) {
if (auto ti = dynamic_cast<TensorIterator*>(newNode))
ti->setExtManager(extMgr);
} else if (newNode->getType() == Type::If) {
if (auto ifNode = dynamic_cast<If*>(newNode))
ifNode->setExtManager(extMgr);
}
}
// // WA-end
if (!newNode) { if (!newNode) {
std::string errorDetails; std::string errorDetails;
if (!errorMessage.empty()) { if (!errorMessage.empty()) {

View File

@ -37,6 +37,7 @@
#include "utils/debug_capabilities.h" #include "utils/debug_capabilities.h"
#include "dnnl_postops_composer.h" #include "dnnl_postops_composer.h"
#include "graph_context.h"
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
@ -527,10 +528,6 @@ public:
return false; return false;
} }
void setQuantizedGraphFlag(bool flag) {
isInQuantizedGraph = flag;
}
bool canBePerformedAsScaleShift(const Node *parentNode = nullptr) const; bool canBePerformedAsScaleShift(const Node *parentNode = nullptr) const;
bool isDynamicNode() const { bool isDynamicNode() const {
@ -573,18 +570,6 @@ public:
virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::unordered_map<int, MemoryPtr>& postOpsMem, const int channelAxis = 1); virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::unordered_map<int, MemoryPtr>& postOpsMem, const int channelAxis = 1);
virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector<const void*>& postOpsMem, const int channelAxis = 1); virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector<const void*>& postOpsMem, const int channelAxis = 1);
void setRuntimeCache(MultiCachePtr cache) {
rtParamsCache = cache;
}
void setRuntimeScratchPad(DnnlScratchPadPtr scratchPad) {
rtScratchPad = scratchPad;
}
void setSharedMutex(const std::shared_ptr<std::mutex>& mutex) {
sharedMutex = mutex;
}
protected: protected:
bool canFuseSimpleOperation(const NodePtr& node) const; bool canFuseSimpleOperation(const NodePtr& node) const;
@ -618,8 +603,8 @@ protected:
std::string originalLayers; // contains names of the original layers separated by comma std::string originalLayers; // contains names of the original layers separated by comma
Node(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &w_cache, const ShapeInferFactory& shapeInferFactory); Node(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr ctx, const ShapeInferFactory& shapeInferFactory);
Node(const std::string& type, const std::string& name, const dnnl::engine& eng, WeightsSharing::Ptr &w_cache); Node(const std::string& type, const std::string& name, const GraphContext::CPtr ctx);
int selectedPrimitiveDescriptorIndex = -1; int selectedPrimitiveDescriptorIndex = -1;
bool permanent = false; bool permanent = false;
@ -645,12 +630,10 @@ protected:
Primitive prim; Primitive prim;
std::vector<DnnlDesriptor> descs; std::vector<DnnlDesriptor> descs;
WeightsSharing::Ptr weightCache; const GraphContext::CPtr context;
Algorithm algorithm = Algorithm::Default; Algorithm algorithm = Algorithm::Default;
bool isInQuantizedGraph = false;
friend class Edge; friend class Edge;
friend class Graph; friend class Graph;
friend class GraphOptimizer; friend class GraphOptimizer;
@ -715,17 +698,9 @@ protected:
IE_THROW(NotImplemented) << "[DS] prapareParams not implemented for node with type " << NameFromType(getType()); IE_THROW(NotImplemented) << "[DS] prapareParams not implemented for node with type " << NameFromType(getType());
} }
MultiCachePtr getRuntimeCache() const {
return rtParamsCache;
}
DnnlScratchPadPtr getRuntimeScratchPad() const {
return rtScratchPad;
}
MemoryPtr getScratchPadMem(const const_dnnl_primitive_desc_t& pd) { MemoryPtr getScratchPadMem(const const_dnnl_primitive_desc_t& pd) {
auto scratchpadMemoryDesc = DnnlExtensionUtils::query_md(pd, dnnl::query::scratchpad_md); auto scratchpadMemoryDesc = DnnlExtensionUtils::query_md(pd, dnnl::query::scratchpad_md);
scratchpadMem = getRuntimeScratchPad()->createScratchPadMem(scratchpadMemoryDesc); scratchpadMem = context->getScratchPad()->createScratchPadMem(scratchpadMemoryDesc);
return scratchpadMem; return scratchpadMem;
} }
@ -733,8 +708,6 @@ protected:
std::shared_ptr<IShapeInfer> shapeInference; std::shared_ptr<IShapeInfer> shapeInference;
std::shared_ptr<std::mutex> sharedMutex = nullptr;
private: private:
std::vector<EdgeWeakPtr> parentEdges; std::vector<EdgeWeakPtr> parentEdges;
std::vector<EdgeWeakPtr> childEdges; std::vector<EdgeWeakPtr> childEdges;
@ -744,7 +717,7 @@ private:
int fusingPort; int fusingPort;
dnnl::engine engine; const dnnl::engine engine;
std::string name; std::string name;
std::string typeStr; std::string typeStr;
@ -756,8 +729,6 @@ private:
PerfCount perfCounter; PerfCount perfCounter;
PerfCounters profiling; PerfCounters profiling;
MultiCachePtr rtParamsCache;
DnnlScratchPadPtr rtScratchPad;
MemoryPtr scratchpadMem; MemoryPtr scratchpadMem;
bool isEdgesEmpty(const std::vector<EdgeWeakPtr>& edges) const; bool isEdgesEmpty(const std::vector<EdgeWeakPtr>& edges) const;
@ -796,19 +767,17 @@ constexpr uint64_t PortMask(int n, T... rest) {
class Node::NodesFactory : public openvino::cc::Factory<Type, class Node::NodesFactory : public openvino::cc::Factory<Type,
Node*(const std::shared_ptr<ngraph::Node>& op, Node*(const std::shared_ptr<ngraph::Node>& op,
const dnnl::engine &, const GraphContext::CPtr)> {
WeightsSharing::Ptr &)> {
public: public:
NodesFactory(); NodesFactory();
Node* create(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, Node* create(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
const ExtensionManager::Ptr& extMgr, WeightsSharing::Ptr &w_cache);
}; };
template<typename NodeType> template<typename NodeType>
struct NodeImpl : public NodeType { struct NodeImpl : public NodeType {
NodeImpl(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) NodeImpl(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: NodeType(op, eng, cache) { : NodeType(op, context) {
NodeType::perfCounters().template buildClassCounters<NodeType>(NameFromType(NodeType::getType())); NodeType::perfCounters().template buildClassCounters<NodeType>(NameFromType(NodeType::getType()));
} }
}; };

View File

@ -98,8 +98,8 @@ bool AdaptivePooling::isSupportedOperation(const std::shared_ptr<const ngraph::N
return true; return true;
} }
AdaptivePooling::AdaptivePooling(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, AdaptivePooling::AdaptivePooling(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, AdaptivePoolingShapeInferFactory(op)) { : Node(op, context, AdaptivePoolingShapeInferFactory(op)) {
std::string errorMessage; std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "Adaptive Pooling layer with name '" + getName() + "' "; errorPrefix = "Adaptive Pooling layer with name '" + getName() + "' ";

View File

@ -16,7 +16,7 @@ namespace node {
class AdaptivePooling : public Node { class AdaptivePooling : public Node {
public: public:
AdaptivePooling(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); AdaptivePooling(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -36,8 +36,8 @@ bool BatchToSpace::isSupportedOperation(const std::shared_ptr<const ngraph::Node
return true; return true;
} }
BatchToSpace::BatchToSpace(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, BatchToSpace::BatchToSpace(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, PortMask(1, 2, 3))) { : Node(op, context, NgraphShapeInferFactory(op, PortMask(1, 2, 3))) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -16,7 +16,7 @@ namespace node {
class BatchToSpace : public Node { class BatchToSpace : public Node {
public: public:
BatchToSpace(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); BatchToSpace(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -897,9 +897,8 @@ bool BinaryConvolution::isSupportedOperation(const std::shared_ptr<const ngraph:
return true; return true;
} }
BinaryConvolution::BinaryConvolution(const std::shared_ptr<ngraph::Node>& op, BinaryConvolution::BinaryConvolution(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
const dnnl::engine& eng, WeightsSharing::Ptr &cache) : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
: Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "BinaryConvolution node with name '" + getName() + "' "; errorPrefix = "BinaryConvolution node with name '" + getName() + "' ";

View File

@ -77,7 +77,7 @@ struct jit_uni_bin_conv_kernel {
class BinaryConvolution : public Node { class BinaryConvolution : public Node {
public: public:
BinaryConvolution(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); BinaryConvolution(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void createPrimitive() override; void createPrimitive() override;

View File

@ -49,8 +49,8 @@ bool Broadcast::isSupportedOperation(const std::shared_ptr<const ov::Node>& op,
return true; return true;
} }
Broadcast::Broadcast(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, Broadcast::Broadcast(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, PortMask(TARGET_SHAPE_IDX, AXES_MAPPING_IDX))) { : Node(op, context, NgraphShapeInferFactory(op, PortMask(TARGET_SHAPE_IDX, AXES_MAPPING_IDX))) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -16,7 +16,7 @@ namespace node {
class Broadcast : public Node, public TileBroadcastCommon { class Broadcast : public Node, public TileBroadcastCommon {
public: public:
Broadcast(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Broadcast(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -30,8 +30,8 @@ bool Bucketize::isSupportedOperation(const std::shared_ptr<const ngraph::Node>&
return true; return true;
} }
Bucketize::Bucketize(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, Bucketize::Bucketize(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, PassThroughShapeInferFactory()) { : Node(op, context, PassThroughShapeInferFactory()) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class Bucketize : public Node { class Bucketize : public Node {
public: public:
Bucketize(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Bucketize(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -1037,10 +1037,8 @@ bool ColorConvert::isSupportedOperation(const std::shared_ptr<const ngraph::Node
return alg != Algorithm::Default; return alg != Algorithm::Default;
} }
ColorConvert::ColorConvert(const std::shared_ptr<ngraph::Node>& op, ColorConvert::ColorConvert(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
const dnnl::engine& eng, : Node(op, context, ColorConvertShapeInferFactory(op)) {
WeightsSharing::Ptr &cache)
: Node(op, eng, cache, ColorConvertShapeInferFactory(op)) {
std::string errorMessage; std::string errorMessage;
std::tie(algorithm, errorMessage) = getAlgorithmFor(op); std::tie(algorithm, errorMessage) = getAlgorithmFor(op);
if (algorithm == Algorithm::Default) if (algorithm == Algorithm::Default)

View File

@ -16,9 +16,7 @@ namespace node {
class ColorConvert : public Node { class ColorConvert : public Node {
public: public:
ColorConvert(const std::shared_ptr<ngraph::Node>& op, ColorConvert(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
const dnnl::engine& eng,
WeightsSharing::Ptr &cache);
class Converter; class Converter;
public: public:

View File

@ -49,8 +49,8 @@ bool Concat::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op,
return true; return true;
} }
Concat::Concat(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) Concat::Concat(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;
@ -282,7 +282,7 @@ void Concat::selectOptimalPrimitiveDescriptor() {
maxCount = it.second; maxCount = it.second;
convertTo = it.first; convertTo = it.first;
} else if (it.second == maxCount) { } else if (it.second == maxCount) {
if (isInQuantizedGraph && it.first == LayoutType::nspc) { if (context->isGraphQuantized() && it.first == LayoutType::nspc) {
convertTo = it.first; convertTo = it.first;
} else if (it.first == LayoutType::nCsp8c || it.first == LayoutType::nCsp16c) { } else if (it.first == LayoutType::nCsp8c || it.first == LayoutType::nCsp16c) {
convertTo = it.first; convertTo = it.first;

View File

@ -15,7 +15,7 @@ namespace node {
class Concat : public Node { class Concat : public Node {
public: public:
Concat(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Concat(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
void getSupportedDescriptors() override; void getSupportedDescriptors() override;

View File

@ -103,7 +103,7 @@ bool ConvKey::operator==(const ConvKey &rhs) const {
class Convolution::FusedSubgraph { class Convolution::FusedSubgraph {
public: public:
FusedSubgraph(const std::vector<NodePtr> &opList, const Convolution &conv, WeightsSharing::Ptr weightCache) { FusedSubgraph(const std::vector<NodePtr> &opList, const Convolution &conv, const GraphContext::CPtr context) {
_graph = std::unique_ptr<Graph>(new Graph()); _graph = std::unique_ptr<Graph>(new Graph());
std::unordered_set<NodePtr> nodesSet; std::unordered_set<NodePtr> nodesSet;
@ -119,11 +119,11 @@ public:
//Make inputs //Make inputs
const auto &inpMemDesc1 = conv.getBaseMemDescAtOutputPort(0); const auto &inpMemDesc1 = conv.getBaseMemDescAtOutputPort(0);
auto inp0 = std::make_shared<Input>(inpMemDesc1, "inp0", "Parameter", conv.getEngine(), weightCache); auto inp0 = std::make_shared<Input>(inpMemDesc1, "inp0", "Parameter", context);
inputs.push_back(inp0); inputs.push_back(inp0);
const size_t sumPortNum = conv.getParentEdges().size() - 1; const size_t sumPortNum = conv.getParentEdges().size() - 1;
const auto &inpMemDesc2 = conv.getBaseMemDescAtInputPort(sumPortNum); const auto &inpMemDesc2 = conv.getBaseMemDescAtInputPort(sumPortNum);
auto inp1 = std::make_shared<Input>(inpMemDesc2, "inp1", "Parameter", conv.getEngine(), weightCache); auto inp1 = std::make_shared<Input>(inpMemDesc2, "inp1", "Parameter", context);
inputs.push_back(inp1); inputs.push_back(inp1);
auto itr = std::find_if(opList.begin(), opList.end(), [](const NodePtr &node) { auto itr = std::find_if(opList.begin(), opList.end(), [](const NodePtr &node) {
@ -162,13 +162,13 @@ public:
//Make output //Make output
const auto &outMemDesc = conv.getBaseMemDescAtOutputPort(0); const auto &outMemDesc = conv.getBaseMemDescAtOutputPort(0);
auto out = std::make_shared<Input>(outMemDesc, "out", "Result", conv.getEngine(), weightCache); auto out = std::make_shared<Input>(outMemDesc, "out", "Result", context);
addEdge(*parentItr, out, 0, 0); addEdge(*parentItr, out, 0, 0);
outputs.push_back(out); outputs.push_back(out);
std::vector<NodePtr> nodes(nodesSet.begin(), nodesSet.end()); std::vector<NodePtr> nodes(nodesSet.begin(), nodesSet.end());
_graph->CreateGraph(nodes, edges, weightCache, "fused_subgraph"); _graph->CreateGraph(nodes, edges, context, "fused_subgraph");
} }
std::shared_ptr<Input> getInput(size_t idx) const { std::shared_ptr<Input> getInput(size_t idx) const {
@ -222,8 +222,8 @@ bool Convolution::isSupportedOperation(const std::shared_ptr<const ngraph::Node>
return true; return true;
} }
Convolution::Convolution(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) Convolution::Convolution(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false), withSum(false), withDWConv(false), : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false), withSum(false), withDWConv(false),
isGrouped(false), dw_conv_oc(0), dw_conv_ih(0), dw_conv_iw(0), dw_conv_in_dt(memory::data_type::undef), isGrouped(false), dw_conv_oc(0), dw_conv_ih(0), dw_conv_iw(0), dw_conv_in_dt(memory::data_type::undef),
groupNum(1lu), IC(1), groupIC(1), groupOC(1), eltwisePrecision(Precision::FP32) { groupNum(1lu), IC(1), groupIC(1), groupOC(1), eltwisePrecision(Precision::FP32) {
std::string errorMessage; std::string errorMessage;
@ -1165,7 +1165,7 @@ bool Convolution::isNspcAvailable() const {
using impl::cpu::x64::mayiuse; using impl::cpu::x64::mayiuse;
// do not use in non-quantized networks until it is enforced externally // do not use in non-quantized networks until it is enforced externally
if (!isInQuantizedGraph) { if (!context->isGraphQuantized()) {
auto predicate = [](memory::format_tag tag) { auto predicate = [](memory::format_tag tag) {
return one_of(tag, memory::format_tag::nwc, memory::format_tag::nhwc, memory::format_tag::ndhwc); return one_of(tag, memory::format_tag::nwc, memory::format_tag::nhwc, memory::format_tag::ndhwc);
}; };
@ -1426,7 +1426,7 @@ void Convolution::prepareParams() {
}; };
execPtr = nullptr; execPtr = nullptr;
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, builder); auto result = cache->getOrCreate(key, builder);
execPtr = result.first; execPtr = result.first;
@ -1524,7 +1524,7 @@ void Convolution::redefineOutputMemory(const std::vector<VectorDims> &newOutputS
if (newOutputShapes.front() != sumInpMem.getStaticDims()) { if (newOutputShapes.front() != sumInpMem.getStaticDims()) {
withSumBroadcast = true; withSumBroadcast = true;
if (!subgraph) { if (!subgraph) {
subgraph = std::make_shared<FusedSubgraph>(fusedWith, *this, weightCache); subgraph = std::make_shared<FusedSubgraph>(fusedWith, *this, context);
} }
auto inp0 = subgraph->getInput(0); auto inp0 = subgraph->getInput(0);
inp0->redefineOutputMemory(newOutputShapes); inp0->redefineOutputMemory(newOutputShapes);

View File

@ -19,7 +19,7 @@ class Eltwise;
class Convolution : public Node { class Convolution : public Node {
public: public:
Convolution(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Convolution(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
void getSupportedDescriptors() override; void getSupportedDescriptors() override;

View File

@ -31,8 +31,8 @@ bool Convert::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op
return true; return true;
} }
Convert::Convert(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) Convert::Convert(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, eng, cache, PassThroughShapeInferFactory()) { : Node(op, context, PassThroughShapeInferFactory()) {
std::string errorMessage; std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "Convert node with name '" + getName() + "'"; errorPrefix = "Convert node with name '" + getName() + "'";
@ -45,8 +45,8 @@ Convert::Convert(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& en
} }
Convert::Convert(const Shape &shape, const InferenceEngine::Precision &inPrc, const InferenceEngine::Precision &outPrc, Convert::Convert(const Shape &shape, const InferenceEngine::Precision &inPrc, const InferenceEngine::Precision &outPrc,
const std::string &nodeName, const dnnl::engine& eng, WeightsSharing::Ptr &cache) const std::string &nodeName, const GraphContext::CPtr context)
: Node("Convert", nodeName, eng, cache) : Node("Convert", nodeName, context)
, origPrc(outPrc) { , origPrc(outPrc) {
inputShapes.push_back(shape); inputShapes.push_back(shape);
addOriginalInputPrecision(inPrc); addOriginalInputPrecision(inPrc);

View File

@ -15,9 +15,9 @@ namespace node {
class Convert : public Node { class Convert : public Node {
public: public:
Convert(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Convert(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
Convert(const Shape &shape, const InferenceEngine::Precision &inPrc, const InferenceEngine::Precision &outPrc, Convert(const Shape &shape, const InferenceEngine::Precision &inPrc, const InferenceEngine::Precision &outPrc,
const std::string &nodeName, const dnnl::engine& eng, WeightsSharing::Ptr &cache); const std::string &nodeName, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -28,8 +28,8 @@ bool CTCGreedyDecoder::isSupportedOperation(const std::shared_ptr<const ngraph::
return true; return true;
} }
CTCGreedyDecoder::CTCGreedyDecoder(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, CTCGreedyDecoder::CTCGreedyDecoder(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class CTCGreedyDecoder : public Node { class CTCGreedyDecoder : public Node {
public: public:
CTCGreedyDecoder(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); CTCGreedyDecoder(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -28,8 +28,8 @@ bool CTCGreedyDecoderSeqLen::isSupportedOperation(const std::shared_ptr<const ng
return true; return true;
} }
CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class CTCGreedyDecoderSeqLen : public Node { class CTCGreedyDecoderSeqLen : public Node {
public: public:
CTCGreedyDecoderSeqLen(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); CTCGreedyDecoderSeqLen(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -27,8 +27,8 @@ bool CTCLoss::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op
return true; return true;
} }
CTCLoss::CTCLoss(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, CTCLoss::CTCLoss(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class CTCLoss : public Node { class CTCLoss : public Node {
public: public:
CTCLoss(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); CTCLoss(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -32,8 +32,7 @@ bool CumSum::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op,
return true; return true;
} }
CumSum::CumSum(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, CumSum::CumSum(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class CumSum : public Node { class CumSum : public Node {
public: public:
CumSum(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); CumSum(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -151,7 +151,7 @@ bool Deconvolution::isSupportedOperation(const std::shared_ptr<const ngraph::Nod
} }
Deconvolution::Deconvolution(const std::shared_ptr<ngraph::Node>& op, Deconvolution::Deconvolution(const std::shared_ptr<ngraph::Node>& op,
const dnnl::engine& eng, WeightsSharing::Ptr &cache) : Node(op, eng, cache, DeconfolutionShapeInferFactory(op)) { const GraphContext::CPtr context) : Node(op, context, DeconfolutionShapeInferFactory(op)) {
std::string errorMessage; std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "Deconvolution node with name '" + getName() + "'"; errorPrefix = "Deconvolution node with name '" + getName() + "'";
@ -928,7 +928,7 @@ void Deconvolution::prepareParams() {
}; };
execPtr = nullptr; execPtr = nullptr;
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, builder); auto result = cache->getOrCreate(key, builder);
execPtr = result.first; execPtr = result.first;

View File

@ -17,7 +17,7 @@ namespace node {
class Deconvolution : public Node { class Deconvolution : public Node {
public: public:
Deconvolution(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Deconvolution(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc, void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,

View File

@ -740,8 +740,8 @@ bool DefConvKey::operator==(const DefConvKey &rhs) const {
} // namespace } // namespace
DeformableConvolution::DeformableConvolution(const std::shared_ptr<ngraph::Node>& op, DeformableConvolution::DeformableConvolution(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
const dnnl::engine& eng, WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;
@ -1221,7 +1221,7 @@ void DeformableConvolution::prepareParams() {
execPtr = nullptr; execPtr = nullptr;
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, [] (const DefConvKey& key) -> std::shared_ptr<DefConvExecutor> { auto result = cache->getOrCreate(key, [] (const DefConvKey& key) -> std::shared_ptr<DefConvExecutor> {
if (key.implType == impl_desc_type::ref) { if (key.implType == impl_desc_type::ref) {
return std::make_shared<DefConvRefExecutor>(key.defConvAttr, key.descVector); return std::make_shared<DefConvRefExecutor>(key.defConvAttr, key.descVector);

View File

@ -71,7 +71,7 @@ struct jit_uni_def_conv_kernel {
class DeformableConvolution : public Node { class DeformableConvolution : public Node {
public: public:
DeformableConvolution(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); DeformableConvolution(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
void getSupportedDescriptors() override; void getSupportedDescriptors() override;

View File

@ -67,8 +67,8 @@ bool DepthToSpace::isSupportedOperation(const std::shared_ptr<const ngraph::Node
return true; return true;
} }
DepthToSpace::DepthToSpace(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) DepthToSpace::DepthToSpace(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;
@ -191,7 +191,7 @@ void DepthToSpace::prepareParams() {
return std::make_shared<DepthToSpaceExecutor>(key); return std::make_shared<DepthToSpaceExecutor>(key);
}; };
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto result = cache->getOrCreate(attrs, builder); auto result = cache->getOrCreate(attrs, builder);
if (!result.first) { if (!result.first) {
IE_THROW() << "DepthToSpaceExecutor was not found for node " << getName() << "."; IE_THROW() << "DepthToSpaceExecutor was not found for node " << getName() << ".";

View File

@ -15,7 +15,7 @@ namespace node {
class DepthToSpace : public Node { class DepthToSpace : public Node {
public: public:
DepthToSpace(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); DepthToSpace(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
void getSupportedDescriptors() override; void getSupportedDescriptors() override;

View File

@ -51,8 +51,8 @@ bool DetectionOutput::isSupportedOperation(const std::shared_ptr<const ov::Node>
return true; return true;
} }
DetectionOutput::DetectionOutput(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, DetectionOutput::DetectionOutput(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -14,7 +14,7 @@ namespace node {
class DetectionOutput : public Node { class DetectionOutput : public Node {
public: public:
DetectionOutput(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); DetectionOutput(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -44,8 +44,8 @@ bool DFT::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, st
return true; return true;
} }
DFT::DFT(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) : DFT::DFT(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) :
Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -16,7 +16,7 @@ namespace node {
class DFT : public Node { class DFT : public Node {
public: public:
DFT(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); DFT(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
~DFT() override = default; ~DFT() override = default;
void getSupportedDescriptors() override; void getSupportedDescriptors() override;

View File

@ -1666,8 +1666,8 @@ bool Eltwise::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op
return true; return true;
} }
Eltwise::Eltwise(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) : Eltwise::Eltwise(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) :
Node(op, eng, cache, EltwiseShapeInferFactory()), broadcastingPolicy(Undefined) { Node(op, context, EltwiseShapeInferFactory()), broadcastingPolicy(Undefined) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;
@ -2050,7 +2050,7 @@ void Eltwise::prepareParams() {
} }
} }
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, buildExecutor); auto result = cache->getOrCreate(key, buildExecutor);
execPtr = result.first; execPtr = result.first;
} }

View File

@ -90,7 +90,7 @@ public:
using executorPtr = std::shared_ptr<IEltwiseExecutor>; using executorPtr = std::shared_ptr<IEltwiseExecutor>;
public: public:
Eltwise(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Eltwise(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -27,8 +27,9 @@ bool EmbeddingBagOffsetSum::isSupportedOperation(const std::shared_ptr<const ngr
return true; return true;
} }
EmbeddingBagOffsetSum::EmbeddingBagOffsetSum(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, EmbeddingBagOffsetSum::EmbeddingBagOffsetSum(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), EmbeddingBagSum(op, 3lu, 1lu, 4lu, 3lu) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)),
EmbeddingBagSum(op, 3lu, 1lu, 4lu, 3lu) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -17,7 +17,7 @@ namespace node {
class EmbeddingBagOffsetSum : public Node, public EmbeddingBagSum { class EmbeddingBagOffsetSum : public Node, public EmbeddingBagSum {
public: public:
EmbeddingBagOffsetSum(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); EmbeddingBagOffsetSum(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -27,8 +27,9 @@ bool EmbeddingBagPackedSum::isSupportedOperation(const std::shared_ptr<const ngr
return true; return true;
} }
EmbeddingBagPackedSum::EmbeddingBagPackedSum(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, EmbeddingBagPackedSum::EmbeddingBagPackedSum(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), EmbeddingBagSum(op, 2lu, 1lu, 2lu, 3lu) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)),
EmbeddingBagSum(op, 2lu, 1lu, 2lu, 3lu) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -17,7 +17,7 @@ namespace node {
class EmbeddingBagPackedSum : public Node, public EmbeddingBagSum { class EmbeddingBagPackedSum : public Node, public EmbeddingBagSum {
public: public:
EmbeddingBagPackedSum(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); EmbeddingBagPackedSum(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -27,8 +27,9 @@ bool EmbeddingSegmentsSum::isSupportedOperation(const std::shared_ptr<const ngra
return true; return true;
} }
EmbeddingSegmentsSum::EmbeddingSegmentsSum(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, EmbeddingSegmentsSum::EmbeddingSegmentsSum(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, PortMask(NUM_SEGMENTS_IDX))), EmbeddingBagSum(op, 4lu, 1lu, 5lu, 4lu) { : Node(op, context, NgraphShapeInferFactory(op, PortMask(NUM_SEGMENTS_IDX))),
EmbeddingBagSum(op, 4lu, 1lu, 5lu, 4lu) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -17,7 +17,7 @@ namespace node {
class EmbeddingSegmentsSum : public Node, public EmbeddingBagSum { class EmbeddingSegmentsSum : public Node, public EmbeddingBagSum {
public: public:
EmbeddingSegmentsSum(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); EmbeddingSegmentsSum(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -236,9 +236,9 @@ bool ExperimentalDetectronDetectionOutput::isSupportedOperation(const std::share
return true; return true;
} }
ExperimentalDetectronDetectionOutput::ExperimentalDetectronDetectionOutput ExperimentalDetectronDetectionOutput::ExperimentalDetectronDetectionOutput(const std::shared_ptr<ngraph::Node>& op,
(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class ExperimentalDetectronDetectionOutput : public Node { class ExperimentalDetectronDetectionOutput : public Node {
public: public:
ExperimentalDetectronDetectionOutput(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); ExperimentalDetectronDetectionOutput(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -285,9 +285,10 @@ bool ExperimentalDetectronGenerateProposalsSingleImage::isSupportedOperation
return true; return true;
} }
ExperimentalDetectronGenerateProposalsSingleImage::ExperimentalDetectronGenerateProposalsSingleImage ExperimentalDetectronGenerateProposalsSingleImage::ExperimentalDetectronGenerateProposalsSingleImage(
(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, const std::shared_ptr<ngraph::Node>& op,
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -14,7 +14,7 @@ namespace node {
class ExperimentalDetectronGenerateProposalsSingleImage : public Node { class ExperimentalDetectronGenerateProposalsSingleImage : public Node {
public: public:
ExperimentalDetectronGenerateProposalsSingleImage(const std::shared_ptr<ngraph::Node>& op, ExperimentalDetectronGenerateProposalsSingleImage(const std::shared_ptr<ngraph::Node>& op,
const dnnl::engine& eng, WeightsSharing::Ptr &cache); const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -28,9 +28,10 @@ bool ExperimentalDetectronPriorGridGenerator::isSupportedOperation(const std::sh
return true; return true;
} }
ExperimentalDetectronPriorGridGenerator::ExperimentalDetectronPriorGridGenerator ExperimentalDetectronPriorGridGenerator::ExperimentalDetectronPriorGridGenerator(
(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, const std::shared_ptr<ngraph::Node>& op,
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class ExperimentalDetectronPriorGridGenerator : public Node { class ExperimentalDetectronPriorGridGenerator : public Node {
public: public:
ExperimentalDetectronPriorGridGenerator(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); ExperimentalDetectronPriorGridGenerator(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -324,9 +324,10 @@ bool ExperimentalDetectronROIFeatureExtractor::isSupportedOperation(const std::s
return true; return true;
} }
ExperimentalDetectronROIFeatureExtractor::ExperimentalDetectronROIFeatureExtractor ExperimentalDetectronROIFeatureExtractor::ExperimentalDetectronROIFeatureExtractor(
(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, const std::shared_ptr<ngraph::Node>& op,
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class ExperimentalDetectronROIFeatureExtractor : public Node { class ExperimentalDetectronROIFeatureExtractor : public Node {
public: public:
ExperimentalDetectronROIFeatureExtractor(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); ExperimentalDetectronROIFeatureExtractor(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -30,8 +30,9 @@ bool ExperimentalDetectronTopKROIs::isSupportedOperation(const std::shared_ptr<c
return true; return true;
} }
ExperimentalDetectronTopKROIs::ExperimentalDetectronTopKROIs(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, ExperimentalDetectronTopKROIs::ExperimentalDetectronTopKROIs(const std::shared_ptr<ngraph::Node>& op,
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class ExperimentalDetectronTopKROIs : public Node { class ExperimentalDetectronTopKROIs : public Node {
public: public:
ExperimentalDetectronTopKROIs(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); ExperimentalDetectronTopKROIs(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -327,8 +327,8 @@ bool ExtractImagePatchesKey::operator==(const ExtractImagePatchesKey& rhs) const
} }
} // namespace } // namespace
ExtractImagePatches::ExtractImagePatches(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, ExtractImagePatches::ExtractImagePatches(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;
@ -398,7 +398,7 @@ void ExtractImagePatches::prepareParams() {
key.prcSize); key.prcSize);
} }
}; };
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, buildExecutor); auto result = cache->getOrCreate(key, buildExecutor);
execPtr = result.first; execPtr = result.first;
} }

View File

@ -44,7 +44,7 @@ struct jit_uni_extract_image_patches_kernel {
class ExtractImagePatches : public Node { class ExtractImagePatches : public Node {
public: public:
ExtractImagePatches(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); ExtractImagePatches(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -48,8 +48,7 @@ private:
}; };
} // namespace } // namespace
Eye::Eye(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, Eye::Eye(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context) : Node(op, context, EyeShapeInferFactory(op)) {
WeightsSharing::Ptr &cache) : Node(op, eng, cache, EyeShapeInferFactory(op)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -22,7 +22,7 @@ public:
static constexpr size_t BATCH_SHAPE = 3lu; static constexpr size_t BATCH_SHAPE = 3lu;
public: public:
Eye(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Eye(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -914,8 +914,8 @@ struct FakeQuantKey {
}; };
} // namespace } // namespace
FakeQuantize::FakeQuantize(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) : FakeQuantize::FakeQuantize(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) :
Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
algorithm = Algorithm::FQCommon; algorithm = Algorithm::FQCommon;
@ -1433,7 +1433,7 @@ void FakeQuantize::prepareParams() {
key.jqp.is_planar = srcDesc->hasLayoutType(LayoutType::ncsp) && one_of(srcDesc->getShape().getRank(), 3, 4, 5); key.jqp.is_planar = srcDesc->hasLayoutType(LayoutType::ncsp) && one_of(srcDesc->getShape().getRank(), 3, 4, 5);
key.jqp.op_type = getAlgorithm(); key.jqp.op_type = getAlgorithm();
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto buildExecutor = [](const FakeQuantKey& key) { auto buildExecutor = [](const FakeQuantKey& key) {
return std::make_shared<FakeQuantizeJitExecutor>(key.jqp); return std::make_shared<FakeQuantizeJitExecutor>(key.jqp);
}; };

View File

@ -68,7 +68,7 @@ struct jit_uni_quantize_kernel {
class FakeQuantize : public Node { class FakeQuantize : public Node {
public: public:
FakeQuantize(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); FakeQuantize(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;
void getSupportedDescriptors() override; void getSupportedDescriptors() override;

View File

@ -113,13 +113,16 @@ bool FullyConnected::isSupportedOperation(const std::shared_ptr<const ngraph::No
return true; return true;
} }
FullyConnected::FullyConnected(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) FullyConnected::FullyConnected(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false) {
std::string errorMessage; std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "FullyConnected node with name '" + getName() + "'"; errorPrefix = "FullyConnected node with name '" + getName() + "'";
withBiases = inputShapes.size() == 3; withBiases = inputShapes.size() == 3;
if (context->getConfig().fcSparseWeiDecompressionRate < 1.0f)
minSparseRate = context->getConfig().fcSparseWeiDecompressionRate;
} else { } else {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;
} }
@ -330,7 +333,7 @@ void FullyConnected::prepareParams() {
return execPtr; return execPtr;
}; };
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, builder); auto result = cache->getOrCreate(key, builder);
if (!result.first) { if (!result.first) {
@ -867,7 +870,7 @@ MemoryPtr FullyConnected::prepareWeightMemory(DnnlMemoryDescPtr weightDesc) {
MemoryPtr _ptr = std::make_shared<Memory>(getEngine()); MemoryPtr _ptr = std::make_shared<Memory>(getEngine());
_ptr->Create(weightDesc); _ptr->Create(weightDesc);
node::Reorder::reorderData(srcMemory, *_ptr, getRuntimeCache()); node::Reorder::reorderData(srcMemory, *_ptr, context->getParamsCache());
return _ptr; return _ptr;
}; };
@ -878,6 +881,7 @@ MemoryPtr FullyConnected::prepareWeightMemory(DnnlMemoryDescPtr weightDesc) {
if (privateWeightCache.end() != itr) { if (privateWeightCache.end() != itr) {
ptr = itr->second; ptr = itr->second;
} else { } else {
auto weightCache = context->getWeightsCache();
if (weightCache != nullptr) { if (weightCache != nullptr) {
const std::string string_hash = getName() + "_" + format const std::string string_hash = getName() + "_" + format
+ "_" + std::to_string(blob->GetSize()) + "_" + std::to_string(blob->GetSize())

View File

@ -17,7 +17,7 @@ namespace node {
class FullyConnected : public Node { class FullyConnected : public Node {
public: public:
FullyConnected(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); FullyConnected(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
std::vector<dnnl::memory::format_tag> getAvailableFormatsForDims(const Shape &dims) const override; std::vector<dnnl::memory::format_tag> getAvailableFormatsForDims(const Shape &dims) const override;
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
@ -59,8 +59,6 @@ public:
void setDynamicBatchLim(int lim) override; void setDynamicBatchLim(int lim) override;
void setMinSparseRate(float sparseRate) { minSparseRate = sparseRate; }
private: private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc, void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
const dnnl::memory::desc &outputDesc); const dnnl::memory::desc &outputDesc);

View File

@ -41,8 +41,9 @@ bool Gather::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std
return true; return true;
} }
Gather::Gather(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, Gather::Gather(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, PortMask(GATHER_AXIS))), batchDims(0) { : Node(op, context, NgraphShapeInferFactory(op, PortMask(GATHER_AXIS))),
batchDims(0) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -17,7 +17,7 @@ namespace node {
class Gather : public Node { class Gather : public Node {
public: public:
Gather(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Gather(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -32,8 +32,8 @@ bool GatherElements::isSupportedOperation(const std::shared_ptr<const ov::Node>&
return true; return true;
} }
GatherElements::GatherElements(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, GatherElements::GatherElements(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -16,7 +16,7 @@ namespace node {
class GatherElements : public Node { class GatherElements : public Node {
public: public:
GatherElements(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); GatherElements(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -34,8 +34,8 @@ bool GatherND::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& o
return true; return true;
} }
GatherND::GatherND(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, GatherND::GatherND(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -16,7 +16,7 @@ namespace node {
class GatherND : public Node { class GatherND : public Node {
public: public:
GatherND(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); GatherND(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -30,8 +30,8 @@ bool GatherTree::isSupportedOperation(const std::shared_ptr<const ngraph::Node>&
return true; return true;
} }
GatherTree::GatherTree(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, GatherTree::GatherTree(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class GatherTree : public Node { class GatherTree : public Node {
public: public:
GatherTree(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); GatherTree(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -289,8 +289,8 @@ bool GenerateProposals::isSupportedOperation
return true; return true;
} }
GenerateProposals::GenerateProposals(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, GenerateProposals::GenerateProposals(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, InternalDynShapeInferFactory()) { : Node(op, context, InternalDynShapeInferFactory()) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,8 +13,7 @@ namespace node {
class GenerateProposals : public Node { class GenerateProposals : public Node {
public: public:
GenerateProposals(const std::shared_ptr<ngraph::Node>& op, GenerateProposals(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
const dnnl::engine& eng, WeightsSharing::Ptr &cache);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -43,8 +43,8 @@ public:
}; };
} // namespace } // namespace
Generic::Generic(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) Generic::Generic(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, eng, cache, GenericShapeInferFactory()), ngraphOp(op) { : Node(op, context, GenericShapeInferFactory()), ngraphOp(op) {
} }
void Generic::getSupportedDescriptors() { void Generic::getSupportedDescriptors() {

View File

@ -18,7 +18,7 @@ namespace node {
class Generic : public Node { class Generic : public Node {
public: public:
Generic(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Generic(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
~Generic() = default; ~Generic() = default;
void getSupportedDescriptors() override; void getSupportedDescriptors() override;

View File

@ -34,8 +34,8 @@ bool GridSample::isSupportedOperation(const std::shared_ptr<const ov::Node>& op,
return true; return true;
} }
GridSample::GridSample(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, GridSample::GridSample(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, PortMask(1))) { : Node(op, context, NgraphShapeInferFactory(op, PortMask(1))) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -17,7 +17,7 @@ namespace node {
class GridSample : public Node { class GridSample : public Node {
public: public:
GridSample(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); GridSample(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};

View File

@ -27,8 +27,8 @@ bool GRN::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, st
return true; return true;
} }
GRN::GRN(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, GRN::GRN(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,7 +13,7 @@ namespace node {
class GRN : public Node { class GRN : public Node {
public: public:
GRN(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); GRN(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -58,8 +58,8 @@ bool If::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::st
return true; return true;
} }
If::If(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) : If::If(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context) :
Node(op, eng, cache, InternalDynShapeInferFactory()), ovOp(op) { Node(op, context, InternalDynShapeInferFactory()), ovOp(op) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;
@ -71,8 +71,8 @@ void If::getSupportedDescriptors() {
const std::shared_ptr<const ov::Model>& thenBody = ifOp->get_then_body(); const std::shared_ptr<const ov::Model>& thenBody = ifOp->get_then_body();
const std::shared_ptr<const ov::Model>& elseBody = ifOp->get_else_body(); const std::shared_ptr<const ov::Model>& elseBody = ifOp->get_else_body();
subGraphThen.CreateGraph(thenBody, ext_mng, weightCache, sharedMutex); subGraphThen.CreateGraph(thenBody, context);
subGraphElse.CreateGraph(elseBody, ext_mng, weightCache, sharedMutex); subGraphElse.CreateGraph(elseBody, context);
const auto &inMapThen = subGraphThen.GetInputNodesMap(); const auto &inMapThen = subGraphThen.GetInputNodesMap();
for (const auto &param : ifOp->get_then_body()->get_parameters()) { for (const auto &param : ifOp->get_then_body()->get_parameters()) {

View File

@ -17,7 +17,7 @@ namespace node {
class If : public Node { class If : public Node {
public: public:
If(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); If(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;
@ -27,8 +27,6 @@ public:
void execute(dnnl::stream strm) override; void execute(dnnl::stream strm) override;
bool isExecutable() const override { return true; } bool isExecutable() const override { return true; }
void inline setExtManager(const ExtensionManager::Ptr& extMgr) { ext_mng = extMgr; }
protected: protected:
void executeDynamicImpl(dnnl::stream strm) override; void executeDynamicImpl(dnnl::stream strm) override;
bool needPrepareParams() const override { return false; }; bool needPrepareParams() const override { return false; };

View File

@ -230,8 +230,8 @@ jit_has_subnormals_base::fn_t jit_has_subnormals_function() {
} // namespace } // namespace
Input::Input(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) Input::Input(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, eng, cache, PassThroughShapeInferFactory()) { : Node(op, context, PassThroughShapeInferFactory()) {
if (!one_of(op->get_type_info(), if (!one_of(op->get_type_info(),
v0::Parameter::get_type_info_static(), v0::Parameter::get_type_info_static(),
v0::Constant::get_type_info_static(), v0::Constant::get_type_info_static(),
@ -351,6 +351,7 @@ void Input::cloneBlobIfRequired() {
+ "_" + ptr; + "_" + ptr;
}; };
auto weightCache = context->getWeightsCache();
if (weightCache) { if (weightCache) {
MemoryPtr ptr = *weightCache->findOrCreate(blobKey(), cloneBlob); MemoryPtr ptr = *weightCache->findOrCreate(blobKey(), cloneBlob);
memoryPtr = std::const_pointer_cast<const Memory>(ptr); memoryPtr = std::const_pointer_cast<const Memory>(ptr);
@ -363,9 +364,12 @@ void Input::cloneBlobIfRequired() {
} }
} }
Input::Input(const Shape& shape, const InferenceEngine::Precision &prc, const std::string &name, Input::Input(const Shape& shape,
const std::string &type, const dnnl::engine& eng, WeightsSharing::Ptr &cache) const InferenceEngine::Precision& prc,
: Node(type, name, eng, cache) { const std::string& name,
const std::string& type,
const GraphContext::CPtr context)
: Node(type, name, context) {
constant = ConstantType::NoConst; constant = ConstantType::NoConst;
if (getType() == Type::Input) { if (getType() == Type::Input) {
outputShapes.emplace_back(shape); outputShapes.emplace_back(shape);
@ -376,9 +380,8 @@ Input::Input(const Shape& shape, const InferenceEngine::Precision &prc, const st
} }
} }
Input::Input(MemoryDescPtr memDesc, const std::string &name, const std::string &type, Input::Input(MemoryDescPtr memDesc, const std::string& name, const std::string& type, const GraphContext::CPtr context)
const dnnl::engine &eng, WeightsSharing::Ptr &cache) : : Input(memDesc->getShape(), memDesc->getPrecision(), name, type, context) {
Input(memDesc->getShape(), memDesc->getPrecision(), name, type, eng, cache) {
extMemDesc = memDesc; extMemDesc = memDesc;
} }

View File

@ -15,11 +15,13 @@ namespace node {
class Input : public Node { class Input : public Node {
public: public:
Input(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Input(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
Input(const Shape& shape, const InferenceEngine::Precision &prc, const std::string &name, Input(const Shape& shape,
const std::string &type, const dnnl::engine& eng, WeightsSharing::Ptr &cache); const InferenceEngine::Precision& prc,
Input(MemoryDescPtr memDesc, const std::string &name, const std::string &type, const dnnl::engine& eng, const std::string& name,
WeightsSharing::Ptr &cache); const std::string& type,
const GraphContext::CPtr context);
Input(MemoryDescPtr memDesc, const std::string& name, const std::string& type, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -155,8 +155,8 @@ private:
std::unordered_map<size_t, std::unique_ptr<jit_emitter>> emitters; std::unordered_map<size_t, std::unique_ptr<jit_emitter>> emitters;
}; };
Interaction::Interaction(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) Interaction::Interaction(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -46,7 +46,7 @@ struct jit_uni_move_scale_kernel {
class Interaction : public Node { class Interaction : public Node {
public: public:
Interaction(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Interaction(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;
void execute(dnnl::stream strm) override; void execute(dnnl::stream strm) override;

View File

@ -1576,8 +1576,8 @@ private:
}; };
} // namespace } // namespace
Interpolate::Interpolate(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) Interpolate::Interpolate(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, eng, cache, InterpolateShapeInferFactory(op)) { : Node(op, context, InterpolateShapeInferFactory(op)) {
std::string errorMessage; std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "Interpolate node with name '" + getName() + "'"; errorPrefix = "Interpolate node with name '" + getName() + "'";
@ -1920,7 +1920,7 @@ void Interpolate::prepareParams() {
return executor; return executor;
}; };
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, buildExecutor); auto result = cache->getOrCreate(key, buildExecutor);
execPtr = result.first; execPtr = result.first;

View File

@ -102,7 +102,7 @@ public:
static constexpr int CUBIC_GRID_LEN = 4; static constexpr int CUBIC_GRID_LEN = 4;
public: public:
Interpolate(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Interpolate(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -27,8 +27,8 @@ bool LogSoftmax::isSupportedOperation(const std::shared_ptr<const ngraph::Node>&
return true; return true;
} }
LogSoftmax::LogSoftmax(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, LogSoftmax::LogSoftmax(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
WeightsSharing::Ptr &cache) : Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; IE_THROW(NotImplemented) << errorMessage;

View File

@ -13,8 +13,7 @@ namespace node {
class LogSoftmax : public Node { class LogSoftmax : public Node {
public: public:
LogSoftmax(const std::shared_ptr<ngraph::Node>& op, LogSoftmax(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
const dnnl::engine& eng, WeightsSharing::Ptr &cache);
void getSupportedDescriptors() override {}; void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;

View File

@ -106,8 +106,8 @@ bool Lrn::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, st
return true; return true;
} }
Lrn::Lrn(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) : Lrn::Lrn(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) :
Node(op, eng, cache, PassThroughShapeInferFactory()) { Node(op, context, PassThroughShapeInferFactory()) {
std::string errorMessage; std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "LRN node with name '" + getName() + "'"; errorPrefix = "LRN node with name '" + getName() + "'";
@ -195,7 +195,7 @@ void Lrn::prepareParams() {
return std::make_shared<dnnl::lrn_forward>(prim_desc); return std::make_shared<dnnl::lrn_forward>(prim_desc);
}; };
auto cache = getRuntimeCache(); auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, builder); auto result = cache->getOrCreate(key, builder);
if (!result.first) { if (!result.first) {
IE_THROW() << "Primitive descriptor was not found for node " << getName() << "."; IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";

View File

@ -16,7 +16,7 @@ namespace node {
class Lrn : public Node { class Lrn : public Node {
public: public:
Lrn(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); Lrn(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc, void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,

Some files were not shown because too many files have changed in this diff Show More