[LPT] bfloat enabling fix (#2819)

This commit is contained in:
Edward Shogulin 2020-10-26 16:02:11 +03:00 committed by GitHub
parent 0267cbd286
commit 5007cba70a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -54,8 +54,8 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::ICNNNetwork &network
// we are cloning network if we have statistics and we can transform network.
_clonedNetwork = cloneNet(network);
#ifdef USE_CNNNETWORK_LPT
if (_cfg.lpTransformsMode == Config::LPTransformsMode::On) {
#ifdef USE_CNNNETWORK_LPT
auto params = LayerTransformation::Params(true, // updatePrecisions
true, // quantizeOutputs
true, // weightsToConst
@ -70,6 +70,7 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::ICNNNetwork &network
LayerTransformation::Params(params).setPrecisionsOnActivations({ Precision::U8 }),
"ScaleShift"));
transformer.transform(*_clonedNetwork);
#endif
// Check if network is INT8 or Binary.
// BF16 transformations were disabled since CPU plug-in doesn't support mixed precision execution:
@ -98,7 +99,6 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::ICNNNetwork &network
bf16Transformer.convertToFloat(cnnetwork);
}
}
#endif
MKLDNNGraph::ApplyUnrollPasses(static_cast<ICNNNetwork&>(*_clonedNetwork));