[CPU] Optimal number of streams calculation moved after LPT (#19313)

This commit is contained in:
Vladislav Golubev 2023-08-23 14:28:42 +02:00 committed by GitHub
parent 25e89a754d
commit 982d0f43c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 18 deletions

View File

@ -499,6 +499,9 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
DEBUG_LOG(PrintableModel(*nGraphFunc, "org_")); DEBUG_LOG(PrintableModel(*nGraphFunc, "org_"));
Transformations transformations(nGraphFunc, enableLPT, inferencePrecision, isLegacyAPI(), snippetsMode, engConfig);
transformations.UpToLpt();
if (!is_cpu_map_available()) { if (!is_cpu_map_available()) {
ApplyPerformanceHints(config, nGraphFunc); ApplyPerformanceHints(config, nGraphFunc);
} }
@ -510,8 +513,8 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
conf.readProperties(config, modelType); conf.readProperties(config, modelType);
CalculateStreams(conf, nGraphFunc); CalculateStreams(conf, nGraphFunc);
Transformations transformations(nGraphFunc, enableLPT, inferencePrecision, isLegacyAPI(), snippetsMode, conf); transformations.PostLpt();
transformations.UpToCpuSpecificOpSet(); transformations.Snippets();
// need to check that all outputs have static shapes // need to check that all outputs have static shapes
// checking that all inputs have static shapes is performed in the common part // checking that all inputs have static shapes is performed in the common part
@ -783,7 +786,9 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma
auto supported = GetSupportedNodes(model, auto supported = GetSupportedNodes(model,
[&](std::shared_ptr<ov::Model>& model) { [&](std::shared_ptr<ov::Model>& model) {
Transformations transformation(model, enableLPT, conf.inferencePrecision, isLegacyAPI(), snippetsMode, engConfig); Transformations transformation(model, enableLPT, conf.inferencePrecision, isLegacyAPI(), snippetsMode, engConfig);
transformation.UpToCpuSpecificOpSet(); transformation.UpToLpt();
transformation.PostLpt();
transformation.Snippets();
transformation.CpuSpecificOpSet(); transformation.CpuSpecificOpSet();
}, },
[&](const std::shared_ptr<ngraph::Node>& op) { [&](const std::shared_ptr<ngraph::Node>& op) {

View File

@ -157,14 +157,11 @@ bool Transformations::fuse_type_to_convert(const std::shared_ptr<ngraph::Node>&
return false; return false;
} }
void Transformations::UpToCpuSpecificOpSet() { void Transformations::UpToLpt() {
const bool useLpt = enableLpt && const bool useLpt = enableLpt &&
ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(model) && ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(model) &&
CPU_DEBUG_CAP_IS_TRANSFORMATION_ENABLED(config.debugCaps, Lpt); CPU_DEBUG_CAP_IS_TRANSFORMATION_ENABLED(config.debugCaps, Lpt);
const bool useSnippets = snippetsMode != Config::SnippetsMode::Disable &&
CPU_DEBUG_CAP_IS_TRANSFORMATION_ENABLED(config.debugCaps, Snippets);
auto defaultPrecisions = useLpt ? ngraph::pass::low_precision::precision_set::int8_support : std::vector<ov::element::Type>{}; auto defaultPrecisions = useLpt ? ngraph::pass::low_precision::precision_set::int8_support : std::vector<ov::element::Type>{};
bool hasINT16orINT32Levels = false; bool hasINT16orINT32Levels = false;
@ -183,11 +180,6 @@ void Transformations::UpToCpuSpecificOpSet() {
if (useLpt) if (useLpt)
Lpt(hasINT16orINT32Levels, defaultPrecisions); Lpt(hasINT16orINT32Levels, defaultPrecisions);
PostLpt();
if (useSnippets)
Snippets();
} }
void Transformations::CpuSpecificOpSet(void) { void Transformations::CpuSpecificOpSet(void) {
@ -731,8 +723,12 @@ void Transformations::PostSnippets(void) {
} }
void Transformations::Snippets(void) { void Transformations::Snippets(void) {
CPU_DEBUG_CAP_TRANSFORMATION_SCOPE(this, Snippets); const bool useSnippets = snippetsMode != Config::SnippetsMode::Disable &&
CPU_DEBUG_CAP_IS_TRANSFORMATION_ENABLED(config.debugCaps, Snippets);
if (!useSnippets)
return;
CPU_DEBUG_CAP_TRANSFORMATION_SCOPE(this, Snippets);
MainSnippets(); MainSnippets();
PostSnippets(); PostSnippets();
} }

View File

@ -39,8 +39,10 @@ public:
CPU_DEBUG_CAPS_MAYBE_UNUSED(this->config); CPU_DEBUG_CAPS_MAYBE_UNUSED(this->config);
} }
void UpToCpuSpecificOpSet(); void UpToLpt();
void CpuSpecificOpSet(); void CpuSpecificOpSet();
void PostLpt();
void Snippets(void);
private: private:
std::shared_ptr<ov::Model> model; std::shared_ptr<ov::Model> model;
@ -54,14 +56,10 @@ private:
void Lpt(const bool hasINT16orINT32Levels, const std::vector<ov::element::Type>& defaultPrecisions); void Lpt(const bool hasINT16orINT32Levels, const std::vector<ov::element::Type>& defaultPrecisions);
void PostLpt();
void MainSnippets(void); void MainSnippets(void);
void PostSnippets(void); void PostSnippets(void);
void Snippets(void);
static bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions); static bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
}; };