[LPT] INT4 FakeQuantize not transform (#5082)

This commit is contained in:
Vladimir Zinoviev 2021-04-29 18:24:21 +03:00 committed by GitHub
parent 68ed12cb98
commit 19afae3638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 105 additions and 38 deletions

View File

@ -25,10 +25,31 @@ const std::vector<LayerTransformation::Params> trasformationParamValues = {
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8()
};
const std::vector<ngraph::builder::subgraph::FakeQuantizeOnData> fakeQuantizeOnDataValues = {
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, { 1ul }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, {}, { 0.f }, { 2.55f }, { 2.55f }, { 2.55f } },
const std::vector<FakeQuantizeTransformationParam> fakeQuantizeOnDataValues = {
{
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
"Pooling", "U8"
},
{
{ 256ul, { 1ul }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
"Pooling", "U8"
},
{
{ 256ul, {}, { 0.f }, { 2.55f }, { -1.28f }, { 1.27f } },
"Pooling", "I8"
},
{
{ 256ul, {}, { 0.f }, { 2.55f }, { 2.55f }, { 2.55f } },
"Pooling", "U8"
},
{
{ 16ul, {}, { 0.f }, { 1.5f }, { 0.f }, { 1.5f } },
"Pooling", "FP32"
},
{
{ 16ul, {}, { -8.f }, { 7.f }, { -0.8f }, { 0.7f } },
"Pooling", "FP32"
},
// nGraph: I8->FP32 Convert is not supported
// { 256ul, {}, { -1.28f} , { 1.27f }, { -1.28f} , { 1.27f } },
// { 256ul, { 1ul }, { -1.28f} , { 1.27f } }

View File

@ -25,9 +25,31 @@ const std::vector<LayerTransformation::Params> trasformationParamValues = {
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8()
};
const std::vector<ngraph::builder::subgraph::FakeQuantizeOnData> fakeQuantizeOnDataValues = {
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, { 1ul }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
const std::vector<FakeQuantizeTransformationParam> fakeQuantizeOnDataValues = {
{
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
"Pooling", "U8"
},
{
{ 256ul, { 1ul }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
"Pooling", "U8"
},
{
{ 256ul, {}, { 0.f }, { 2.55f }, { -1.28f }, { 1.27f } },
"Pooling", "I8"
},
{
{ 256ul, {}, { 0.f }, { 2.55f }, { 2.55f }, { 2.55f } },
"Pooling", "U8"
},
{
{ 16ul, {}, { 0.f }, { 1.5f }, { 0.f }, { 1.5f } },
"Pooling", "FP32"
},
{
{ 16ul, {}, { -8.f }, { 7.f }, { -0.8f }, { 0.7f } },
"Pooling", "FP32"
},
// nGraph: I8->FP32 Convert is not supported
// { 256ul, {}, { -1.28f} , { 1.27f }, { -1.28f} , { 1.27f } },
// { 256ul, { 1ul }, { -1.28f} , { 1.27f } }

View File

@ -10,13 +10,20 @@
#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp"
namespace LayerTestsDefinitions {
class FakeQuantizeTransformationParam {
public:
ngraph::builder::subgraph::FakeQuantizeOnData fakequantize;
std::string layerName;
std::string expectedKernelType;
};
typedef std::tuple<
ngraph::element::Type,
ngraph::Shape,
std::string,
ngraph::pass::low_precision::LayerTransformation::Params,
ngraph::builder::subgraph::FakeQuantizeOnData> FakeQuantizeTransformationParams;
FakeQuantizeTransformationParam> FakeQuantizeTransformationParams;
class FakeQuantizeTransformation :
public testing::WithParamInterface<FakeQuantizeTransformationParams>,
@ -27,8 +34,7 @@ public:
protected:
void SetUp() override;
private:
void validate();
void Run() override;
};
} // namespace LayerTestsDefinitions

View File

@ -22,11 +22,11 @@ std::string FakeQuantizeTransformation::getTestCaseName(testing::TestParamInfo<F
ngraph::Shape inputShape;
std::string targetDevice;
ngraph::pass::low_precision::LayerTransformation::Params params;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantizeOnData;
std::tie(netPrecision, inputShape, targetDevice, params, fakeQuantizeOnData) = obj.param;
FakeQuantizeTransformationParam testParams;
std::tie(netPrecision, inputShape, targetDevice, params, testParams) = obj.param;
std::ostringstream result;
result << getTestCaseNameByParams(netPrecision, inputShape, targetDevice, params) << "_" << fakeQuantizeOnData;
result << getTestCaseNameByParams(netPrecision, inputShape, targetDevice, params) << "_" << testParams.fakequantize;
return result.str();
}
@ -34,37 +34,25 @@ void FakeQuantizeTransformation::SetUp() {
ngraph::element::Type netPrecision;
ngraph::Shape inputShape;
ngraph::pass::low_precision::LayerTransformation::Params params;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantizeOnData;
std::tie(netPrecision, inputShape, targetDevice, params, fakeQuantizeOnData) = this->GetParam();
FakeQuantizeTransformationParam testParams;
std::tie(netPrecision, inputShape, targetDevice, params, testParams) = this->GetParam();
function = ngraph::builder::subgraph::FakeQuantizeFunction::getOriginal(
function = ngraph::builder::subgraph::FakeQuantizeFunction::getOriginalWithMaxPool(
netPrecision,
inputShape,
fakeQuantizeOnData);
ngraph::pass::InitNodeInfo().run_on_function(function);
validate();
testParams.fakequantize);
}
void FakeQuantizeTransformation::validate() {
ngraph::element::Type precision;
ngraph::Shape inputShapes;
std::string targetDevice;
ngraph::pass::low_precision::LayerTransformation::Params params;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantizeOnData;
std::tie(precision, inputShapes, targetDevice, params, fakeQuantizeOnData) = this->GetParam();
void FakeQuantizeTransformation::Run() {
LayerTestsCommon::Run();
auto transformations = getLowPrecisionTransformationsNGraph(params);
transformations.removeStandaloneCleanup<ngraph::pass::low_precision::FuseSubtractToFakeQuantizeTransformation, ngraph::opset1::Subtract>();
transformations.removeStandaloneCleanup<ngraph::pass::low_precision::FuseMultiplyToFakeQuantizeTransformation, ngraph::opset1::Multiply>();
const auto transformed = transformNGraph(params, transformations);
EXPECT_EQ(1ul, transformed->get_output_size());
const auto output = transformed->get_output_op(0);
const auto scaleShift = output->get_input_node_shared_ptr(0);
const std::string typeName = scaleShift->get_type_name();
ASSERT_EQ("ScaleShiftIE", typeName);
const auto params = std::get<4>(GetParam());
const auto actualPrecision = getRuntimePrecisionByType(params.layerName);
auto expectedPrecision = params.expectedKernelType;
if (expectedPrecision == "FP32" && std::get<0>(GetParam()) == ngraph::element::f16) {
expectedPrecision = "FP16";
}
EXPECT_EQ(actualPrecision, expectedPrecision);
}
TEST_P(FakeQuantizeTransformation, CompareWithRefImpl) {

View File

@ -23,6 +23,11 @@ public:
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fakeQuantizeOnData);
static std::shared_ptr<ngraph::Function> getOriginalWithMaxPool(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fakeQuantizeOnData);
static std::shared_ptr<ngraph::Function> getReference(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,

View File

@ -20,6 +20,31 @@ namespace subgraph {
using namespace ngraph::pass;
std::shared_ptr<ngraph::Function> FakeQuantizeFunction::getOriginalWithMaxPool(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fakeQuantizeOnData) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape));
input->set_friendly_name("input");
const auto fakeQuantize = ngraph::builder::makeFakeQuantize(
input, element::f32, fakeQuantizeOnData.quantizationLevel, fakeQuantizeOnData.constantShape,
fakeQuantizeOnData.inputLowValues, fakeQuantizeOnData.inputHighValues, fakeQuantizeOnData.outputLowValues, fakeQuantizeOnData.outputHighValues);
const auto maxPool = std::make_shared<opset1::MaxPool>(
fakeQuantize,
Strides{ 1, 1 },
Shape{ 1, 1 },
Shape{ 0, 0 },
Shape{ 2, 2 });
fakeQuantize->set_friendly_name("fakeQuantize");
auto& rtInfo = fakeQuantize->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("fakeQuantize");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxPool) };
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "FakeQuantizeFunction");
}
std::shared_ptr<ngraph::Function> FakeQuantizeFunction::getOriginal(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,