[LPT] Rank limitations removed (#14785)

* [LPT] LayerTransformation: removed legacy rank checks

* [LPT] Added test cases with 1D and 6D ranks & existing tests corrected
This commit is contained in:
Vladislav Golubev 2023-01-31 01:26:59 +01:00 committed by GitHub
parent 0b5603fa98
commit d1397b7b48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 46 additions and 115 deletions

View File

@ -35,18 +35,6 @@ public:
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
static bool isQuantizedStatic(const std::shared_ptr<const Node>& layer);
protected:
static bool isHandled(
const TransformationContext& context,
const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations);
void fillDequantizationNodes(
const std::vector<FakeQuantizeDequantization>& layerDequantizations,
const std::shared_ptr<Node> layer,
NodeVector& convertNodes,
NodeVector& subtractNodes,
NodeVector& multiplyNodes) const;
};
} // namespace low_precision

View File

@ -253,81 +253,6 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
return true;
}
void ConcatTransformation::fillDequantizationNodes(
const std::vector<FakeQuantizeDequantization>& layerDequantizations,
const std::shared_ptr<Node> layer,
NodeVector& convertNodes,
NodeVector& subtractNodes,
NodeVector& multiplyNodes) const {
if (layerDequantizations.size() > 1ul) {
auto broadcastElementWiseConst = [](
// FakeQuantize constant shape must be broadcastable to the shape on data.
std::shared_ptr<ngraph::opset1::Constant> operation,
const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
auto targetShapeConst = opset1::Constant::create(element::i64, ngraph::Shape{ targetShape.size() }, targetShape);
auto broadcast = fold<ngraph::opset1::Broadcast>(operation, targetShapeConst);
return broadcast;
};
bool allDequantizationShiftAreZero = true;
bool allDequantizationMultiplyAreZero = true;
for (const auto& dequantization : layerDequantizations) {
if (dequantization.subtract != nullptr) {
allDequantizationShiftAreZero = false;
}
if (dequantization.multiply != nullptr) {
allDequantizationMultiplyAreZero = false;
}
}
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
const auto& dequantization = layerDequantizations[i];
const ngraph::element::Type precision = deqPrecision;
ngraph::Shape targetShape(layer->get_input_partial_shape(i).rank().get_length(), 1ul);
targetShape[1] = layer->get_input_partial_shape(i)[1].get_length();
if (dequantization.convert != nullptr) {
convertNodes.push_back(dequantization.convert);
}
if (!allDequantizationShiftAreZero) {
subtractNodes.push_back(dequantization.subtract == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 0.f })) :
broadcastElementWiseConst(dequantization.subtractConstant, targetShape));
}
if (!allDequantizationMultiplyAreZero) {
multiplyNodes.push_back(dequantization.multiply == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 1.0f })) :
broadcastElementWiseConst(dequantization.multiplyConstant, targetShape));
}
}
} else {
// TODO: check constant shapes here - has to be scalar
if (layerDequantizations[0].convert != nullptr) {
convertNodes.push_back(layerDequantizations[0].convert);
}
if (layerDequantizations[0].subtract != nullptr) {
subtractNodes.push_back(layerDequantizations[0].subtract->input_value(1).get_node_shared_ptr());
}
if (layerDequantizations[0].multiply != nullptr) {
multiplyNodes.push_back(layerDequantizations[0].multiply->input_value(1).get_node_shared_ptr());
}
}
}
bool ConcatTransformation::isHandled(const TransformationContext& context, const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations) {
for (const std::shared_ptr<ngraph::Node>& quantizationLayer : quantizationOperations) {
if (context.quantizedFakeQuantizeNames.find(quantizationLayer->get_friendly_name()) != context.quantizedFakeQuantizeNames.end()) {
return true;
}
}
return false;
}
bool ConcatTransformation::isQuantizedStatic(const std::shared_ptr<const Node>& layer) {
const auto concat = as_type_ptr<const opset1::Concat>(layer);
if (concat == nullptr)

View File

@ -53,11 +53,10 @@ bool LayerTransformation::canBeTransformed(const TransformationContext& context,
bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& layer,
const std::vector<ngraph::element::Type>& defaultPrecisions) {
for (const auto& output : layer->outputs()) {
const auto rank = output.get_partial_shape().rank();
if (rank.is_dynamic() || rank.get_length() < 2) {
return false;
}
const auto outputs = layer->outputs();
if (std::any_of(outputs.begin(), outputs.end(),
[](const Output<Node>& out) { return out.get_partial_shape().rank().is_dynamic(); })) {
return false;
}
const auto dequantization = NetworkHelper::getDequantization(layer, defaultPrecisions);
@ -72,8 +71,7 @@ bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& la
return false;
}
const auto dataShapeSize = static_cast<size_t>(rank.get_length());
if ((dataShapeSize - constShape.size()) == 1ul) {
if ((dataPShape.size() - constShape.size()) == 1ul) {
constShape.insert(constShape.begin(), 1ul);
}
@ -115,18 +113,10 @@ bool LayerTransformation::canBeTransformedSpatialDimension(const TransformationC
if (!isQuantized(layer, defaultPrecisions)) {
return false;
}
for (const auto& output : layer->outputs()) {
const auto outPShape = output.get_partial_shape();
const auto rank = outPShape.rank();
if (rank.is_dynamic()) {
return false;
}
const auto size = rank.get_length();
if ((size < 2) || (size > 5)) {
return false;
}
const auto outputs = layer->outputs();
if (std::any_of(outputs.begin(), outputs.end(),
[](const Output<Node>& out) { return out.get_partial_shape().rank().is_dynamic(); })) {
return false;
}
return true;
}

View File

@ -81,6 +81,8 @@ bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& co
const Shape outputShape = scalesConst->get_shape();
const size_t size = shape_size(outputShape);
if (size != 1ul) {
if (operation->get_output_partial_shape(0).size() < 2)
return false;
const auto channelsInterval = operation->get_output_partial_shape(0)[1];
if (channelsInterval.is_dynamic() || static_cast<size_t>(channelsInterval.get_length()) != size) {
return false;

View File

@ -211,7 +211,7 @@ bool ReshapeTransformation::canBeTransformed(const TransformationContext& contex
}
const PartialShape outputPShape = op->get_output_partial_shape(0);
if (outputPShape[1].is_dynamic()) {
if (outputPShape.size() < 2 || outputPShape[1].is_dynamic()) {
return false;
}

View File

@ -137,7 +137,8 @@ namespace testValues1 {
const std::vector<std::pair<ngraph::PartialShape, ngraph::PartialShape>> shapes = {
{{-1, -1, -1, -1}, {-1, -1, -1, -1}},
{{1, 16, 384, 64}, {1, 16, 64, 384}},
{{4, 16, 384, 64}, {4, 16, 64, 384}}};
{{1, 1, 4, 16, 384, 64}, {1, 1, 4, 16, 64, 384}},
{{64}, {64}}};
std::vector<MatMullTransformationTestValues> testValues = {
// U8 + I8: Constant on dequantization operations on 0 branch

View File

@ -94,10 +94,7 @@ TEST_P(PReluTransformation, CompareFunctions) {
}
namespace testValues1 {
const std::vector<ngraph::PartialShape> shapes = {
{1, 3, 16, 16},
{-1, -1, -1, -1},
};
const std::vector<ngraph::PartialShape> shapes = {{1, 3, 16, 16}, {-1, -1, -1, -1}, {1, 1, 2, 3, 4, 16}, {5}};
const std::vector<PReluTransformationTestValues> testValues = {
// U8: no subtract

View File

@ -39,6 +39,22 @@ std::vector<MatMulTransformationTestValues> testValues = {
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
"matMul_original",
"I8"
},
{
{ 1, 1, 1, 4, 12, 2 },
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
{ 1, 1, 1, 4, 2, 12 },
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
"matMul_original",
"I8"
},
{
{ 12 },
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
{ 12 },
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
"matMul_original/MM",
"I8"
}
};

View File

@ -12,7 +12,7 @@ using namespace LayerTestsDefinitions;
namespace {
const std::vector<ngraph::element::Type> netPrecisions = {
ngraph::element::f32,
ngraph::element::f16
// ngraph::element::f16
};
const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
@ -34,7 +34,7 @@ const std::vector<ReshapeTransformationParam> params = {
{ -1 },
{ 256ul, ngraph::Shape{}, { 0.f }, { 255.f }, { 0.f }, { 25.5f } },
"Reshape",
"FP32"
"U8"
},
// 4D -> 3D
{

View File

@ -27,6 +27,18 @@ std::vector<MatMulTransformationTestValues> testValues = {
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
{ 1, 4, 2, 12 },
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} }
},
{
{ 1, 1, 1, 4, 12, 2 },
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
{ 1, 1, 1, 4, 2, 12 },
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
},
{
{ 12 },
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
{ 12 },
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
}
};

View File

@ -34,7 +34,7 @@ const std::vector<ReshapeTransformationParam> params = {
{ -1 },
{ 256ul, ngraph::Shape{}, { 0.f }, { 255.f }, { 0.f }, { 25.5f } },
"Reshape",
"FP32"
"U8"
},
// 4D -> 3D
{