[LPT] Rank limitations removed (#14785)
* [LPT] LayerTransformation: removed legacy rank checks * [LPT] Added test cases with 1D and 6D ranks & existing tests corrected
This commit is contained in:
parent
0b5603fa98
commit
d1397b7b48
@ -35,18 +35,6 @@ public:
|
||||
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
|
||||
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
|
||||
static bool isQuantizedStatic(const std::shared_ptr<const Node>& layer);
|
||||
|
||||
protected:
|
||||
static bool isHandled(
|
||||
const TransformationContext& context,
|
||||
const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations);
|
||||
|
||||
void fillDequantizationNodes(
|
||||
const std::vector<FakeQuantizeDequantization>& layerDequantizations,
|
||||
const std::shared_ptr<Node> layer,
|
||||
NodeVector& convertNodes,
|
||||
NodeVector& subtractNodes,
|
||||
NodeVector& multiplyNodes) const;
|
||||
};
|
||||
|
||||
} // namespace low_precision
|
||||
|
@ -253,81 +253,6 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
|
||||
return true;
|
||||
}
|
||||
|
||||
void ConcatTransformation::fillDequantizationNodes(
|
||||
const std::vector<FakeQuantizeDequantization>& layerDequantizations,
|
||||
const std::shared_ptr<Node> layer,
|
||||
NodeVector& convertNodes,
|
||||
NodeVector& subtractNodes,
|
||||
NodeVector& multiplyNodes) const {
|
||||
if (layerDequantizations.size() > 1ul) {
|
||||
auto broadcastElementWiseConst = [](
|
||||
// FakeQuantize constant shape must be broadcastable to the shape on data.
|
||||
std::shared_ptr<ngraph::opset1::Constant> operation,
|
||||
const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
|
||||
auto targetShapeConst = opset1::Constant::create(element::i64, ngraph::Shape{ targetShape.size() }, targetShape);
|
||||
auto broadcast = fold<ngraph::opset1::Broadcast>(operation, targetShapeConst);
|
||||
return broadcast;
|
||||
};
|
||||
|
||||
bool allDequantizationShiftAreZero = true;
|
||||
bool allDequantizationMultiplyAreZero = true;
|
||||
for (const auto& dequantization : layerDequantizations) {
|
||||
if (dequantization.subtract != nullptr) {
|
||||
allDequantizationShiftAreZero = false;
|
||||
}
|
||||
if (dequantization.multiply != nullptr) {
|
||||
allDequantizationMultiplyAreZero = false;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
|
||||
const auto& dequantization = layerDequantizations[i];
|
||||
const ngraph::element::Type precision = deqPrecision;
|
||||
ngraph::Shape targetShape(layer->get_input_partial_shape(i).rank().get_length(), 1ul);
|
||||
targetShape[1] = layer->get_input_partial_shape(i)[1].get_length();
|
||||
|
||||
if (dequantization.convert != nullptr) {
|
||||
convertNodes.push_back(dequantization.convert);
|
||||
}
|
||||
|
||||
if (!allDequantizationShiftAreZero) {
|
||||
subtractNodes.push_back(dequantization.subtract == nullptr ?
|
||||
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 0.f })) :
|
||||
broadcastElementWiseConst(dequantization.subtractConstant, targetShape));
|
||||
}
|
||||
|
||||
if (!allDequantizationMultiplyAreZero) {
|
||||
multiplyNodes.push_back(dequantization.multiply == nullptr ?
|
||||
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 1.0f })) :
|
||||
broadcastElementWiseConst(dequantization.multiplyConstant, targetShape));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO: check constant shapes here - has to be scalar
|
||||
if (layerDequantizations[0].convert != nullptr) {
|
||||
convertNodes.push_back(layerDequantizations[0].convert);
|
||||
}
|
||||
|
||||
if (layerDequantizations[0].subtract != nullptr) {
|
||||
subtractNodes.push_back(layerDequantizations[0].subtract->input_value(1).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
if (layerDequantizations[0].multiply != nullptr) {
|
||||
multiplyNodes.push_back(layerDequantizations[0].multiply->input_value(1).get_node_shared_ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ConcatTransformation::isHandled(const TransformationContext& context, const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations) {
|
||||
for (const std::shared_ptr<ngraph::Node>& quantizationLayer : quantizationOperations) {
|
||||
if (context.quantizedFakeQuantizeNames.find(quantizationLayer->get_friendly_name()) != context.quantizedFakeQuantizeNames.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConcatTransformation::isQuantizedStatic(const std::shared_ptr<const Node>& layer) {
|
||||
const auto concat = as_type_ptr<const opset1::Concat>(layer);
|
||||
if (concat == nullptr)
|
||||
|
@ -53,11 +53,10 @@ bool LayerTransformation::canBeTransformed(const TransformationContext& context,
|
||||
|
||||
bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& layer,
|
||||
const std::vector<ngraph::element::Type>& defaultPrecisions) {
|
||||
for (const auto& output : layer->outputs()) {
|
||||
const auto rank = output.get_partial_shape().rank();
|
||||
if (rank.is_dynamic() || rank.get_length() < 2) {
|
||||
return false;
|
||||
}
|
||||
const auto outputs = layer->outputs();
|
||||
if (std::any_of(outputs.begin(), outputs.end(),
|
||||
[](const Output<Node>& out) { return out.get_partial_shape().rank().is_dynamic(); })) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto dequantization = NetworkHelper::getDequantization(layer, defaultPrecisions);
|
||||
@ -72,8 +71,7 @@ bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& la
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto dataShapeSize = static_cast<size_t>(rank.get_length());
|
||||
if ((dataShapeSize - constShape.size()) == 1ul) {
|
||||
if ((dataPShape.size() - constShape.size()) == 1ul) {
|
||||
constShape.insert(constShape.begin(), 1ul);
|
||||
}
|
||||
|
||||
@ -115,18 +113,10 @@ bool LayerTransformation::canBeTransformedSpatialDimension(const TransformationC
|
||||
if (!isQuantized(layer, defaultPrecisions)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& output : layer->outputs()) {
|
||||
const auto outPShape = output.get_partial_shape();
|
||||
const auto rank = outPShape.rank();
|
||||
if (rank.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto size = rank.get_length();
|
||||
if ((size < 2) || (size > 5)) {
|
||||
return false;
|
||||
}
|
||||
const auto outputs = layer->outputs();
|
||||
if (std::any_of(outputs.begin(), outputs.end(),
|
||||
[](const Output<Node>& out) { return out.get_partial_shape().rank().is_dynamic(); })) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -81,6 +81,8 @@ bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& co
|
||||
const Shape outputShape = scalesConst->get_shape();
|
||||
const size_t size = shape_size(outputShape);
|
||||
if (size != 1ul) {
|
||||
if (operation->get_output_partial_shape(0).size() < 2)
|
||||
return false;
|
||||
const auto channelsInterval = operation->get_output_partial_shape(0)[1];
|
||||
if (channelsInterval.is_dynamic() || static_cast<size_t>(channelsInterval.get_length()) != size) {
|
||||
return false;
|
||||
|
@ -211,7 +211,7 @@ bool ReshapeTransformation::canBeTransformed(const TransformationContext& contex
|
||||
}
|
||||
|
||||
const PartialShape outputPShape = op->get_output_partial_shape(0);
|
||||
if (outputPShape[1].is_dynamic()) {
|
||||
if (outputPShape.size() < 2 || outputPShape[1].is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -137,7 +137,8 @@ namespace testValues1 {
|
||||
const std::vector<std::pair<ngraph::PartialShape, ngraph::PartialShape>> shapes = {
|
||||
{{-1, -1, -1, -1}, {-1, -1, -1, -1}},
|
||||
{{1, 16, 384, 64}, {1, 16, 64, 384}},
|
||||
{{4, 16, 384, 64}, {4, 16, 64, 384}}};
|
||||
{{1, 1, 4, 16, 384, 64}, {1, 1, 4, 16, 64, 384}},
|
||||
{{64}, {64}}};
|
||||
|
||||
std::vector<MatMullTransformationTestValues> testValues = {
|
||||
// U8 + I8: Constant on dequantization operations on 0 branch
|
||||
|
@ -94,10 +94,7 @@ TEST_P(PReluTransformation, CompareFunctions) {
|
||||
}
|
||||
|
||||
namespace testValues1 {
|
||||
const std::vector<ngraph::PartialShape> shapes = {
|
||||
{1, 3, 16, 16},
|
||||
{-1, -1, -1, -1},
|
||||
};
|
||||
const std::vector<ngraph::PartialShape> shapes = {{1, 3, 16, 16}, {-1, -1, -1, -1}, {1, 1, 2, 3, 4, 16}, {5}};
|
||||
|
||||
const std::vector<PReluTransformationTestValues> testValues = {
|
||||
// U8: no subtract
|
||||
|
@ -39,6 +39,22 @@ std::vector<MatMulTransformationTestValues> testValues = {
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
"matMul_original",
|
||||
"I8"
|
||||
},
|
||||
{
|
||||
{ 1, 1, 1, 4, 12, 2 },
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
{ 1, 1, 1, 4, 2, 12 },
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
"matMul_original",
|
||||
"I8"
|
||||
},
|
||||
{
|
||||
{ 12 },
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
{ 12 },
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
"matMul_original/MM",
|
||||
"I8"
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -12,7 +12,7 @@ using namespace LayerTestsDefinitions;
|
||||
namespace {
|
||||
const std::vector<ngraph::element::Type> netPrecisions = {
|
||||
ngraph::element::f32,
|
||||
ngraph::element::f16
|
||||
// ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
|
||||
@ -34,7 +34,7 @@ const std::vector<ReshapeTransformationParam> params = {
|
||||
{ -1 },
|
||||
{ 256ul, ngraph::Shape{}, { 0.f }, { 255.f }, { 0.f }, { 25.5f } },
|
||||
"Reshape",
|
||||
"FP32"
|
||||
"U8"
|
||||
},
|
||||
// 4D -> 3D
|
||||
{
|
||||
|
@ -27,6 +27,18 @@ std::vector<MatMulTransformationTestValues> testValues = {
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
{ 1, 4, 2, 12 },
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} }
|
||||
},
|
||||
{
|
||||
{ 1, 1, 1, 4, 12, 2 },
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
{ 1, 1, 1, 4, 2, 12 },
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
},
|
||||
{
|
||||
{ 12 },
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
{ 12 },
|
||||
{ 256ul, ngraph::Shape({}), {-12.8f}, {12.7f}, {-12.8f}, {12.7f} },
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -34,7 +34,7 @@ const std::vector<ReshapeTransformationParam> params = {
|
||||
{ -1 },
|
||||
{ 256ul, ngraph::Shape{}, { 0.f }, { 255.f }, { 0.f }, { 25.5f } },
|
||||
"Reshape",
|
||||
"FP32"
|
||||
"U8"
|
||||
},
|
||||
// 4D -> 3D
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user