[LPT] Support FakeQuantize with convert on intervals (#9579)
* [LPT] Support FakeQuantize with convert on intervals * [LPT] GPU tests
This commit is contained in:
parent
12d92dfa2d
commit
b6d60a2c82
@ -29,7 +29,7 @@ public:
|
||||
const std::vector<float>& outputLowValues,
|
||||
const std::vector<float>& outputHighValues);
|
||||
|
||||
static bool outputLayoutIsSupported(std::shared_ptr<opset1::FakeQuantize> quantize);
|
||||
static bool outputLayoutIsSupported(std::shared_ptr<opset1::FakeQuantize> quantize, bool isConvertExpected = false);
|
||||
|
||||
static void getInputIntervals(
|
||||
std::shared_ptr<opset1::FakeQuantize> quantize,
|
||||
|
@ -275,7 +275,7 @@ bool ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(const std::s
|
||||
|
||||
const std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantize = ov::as_type_ptr<ngraph::opset1::FakeQuantize>(parent);
|
||||
if ((fakeQuantize != nullptr) &&
|
||||
QuantizationDetails::outputLayoutIsSupported(fakeQuantize) &&
|
||||
QuantizationDetails::outputLayoutIsSupported(fakeQuantize, true) &&
|
||||
QuantizationDetails::isSupportedLevel(fakeQuantize->get_levels())) {
|
||||
return true;
|
||||
}
|
||||
|
@ -49,11 +49,19 @@ QuantizationDetails::QuantizationDetails(const size_t levels, const std::vector<
|
||||
outputLowValues(outputLowValues),
|
||||
outputHighValues(outputHighValues) {}
|
||||
|
||||
bool QuantizationDetails::outputLayoutIsSupported(std::shared_ptr<opset1::FakeQuantize> quantize) {
|
||||
return ov::is_type<opset1::Constant>(quantize->get_input_node_ptr(1)) &&
|
||||
ov::is_type<opset1::Constant>(quantize->get_input_node_ptr(2)) &&
|
||||
ov::is_type<opset1::Constant>(quantize->get_input_node_ptr(3)) &&
|
||||
ov::is_type<opset1::Constant>(quantize->get_input_node_ptr(4));
|
||||
bool QuantizationDetails::outputLayoutIsSupported(std::shared_ptr<opset1::FakeQuantize> quantize, bool isConvertExpected) {
|
||||
const auto inputs = quantize->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
const auto node = inputs[i].get_source_output().get_node_shared_ptr();
|
||||
bool supported = ov::is_type<opset1::Constant>(node);
|
||||
if (!supported && isConvertExpected) {
|
||||
supported = ov::is_type<op::Convert>(node) && ov::is_type<opset1::Constant>(node->get_input_node_ptr(0));
|
||||
}
|
||||
if (!supported) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void QuantizationDetails::getInputIntervals(
|
||||
|
@ -25,6 +25,11 @@ const std::vector<LayerTransformation::Params> trasformationParamValues = {
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8()
|
||||
};
|
||||
|
||||
const std::vector<bool> isConvertOnConstants = {
|
||||
false,
|
||||
true
|
||||
};
|
||||
|
||||
const std::vector<FakeQuantizeTransformationParam> fakeQuantizeOnDataValues = {
|
||||
{
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
|
||||
@ -82,6 +87,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, FakeQuantizeTransformation,
|
||||
::testing::Values(ngraph::PartialShape({ 1, 32, 72, 48 })),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::ValuesIn(trasformationParamValues),
|
||||
::testing::ValuesIn(fakeQuantizeOnDataValues)),
|
||||
::testing::ValuesIn(fakeQuantizeOnDataValues),
|
||||
::testing::ValuesIn(isConvertOnConstants)),
|
||||
FakeQuantizeTransformation::getTestCaseName);
|
||||
} // namespace
|
||||
|
@ -17,6 +17,11 @@ const std::vector<ngraph::element::Type> netPrecisions = {
|
||||
ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<bool> isConvertOnConstants = {
|
||||
false,
|
||||
true
|
||||
};
|
||||
|
||||
const std::vector<LayerTransformation::Params> trasformationParamValues = {
|
||||
// can not be passed to plugin
|
||||
// nGraph: I8 -> FP32 Convert is not supported
|
||||
@ -65,6 +70,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, FakeQuantizeTransformation,
|
||||
::testing::Values(ngraph::PartialShape({ 1, 32, 72, 48 })),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::ValuesIn(trasformationParamValues),
|
||||
::testing::ValuesIn(fakeQuantizeOnDataValues)),
|
||||
::testing::ValuesIn(fakeQuantizeOnDataValues),
|
||||
::testing::ValuesIn(isConvertOnConstants)),
|
||||
FakeQuantizeTransformation::getTestCaseName);
|
||||
} // namespace
|
||||
|
@ -23,7 +23,8 @@ typedef std::tuple<
|
||||
ngraph::PartialShape,
|
||||
std::string,
|
||||
ngraph::pass::low_precision::LayerTransformation::Params,
|
||||
FakeQuantizeTransformationParam> FakeQuantizeTransformationParams;
|
||||
FakeQuantizeTransformationParam,
|
||||
bool> FakeQuantizeTransformationParams;
|
||||
|
||||
class FakeQuantizeTransformation :
|
||||
public testing::WithParamInterface<FakeQuantizeTransformationParams>,
|
||||
|
@ -23,10 +23,12 @@ std::string FakeQuantizeTransformation::getTestCaseName(const testing::TestParam
|
||||
std::string targetDevice;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
FakeQuantizeTransformationParam testParams;
|
||||
std::tie(netPrecision, inputShape, targetDevice, params, testParams) = obj.param;
|
||||
bool isConvertOnConstants;
|
||||
std::tie(netPrecision, inputShape, targetDevice, params, testParams, isConvertOnConstants) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << getTestCaseNameByParams(netPrecision, inputShape, targetDevice, params) << "_" << testParams.fakequantize;
|
||||
result << getTestCaseNameByParams(netPrecision, inputShape, targetDevice, params) << "_" <<
|
||||
isConvertOnConstants << "_" << testParams.fakequantize;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
@ -35,7 +37,10 @@ void FakeQuantizeTransformation::SetUp() {
|
||||
ngraph::PartialShape inputShape;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
FakeQuantizeTransformationParam testParams;
|
||||
std::tie(netPrecision, inputShape, targetDevice, params, testParams) = this->GetParam();
|
||||
bool isConvertOnConstants;
|
||||
std::tie(netPrecision, inputShape, targetDevice, params, testParams, isConvertOnConstants) = this->GetParam();
|
||||
|
||||
testParams.fakequantize.addConverts = isConvertOnConstants;
|
||||
|
||||
function = ngraph::builder::subgraph::FakeQuantizeFunction::getOriginal(
|
||||
params,
|
||||
|
@ -71,8 +71,8 @@ public:
|
||||
const std::vector<float>& outputLowValues,
|
||||
const std::vector<float>& outputHighValues,
|
||||
const ngraph::element::Type outputPrecision = ngraph::element::undefined,
|
||||
|
||||
const std::vector<ov::Any>& attributes = {});
|
||||
const std::vector<ov::Any>& attributes = {},
|
||||
const bool addConverts = false);
|
||||
virtual ~FakeQuantizeOnDataWithConstant();
|
||||
|
||||
virtual bool empty() const;
|
||||
@ -85,6 +85,7 @@ public:
|
||||
std::vector<float> outputHighValues;
|
||||
ngraph::element::Type outputPrecision;
|
||||
std::vector<ov::Any> attributes;
|
||||
bool addConverts;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnDataWithConstant& data) {
|
||||
|
@ -282,6 +282,9 @@ std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
|
||||
fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[0],
|
||||
fqOnData.inputLowValues,
|
||||
fqOnData.inputLowValues.empty());
|
||||
if (fqOnData.addConverts) {
|
||||
inputLowNode = ngraph::builder::makeConversion(inputLowNode, ov::element::f32, ngraph::helpers::ConversionTypes::CONVERT);
|
||||
}
|
||||
|
||||
inputHighNode = ngraph::builder::makeConstant(
|
||||
constantPrecision,
|
||||
@ -290,23 +293,32 @@ std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
|
||||
(fqOnData.constantShapes.size() == 1 ? fqOnData.constantShapes[0] : fqOnData.constantShapes[1]),
|
||||
fqOnData.inputHighValues,
|
||||
fqOnData.inputHighValues.empty());
|
||||
if (fqOnData.addConverts) {
|
||||
inputHighNode = ngraph::builder::makeConversion(inputHighNode, ov::element::f32, ngraph::helpers::ConversionTypes::CONVERT);
|
||||
}
|
||||
}
|
||||
|
||||
const auto outputLowNode = ngraph::builder::makeConstant(
|
||||
auto outputLowNode = ngraph::builder::makeConstant(
|
||||
constantPrecision,
|
||||
fqOnData.constantShapes.empty() ?
|
||||
ngraph::Shape{} :
|
||||
(fqOnData.constantShapes.size() == 1 ? fqOnData.constantShapes[0] : fqOnData.constantShapes[2]),
|
||||
fqOnData.outputLowValues,
|
||||
fqOnData.outputLowValues.empty());
|
||||
if (fqOnData.addConverts) {
|
||||
outputLowNode = ngraph::builder::makeConversion(outputLowNode, ov::element::f32, ngraph::helpers::ConversionTypes::CONVERT);
|
||||
}
|
||||
|
||||
const auto outputHighNode = ngraph::builder::makeConstant(
|
||||
auto outputHighNode = ngraph::builder::makeConstant(
|
||||
constantPrecision,
|
||||
fqOnData.constantShapes.empty() ?
|
||||
ngraph::Shape{} :
|
||||
(fqOnData.constantShapes.size() == 1 ? fqOnData.constantShapes[0] : fqOnData.constantShapes[3]),
|
||||
fqOnData.outputHighValues,
|
||||
fqOnData.outputHighValues.empty());
|
||||
if (fqOnData.addConverts) {
|
||||
outputHighNode = ngraph::builder::makeConversion(outputHighNode, ov::element::f32, ngraph::helpers::ConversionTypes::CONVERT);
|
||||
}
|
||||
|
||||
auto fq = std::make_shared<ngraph::opset1::FakeQuantize>(input, inputLowNode, inputHighNode, outputLowNode, outputHighNode, fqOnData.quantizationLevel);
|
||||
|
||||
|
@ -58,7 +58,8 @@ FakeQuantizeOnDataWithConstant::FakeQuantizeOnDataWithConstant(
|
||||
const std::vector<float>& outputLowValues,
|
||||
const std::vector<float>& outputHighValues,
|
||||
const ngraph::element::Type outputPrecision,
|
||||
const std::vector<ov::Any>& attributes) :
|
||||
const std::vector<ov::Any>& attributes,
|
||||
const bool addConverts) :
|
||||
quantizationLevel(quantizationLevel),
|
||||
constantShapes(constantShapes),
|
||||
inputLowValues(inputLowValues),
|
||||
@ -66,7 +67,8 @@ FakeQuantizeOnDataWithConstant::FakeQuantizeOnDataWithConstant(
|
||||
outputLowValues(outputLowValues),
|
||||
outputHighValues(outputHighValues),
|
||||
outputPrecision(outputPrecision),
|
||||
attributes(attributes)
|
||||
attributes(attributes),
|
||||
addConverts(addConverts)
|
||||
{}
|
||||
|
||||
FakeQuantizeOnDataWithConstant::~FakeQuantizeOnDataWithConstant() {}
|
||||
|
Loading…
Reference in New Issue
Block a user