[LPT] Concat: different branch precisions support (#17330)

* [LPT] Concat: different branch precisions support
This commit is contained in:
Edward Shogulin 2023-05-07 11:38:32 +01:00 committed by GitHub
parent 8e675c71c8
commit 9c3186b243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 207 additions and 24 deletions

View File

@ -236,7 +236,23 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
return dqOnlyByConcatAxis;
};
const auto check_const_precision = [](
const FakeQuantizeDequantization& dequantization,
const std::shared_ptr<Node>& constant,
ov::element::Type& const_precision) {
if (constant == nullptr) {
return true;
}
if (const_precision == element::undefined) {
const_precision = constant->get_element_type();
return true;
}
return const_precision == constant->get_element_type();
};
element::Type precision;
element::Type const_precision;
for (size_t i = 0ul; i < concat->get_input_size(); i++) {
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(concat, defaultPrecisions, i);
if (dequantization.empty() || (updatePrecisions && !dequantization.isLowPrecision())) {
@ -253,6 +269,12 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
} else if (precision != dequantization.data.get_element_type()) {
return false;
}
if (!check_const_precision(dequantization, dequantization.subtractConvert, const_precision) ||
((dequantization.subtractConvert == nullptr) && !check_const_precision(dequantization, dequantization.subtractConstant, const_precision)) ||
!check_const_precision(dequantization, dequantization.multiplyConstant, const_precision)) {
return false;
}
}
return true;
}

View File

@ -19,38 +19,66 @@ const std::vector<ngraph::element::Type> precisions = {
const std::vector<ConcatTransformationTestValues> testValues = {
// U8
{
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
{},
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{}
},
// I8
{
{},
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
{},
{},
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{}
},
// mixed: U8 + I8
{
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
{},
{},
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{}
},
// mixed: I8 + U8
{
{},
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
{},
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{}
},
// FQ with unexpected quantizationLevels
{
{},
{ 14ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
{ 14ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
{},
{},
{ 14ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
{}
},
// FQ with INT4 quantizationLevels
{
{},
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
{},
{},
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
{}
},
// FQ with INT4+INT8 quantizationLevels
{
{},
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
{},
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{}
},
};
@ -67,3 +95,34 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, ConcatTransformation,
::testing::ValuesIn(testValues)),
ConcatTransformation::getTestCaseName);
} // namespace
namespace concat_transformation_mixed {
const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f16
};
const std::vector<ConcatTransformationTestValues> testValues = {
// mixed dequantization: FP32 & FP16
{
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
std::make_shared<ngraph::opset1::Constant>(ov::element::u8, ov::Shape{1, 3, 16, 16}, std::vector<float>(3 * 16 * 16, 1.0)),
{},
{
{ ov::element::f16 },
{},
{{1.f, 2.f, 3.f}, ov::element::f16, ov::Shape{1, 3, 1, 1}},
},
}
};
INSTANTIATE_TEST_SUITE_P(smoke_LPT, ConcatTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::Values(ngraph::PartialShape({ 1, 3, 16, 16 })),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::ValuesIn(testValues)),
ConcatTransformation::getTestCaseName);
} // namespace concat_transformation_mixed

View File

@ -19,23 +19,39 @@ const std::vector<ngraph::element::Type> precisions = {
const std::vector<ConcatTransformationTestValues> testValues = {
// U8
{
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
{},
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{}
},
// I8
{
{},
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
{},
{},
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{}
},
// mixed
{
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
{},
{},
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{}
},
// FQ with unexpected quantizationLevels
{
{},
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
{},
{},
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
{}
},
};
@ -47,3 +63,35 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, ConcatTransformation,
::testing::ValuesIn(testValues)),
ConcatTransformation::getTestCaseName);
} // namespace
namespace concat_transformation_mixed {
const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f16
};
const std::vector<ConcatTransformationTestValues> testValues = {
// mixed dequantization: FP32 & FP16
{
{},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
std::make_shared<ngraph::opset1::Constant>(ov::element::u8, ov::Shape{1, 3, 16, 16}, std::vector<float>(3 * 16 * 16, 1.0)),
{},
{
{ ov::element::f16 },
{},
{{1.f, 2.f, 3.f}, ov::element::f16, ov::Shape{1, 3, 1, 1}},
},
}
};
INSTANTIATE_TEST_SUITE_P(smoke_LPT, ConcatTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::Values(ngraph::PartialShape({ 1, 3, 16, 16 })),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::ValuesIn(testValues)),
ConcatTransformation::getTestCaseName);
} // namespace concat_transformation_mixed

