[CPU] Move streams calculation before transformation pipeline (#18911)
This commit is contained in:
parent
2b12780588
commit
0cad2f1324
@ -67,9 +67,7 @@ static MemBandwidthPressure MemBandwidthPressureTolerance(
|
|||||||
output.get_partial_shape().is_static()) {
|
output.get_partial_shape().is_static()) {
|
||||||
const auto& shapeInput0 = input0.get_shape();
|
const auto& shapeInput0 = input0.get_shape();
|
||||||
const auto& shapeInput1 = input1.get_shape();
|
const auto& shapeInput1 = input1.get_shape();
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
const auto non_const = !ov::op::util::is_on_constant_path(node->input_value(1));
|
||||||
const auto non_const = !get_constant_from_source(node->input_value(1));
|
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
|
||||||
const auto& shapeOutput = output.get_shape();
|
const auto& shapeOutput = output.get_shape();
|
||||||
const auto dataSizeInput0 =
|
const auto dataSizeInput0 =
|
||||||
std::accumulate(shapeInput0.begin(), shapeInput0.end(), size_t(1), std::multiplies<size_t>());
|
std::accumulate(shapeInput0.begin(), shapeInput0.end(), size_t(1), std::multiplies<size_t>());
|
||||||
@ -88,11 +86,14 @@ static MemBandwidthPressure MemBandwidthPressureTolerance(
|
|||||||
const auto input = node->input(0);
|
const auto input = node->input(0);
|
||||||
const auto output = node->output(0);
|
const auto output = node->output(0);
|
||||||
const auto kernels = node->input(1);
|
const auto kernels = node->input(1);
|
||||||
const auto& shape = kernels.get_shape();
|
|
||||||
total_convs++;
|
total_convs++;
|
||||||
if (shape.size() >= 4 /* conventional 2D/3D conv */ && shape[2] >= 3 && shape[3] >= 3) {
|
if (kernels.get_partial_shape().is_static()) {
|
||||||
compute_convs++;
|
const auto& shape = kernels.get_shape();
|
||||||
continue;
|
if (shape.size() >= 4 /* conventional 2D/3D conv */ && shape[2] >= 3 && shape[3] >= 3) {
|
||||||
|
compute_convs++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (input.get_partial_shape().is_static() && output.get_partial_shape().is_static()) {
|
if (input.get_partial_shape().is_static() && output.get_partial_shape().is_static()) {
|
||||||
const auto& shapeInput = input.get_shape();
|
const auto& shapeInput = input.get_shape();
|
||||||
|
@ -491,7 +491,18 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
|
|||||||
|
|
||||||
DEBUG_LOG(PrintableModel(*nGraphFunc, "org_"));
|
DEBUG_LOG(PrintableModel(*nGraphFunc, "org_"));
|
||||||
|
|
||||||
Transformations transformations(nGraphFunc, enableLPT, inferencePrecision, isLegacyAPI(), snippetsMode, engConfig);
|
if (!is_cpu_map_available()) {
|
||||||
|
ApplyPerformanceHints(config, nGraphFunc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the props after the perf mode translated to configs
|
||||||
|
// TODO: Clarify the behavior of SetConfig method. Skip eng_config or not?
|
||||||
|
Config conf = engConfig;
|
||||||
|
|
||||||
|
conf.readProperties(config);
|
||||||
|
CalculateStreams(conf, nGraphFunc);
|
||||||
|
|
||||||
|
Transformations transformations(nGraphFunc, enableLPT, inferencePrecision, isLegacyAPI(), snippetsMode, conf);
|
||||||
transformations.UpToCpuSpecificOpSet();
|
transformations.UpToCpuSpecificOpSet();
|
||||||
|
|
||||||
// need to check that all outputs have static shapes
|
// need to check that all outputs have static shapes
|
||||||
@ -504,20 +515,10 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!is_cpu_map_available()) {
|
|
||||||
ApplyPerformanceHints(config, nGraphFunc);
|
|
||||||
}
|
|
||||||
transformations.CpuSpecificOpSet();
|
transformations.CpuSpecificOpSet();
|
||||||
|
|
||||||
DEBUG_LOG(PrintableModel(*nGraphFunc, "cpu_"));
|
DEBUG_LOG(PrintableModel(*nGraphFunc, "cpu_"));
|
||||||
|
|
||||||
// update the props after the perf mode translated to configs
|
|
||||||
// TODO: Clarify the behavior of SetConfig method. Skip eng_config or not?
|
|
||||||
Config conf = engConfig;
|
|
||||||
|
|
||||||
conf.readProperties(config);
|
|
||||||
CalculateStreams(conf, nGraphFunc);
|
|
||||||
|
|
||||||
// SSE runtime check is needed for some ATOM machine, which is x86-64 but w/o SSE
|
// SSE runtime check is needed for some ATOM machine, which is x86-64 but w/o SSE
|
||||||
static Xbyak::util::Cpu cpu;
|
static Xbyak::util::Cpu cpu;
|
||||||
if (cpu.has(Xbyak::util::Cpu::tSSE)) {
|
if (cpu.has(Xbyak::util::Cpu::tSSE)) {
|
||||||
|
Loading…
Reference in New Issue
Block a user