[LPT] Concat: different branch precisions support (#17330)
* [LPT] Concat: different branch precisions support
This commit is contained in:
parent
8e675c71c8
commit
9c3186b243
@ -236,7 +236,23 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
|
||||
return dqOnlyByConcatAxis;
|
||||
};
|
||||
|
||||
const auto check_const_precision = [](
|
||||
const FakeQuantizeDequantization& dequantization,
|
||||
const std::shared_ptr<Node>& constant,
|
||||
ov::element::Type& const_precision) {
|
||||
if (constant == nullptr) {
|
||||
return true;
|
||||
}
|
||||
if (const_precision == element::undefined) {
|
||||
const_precision = constant->get_element_type();
|
||||
return true;
|
||||
}
|
||||
return const_precision == constant->get_element_type();
|
||||
};
|
||||
|
||||
element::Type precision;
|
||||
element::Type const_precision;
|
||||
|
||||
for (size_t i = 0ul; i < concat->get_input_size(); i++) {
|
||||
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(concat, defaultPrecisions, i);
|
||||
if (dequantization.empty() || (updatePrecisions && !dequantization.isLowPrecision())) {
|
||||
@ -253,6 +269,12 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
|
||||
} else if (precision != dequantization.data.get_element_type()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!check_const_precision(dequantization, dequantization.subtractConvert, const_precision) ||
|
||||
((dequantization.subtractConvert == nullptr) && !check_const_precision(dequantization, dequantization.subtractConstant, const_precision)) ||
|
||||
!check_const_precision(dequantization, dequantization.multiplyConstant, const_precision)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -19,38 +19,66 @@ const std::vector<ngraph::element::Type> precisions = {
|
||||
const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
// U8
|
||||
{
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
|
||||
{},
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{}
|
||||
},
|
||||
// I8
|
||||
{
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
|
||||
{},
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||
{}
|
||||
},
|
||||
// mixed: U8 + I8
|
||||
{
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
|
||||
{},
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||
{}
|
||||
},
|
||||
// mixed: I8 + U8
|
||||
{
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
|
||||
{},
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{}
|
||||
},
|
||||
// FQ with unexpected quantizationLevels
|
||||
{
|
||||
{},
|
||||
{ 14ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
|
||||
{ 14ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
|
||||
{},
|
||||
{},
|
||||
{ 14ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
|
||||
{}
|
||||
},
|
||||
// FQ with INT4 quantizationLevels
|
||||
{
|
||||
{},
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
|
||||
{},
|
||||
{},
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
|
||||
{}
|
||||
},
|
||||
// FQ with INT4+INT8 quantizationLevels
|
||||
{
|
||||
{},
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
|
||||
{},
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{}
|
||||
},
|
||||
};
|
||||
|
||||
@ -67,3 +95,34 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, ConcatTransformation,
|
||||
::testing::ValuesIn(testValues)),
|
||||
ConcatTransformation::getTestCaseName);
|
||||
} // namespace
|
||||
|
||||
namespace concat_transformation_mixed {
|
||||
|
||||
const std::vector<ngraph::element::Type> precisions = {
|
||||
ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
// mixed dequantization: FP32 & FP16
|
||||
{
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{},
|
||||
std::make_shared<ngraph::opset1::Constant>(ov::element::u8, ov::Shape{1, 3, 16, 16}, std::vector<float>(3 * 16 * 16, 1.0)),
|
||||
{},
|
||||
{
|
||||
{ ov::element::f16 },
|
||||
{},
|
||||
{{1.f, 2.f, 3.f}, ov::element::f16, ov::Shape{1, 3, 1, 1}},
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT, ConcatTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(ngraph::PartialShape({ 1, 3, 16, 16 })),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::ValuesIn(testValues)),
|
||||
ConcatTransformation::getTestCaseName);
|
||||
} // namespace concat_transformation_mixed
|
||||
|
@ -19,23 +19,39 @@ const std::vector<ngraph::element::Type> precisions = {
|
||||
const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
// U8
|
||||
{
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
|
||||
{},
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{}
|
||||
},
|
||||
// I8
|
||||
{
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
|
||||
{},
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||
{}
|
||||
},
|
||||
// mixed
|
||||
{
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
|
||||
{},
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||
{}
|
||||
},
|
||||
// FQ with unexpected quantizationLevels
|
||||
{
|
||||
{},
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
|
||||
{},
|
||||
{},
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
|
||||
{}
|
||||
},
|
||||
};
|
||||
|
||||
@ -47,3 +63,35 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, ConcatTransformation,
|
||||
::testing::ValuesIn(testValues)),
|
||||
ConcatTransformation::getTestCaseName);
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace concat_transformation_mixed {
|
||||
|
||||
const std::vector<ngraph::element::Type> precisions = {
|
||||
ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
// mixed dequantization: FP32 & FP16
|
||||
{
|
||||
{},
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{},
|
||||
std::make_shared<ngraph::opset1::Constant>(ov::element::u8, ov::Shape{1, 3, 16, 16}, std::vector<float>(3 * 16 * 16, 1.0)),
|
||||
{},
|
||||
{
|
||||
{ ov::element::f16 },
|
||||
{},
|
||||
{{1.f, 2.f, 3.f}, ov::element::f16, ov::Shape{1, 3, 1, 1}},
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT, ConcatTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(ngraph::PartialShape({ 1, 3, 16, 16 })),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::ValuesIn(testValues)),
|
||||
ConcatTransformation::getTestCaseName);
|
||||
} // namespace concat_transformation_mixed
|
||||
|
@ -8,14 +8,19 @@
|
||||
#include <memory>
|
||||
|
||||
#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
class ConcatTransformationTestValues {
|
||||
public:
|
||||
std::shared_ptr<ngraph::opset1::Constant> input_constant1;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fqOnData1;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization1;
|
||||
std::shared_ptr<ngraph::opset1::Constant> input_constant2;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fqOnData2;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization2;
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
|
@ -26,7 +26,11 @@ std::string ConcatTransformation::getTestCaseName(const testing::TestParamInfo<C
|
||||
const auto params = LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8();
|
||||
|
||||
std::ostringstream result;
|
||||
result << getTestCaseNameByParams(precision, inputShapes, targetDevice, params) << testValues.fqOnData1 << testValues.fqOnData2;
|
||||
result << getTestCaseNameByParams(precision, inputShapes, targetDevice, params) <<
|
||||
testValues.fqOnData1 <<
|
||||
testValues.dequantization1 <<
|
||||
testValues.fqOnData2 <<
|
||||
testValues.dequantization2;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
@ -50,8 +54,12 @@ void ConcatTransformation::SetUp() {
|
||||
function = ngraph::builder::subgraph::ConcatFunction::getOriginal(
|
||||
precision,
|
||||
inputShape,
|
||||
testValues.input_constant1,
|
||||
testValues.fqOnData1,
|
||||
testValues.fqOnData2);
|
||||
testValues.dequantization1,
|
||||
testValues.input_constant2,
|
||||
testValues.fqOnData2,
|
||||
testValues.dequantization2);
|
||||
}
|
||||
|
||||
TEST_P(ConcatTransformation, CompareWithRefImpl) {
|
||||
|
@ -29,8 +29,12 @@ public:
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const std::shared_ptr<ngraph::opset1::Constant>& input_constant1,
|
||||
const FakeQuantizeOnData& fakeQuantize1,
|
||||
const FakeQuantizeOnData& fakeQuantize2);
|
||||
const DequantizationOperations& dequantization1,
|
||||
const std::shared_ptr<ngraph::opset1::Constant>& input_constant2,
|
||||
const FakeQuantizeOnData& fakeQuantize2,
|
||||
const DequantizationOperations& dequantization2);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
|
@ -66,27 +66,64 @@ std::shared_ptr<ov::Model> ConcatFunction::get(
|
||||
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const std::shared_ptr<ngraph::opset1::Constant>& input_constant1,
|
||||
const FakeQuantizeOnData& fqOnData1,
|
||||
const FakeQuantizeOnData& fqOnData2) {
|
||||
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
|
||||
input1->set_friendly_name("input1");
|
||||
const auto fakeQuantize1 = makeFakeQuantize(input1, precision, fqOnData1);
|
||||
const DequantizationOperations& dequantization1,
|
||||
const std::shared_ptr<ngraph::opset1::Constant>& input_constant2,
|
||||
const FakeQuantizeOnData& fqOnData2,
|
||||
const DequantizationOperations& dequantization2) {
|
||||
std::shared_ptr<Node> parent1;
|
||||
std::shared_ptr<ngraph::opset1::Parameter> input1;
|
||||
if (input_constant1 == nullptr) {
|
||||
input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
|
||||
input1->set_friendly_name("input1");
|
||||
parent1 = input1;
|
||||
} else {
|
||||
parent1 = input_constant1;
|
||||
}
|
||||
|
||||
const auto inputShape2 = inputShape;
|
||||
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
|
||||
input2->set_friendly_name("input2");
|
||||
const auto fakeQuantize2 = makeFakeQuantize(input2, precision, fqOnData2);
|
||||
if (!fqOnData1.empty()) {
|
||||
parent1 = makeFakeQuantize(parent1, precision, fqOnData1);
|
||||
}
|
||||
if (!dequantization1.empty()) {
|
||||
parent1 = makeDequantization(parent1, dequantization1);
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> parent2;
|
||||
std::shared_ptr<ngraph::opset1::Parameter> input2;
|
||||
if (input_constant2 == nullptr) {
|
||||
input2 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
|
||||
input2->set_friendly_name("input2");
|
||||
parent2 = input2;
|
||||
} else {
|
||||
parent2 = input_constant2;
|
||||
}
|
||||
|
||||
if (!fqOnData2.empty()) {
|
||||
parent2 = makeFakeQuantize(parent2, precision, fqOnData2);
|
||||
}
|
||||
if (!dequantization2.empty()) {
|
||||
parent2 = makeDequantization(parent2, dequantization2);
|
||||
}
|
||||
|
||||
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
|
||||
ngraph::OutputVector{ fakeQuantize1->output(0), fakeQuantize2->output(0) }, 1);
|
||||
ngraph::OutputVector{ parent1->output(0), parent2->output(0) }, 1);
|
||||
concat->set_friendly_name("output");
|
||||
auto& rtInfo = concat->get_rt_info();
|
||||
rtInfo["Variant::std::string"] = "concat";
|
||||
|
||||
ngraph::ParameterVector inputs;
|
||||
if (input1 != nullptr) {
|
||||
inputs.push_back(input1);
|
||||
}
|
||||
if (input2 != nullptr) {
|
||||
inputs.push_back(input2);
|
||||
}
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(concat) };
|
||||
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
|
||||
results,
|
||||
ngraph::ParameterVector{ input1, input2 },
|
||||
inputs,
|
||||
"ConcatTransformation");
|
||||
|
||||
return function;
|
||||
|
Loading…
Reference in New Issue
Block a user