View File

@ -8,14 +8,19 @@
#include <memory>
#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp"
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
namespace LayerTestsDefinitions {
class ConcatTransformationTestValues {
public:
std::shared_ptr<ngraph::opset1::Constant> input_constant1;
ngraph::builder::subgraph::FakeQuantizeOnData fqOnData1;
ngraph::builder::subgraph::DequantizationOperations dequantization1;
std::shared_ptr<ngraph::opset1::Constant> input_constant2;
ngraph::builder::subgraph::FakeQuantizeOnData fqOnData2;
ngraph::builder::subgraph::DequantizationOperations dequantization2;
};
typedef std::tuple<

View File

@ -26,7 +26,11 @@ std::string ConcatTransformation::getTestCaseName(const testing::TestParamInfo<C
const auto params = LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8();
std::ostringstream result;
result << getTestCaseNameByParams(precision, inputShapes, targetDevice, params) << testValues.fqOnData1 << testValues.fqOnData2;
result << getTestCaseNameByParams(precision, inputShapes, targetDevice, params) <<
testValues.fqOnData1 <<
testValues.dequantization1 <<
testValues.fqOnData2 <<
testValues.dequantization2;
return result.str();
}
@ -50,8 +54,12 @@ void ConcatTransformation::SetUp() {
function = ngraph::builder::subgraph::ConcatFunction::getOriginal(
precision,
inputShape,
testValues.input_constant1,
testValues.fqOnData1,
testValues.fqOnData2);
testValues.dequantization1,
testValues.input_constant2,
testValues.fqOnData2,
testValues.dequantization2);
}
TEST_P(ConcatTransformation, CompareWithRefImpl) {

View File

@ -29,8 +29,12 @@ public:
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape,
const std::shared_ptr<ngraph::opset1::Constant>& input_constant1,
const FakeQuantizeOnData& fakeQuantize1,
const FakeQuantizeOnData& fakeQuantize2);
const DequantizationOperations& dequantization1,
const std::shared_ptr<ngraph::opset1::Constant>& input_constant2,
const FakeQuantizeOnData& fakeQuantize2,
const DequantizationOperations& dequantization2);
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type precision,

View File

@ -66,27 +66,64 @@ std::shared_ptr<ov::Model> ConcatFunction::get(
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginal(
const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape,
const std::shared_ptr<ngraph::opset1::Constant>& input_constant1,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2) {
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
input1->set_friendly_name("input1");
const auto fakeQuantize1 = makeFakeQuantize(input1, precision, fqOnData1);
const DequantizationOperations& dequantization1,
const std::shared_ptr<ngraph::opset1::Constant>& input_constant2,
const FakeQuantizeOnData& fqOnData2,
const DequantizationOperations& dequantization2) {
std::shared_ptr<Node> parent1;
std::shared_ptr<ngraph::opset1::Parameter> input1;
if (input_constant1 == nullptr) {
input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
input1->set_friendly_name("input1");
parent1 = input1;
} else {
parent1 = input_constant1;
}
const auto inputShape2 = inputShape;
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
input2->set_friendly_name("input2");
const auto fakeQuantize2 = makeFakeQuantize(input2, precision, fqOnData2);
if (!fqOnData1.empty()) {
parent1 = makeFakeQuantize(parent1, precision, fqOnData1);
}
if (!dequantization1.empty()) {
parent1 = makeDequantization(parent1, dequantization1);
}
std::shared_ptr<Node> parent2;
std::shared_ptr<ngraph::opset1::Parameter> input2;
if (input_constant2 == nullptr) {
input2 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
input2->set_friendly_name("input2");
parent2 = input2;
} else {
parent2 = input_constant2;
}
if (!fqOnData2.empty()) {
parent2 = makeFakeQuantize(parent2, precision, fqOnData2);
}
if (!dequantization2.empty()) {
parent2 = makeDequantization(parent2, dequantization2);
}
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
ngraph::OutputVector{ fakeQuantize1->output(0), fakeQuantize2->output(0) }, 1);
ngraph::OutputVector{ parent1->output(0), parent2->output(0) }, 1);
concat->set_friendly_name("output");
auto& rtInfo = concat->get_rt_info();
rtInfo["Variant::std::string"] = "concat";
ngraph::ParameterVector inputs;
if (input1 != nullptr) {
inputs.push_back(input1);
}
if (input2 != nullptr) {
inputs.push_back(input2);
}
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(concat) };
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
ngraph::ParameterVector{ input1, input2 },
inputs,
"ConcatTransformation");
return function;