[LPT] GPU tests were fixed
This commit is contained in:
parent
971811c8c8
commit
8eb88d51f2
@ -93,8 +93,6 @@ public:
|
||||
|
||||
static Blob::Ptr getBlob(const CNNLayer* layer, const std::string& blobName);
|
||||
|
||||
static bool blobValuesAreEqual(const CNNLayer& layer, const std::string& blobName);
|
||||
|
||||
static std::shared_ptr<float> getFloatData(const CNNLayerPtr& layer, const std::string& blobName);
|
||||
|
||||
static std::shared_ptr<float> getFloatData(const Blob::Ptr& srcBlob);
|
||||
|
@ -68,21 +68,15 @@ bool FullyConnectedTransformation::canBeTransformed(const TransformationContext&
|
||||
return false;
|
||||
}
|
||||
|
||||
// 3D tensor custom validation
|
||||
if ((inTensorDims.size() == 3ul) &&
|
||||
((!CNNNetworkHelper::blobValuesAreEqual(*scaleShift, "weights")) || (!CNNNetworkHelper::blobValuesAreEqual(*scaleShift, "biases")))) {
|
||||
std::vector<float> dequantizationScales;
|
||||
std::vector<float> dequantizationShifts;
|
||||
fillFromDequantizationLayer(*scaleShift, dequantizationScales, dequantizationShifts);
|
||||
|
||||
if ((inTensorDims.size() == 3ul) && (!DequantizationDetails::isPerTensor(dequantizationScales, dequantizationShifts))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Blob::Ptr prevDequantizationScaleBlob = CNNNetworkHelper::getBlob(scaleShift, "weights");
|
||||
const size_t prevDequantizationScaleBlobSize = prevDequantizationScaleBlob->size();
|
||||
if (prevDequantizationScaleBlobSize != inTensorDims[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Blob::Ptr prevDequantizationShiftBlob = CNNNetworkHelper::getBlob(scaleShift, "biases");
|
||||
const size_t prevDequantizationShiftBlobSize = prevDequantizationShiftBlob->size();
|
||||
if (prevDequantizationShiftBlobSize != inTensorDims[1]) {
|
||||
if ((dequantizationScales.size() != inTensorDims[1]) || (dequantizationShifts.size() != inTensorDims[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -615,20 +615,6 @@ Blob::Ptr CNNNetworkHelper::getBlob(const CNNLayer* layer, const std::string& bl
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool CNNNetworkHelper::blobValuesAreEqual(const CNNLayer& layer, const std::string& blobName) {
|
||||
const Blob::Ptr blob = CNNNetworkHelper::getBlob(&layer, blobName);
|
||||
const std::shared_ptr<float> buffer = CNNNetworkHelper::getFloatData(blob);
|
||||
if (!std::equal(
|
||||
buffer.get(),
|
||||
buffer.get() + blob->size(),
|
||||
buffer.get(),
|
||||
[](const float value1, const float value2) { return value1 == value2; })) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Blob::Ptr CNNNetworkHelper::getBlob(CNNLayerPtr layer, const std::string& blobName) {
|
||||
return getBlob(layer.get(), blobName);
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ const std::vector<InferenceEngine::SizeVector> dimensions = {
|
||||
};
|
||||
|
||||
const std::vector<LayerTransformation::Params> trasformationParamValues = {
|
||||
LayerTestsUtils::LayerTransformationParamsFactory::createParams()
|
||||
LayerTestsUtils::LayerTransformationParamsFactory::createParamsI8I8()
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(LPT, GemmTransformation,
|
||||
|
@ -38,6 +38,7 @@
|
||||
#include "low_precision_transformations/convolution.hpp"
|
||||
#include "low_precision_transformations/scaleshift_to_convolution.hpp"
|
||||
#include "low_precision_transformations/fully_connected.hpp"
|
||||
#include "low_precision_transformations/gemm.hpp"
|
||||
|
||||
using namespace InferenceEngine::details;
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
@ -52,7 +53,7 @@ InferenceEngine::details::LowPrecisionTransformations LayerTransformation::getLo
|
||||
return LowPrecisionTransformer::getAllTransformations(params)
|
||||
.add<FullyConnectedTransformation>(
|
||||
InferenceEngine::details::LayerTransformation::Params(params).setSupportAsymmetricQuantization(false), "FullyConnected")
|
||||
.add<FullyConnectedTransformation>(
|
||||
.add<GemmTransformation>(
|
||||
InferenceEngine::details::LayerTransformation::Params(params).setSupportAsymmetricQuantization(false), "GEMM");
|
||||
}
|
||||
|
||||
|
@ -104,10 +104,6 @@ void GemmTransformation::validate() {
|
||||
|
||||
TEST_P(GemmTransformation, CompareWithRefImpl) {
|
||||
Run();
|
||||
|
||||
if (targetDevice == std::string{CommonTestUtils::DEVICE_GPU}) {
|
||||
PluginCache::get().reset();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user