[LPT] Removed legacy limitations on dequantization propagation for several transformations (#13048)
This commit is contained in:
parent
e8c8ad19f4
commit
372fe475c9
@ -81,7 +81,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
|||||||
allDequantizationShiftConvertAreNotZero = false;
|
allDequantizationShiftConvertAreNotZero = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// FakeQuantize constant shape must be broadcastable to the shape on data.
|
// constant shape must be broadcastable to the shape on data.
|
||||||
auto broadcastElementWiseConst = [](std::shared_ptr<opset1::Constant> operation, const Shape targetShape) {
|
auto broadcastElementWiseConst = [](std::shared_ptr<opset1::Constant> operation, const Shape targetShape) {
|
||||||
auto targetShapeConst = std::make_shared<opset1::Constant>(element::i64, Shape{ targetShape.size() }, targetShape);
|
auto targetShapeConst = std::make_shared<opset1::Constant>(element::i64, Shape{ targetShape.size() }, targetShape);
|
||||||
auto broadcast = fold<ngraph::opset1::Broadcast>(operation, targetShapeConst);
|
auto broadcast = fold<ngraph::opset1::Broadcast>(operation, targetShapeConst);
|
||||||
@ -99,11 +99,14 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
|||||||
[](const FakeQuantizeDequantization& value) { return !value.isLowPrecision(); });
|
[](const FakeQuantizeDequantization& value) { return !value.isLowPrecision(); });
|
||||||
|
|
||||||
bool DqWithDifferentPrecision = someDqInLowPrecision && someDqInFpPrecision;
|
bool DqWithDifferentPrecision = someDqInLowPrecision && someDqInFpPrecision;
|
||||||
|
const auto axis = ngraph::normalize_axis(concat->get_friendly_name(),
|
||||||
|
concat->get_axis(),
|
||||||
|
concat->get_output_partial_shape(0).rank());
|
||||||
|
|
||||||
OutputVector dataNodes;
|
OutputVector dataNodes;
|
||||||
NodeVector convertNodes;
|
NodeVector convertNodes;
|
||||||
NodeVector subtractNodes;
|
NodeVector subConstants;
|
||||||
NodeVector multiplyNodes;
|
NodeVector mulConstants;
|
||||||
std::shared_ptr<opset1::Convert> subtractConvert = nullptr;
|
std::shared_ptr<opset1::Convert> subtractConvert = nullptr;
|
||||||
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
|
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
|
||||||
const auto& dequantization = layerDequantizations[i];
|
const auto& dequantization = layerDequantizations[i];
|
||||||
@ -119,7 +122,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
|||||||
}
|
}
|
||||||
|
|
||||||
Shape targetShape(concat->get_input_partial_shape(i).rank().get_length(), 1ul);
|
Shape targetShape(concat->get_input_partial_shape(i).rank().get_length(), 1ul);
|
||||||
targetShape[1] = concat->get_input_partial_shape(i)[1].get_length();
|
targetShape[axis] = concat->get_input_partial_shape(i)[axis].get_length();
|
||||||
|
|
||||||
if (!allDequantizationShiftAreZero) {
|
if (!allDequantizationShiftAreZero) {
|
||||||
auto subtractInput = dequantization.subtract == nullptr ?
|
auto subtractInput = dequantization.subtract == nullptr ?
|
||||||
@ -138,11 +141,11 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
|||||||
subtractInput = foldConvert(subtractInput, dequantization.subtractConvert->get_convert_element_type());
|
subtractInput = foldConvert(subtractInput, dequantization.subtractConvert->get_convert_element_type());
|
||||||
NetworkHelper::copyInfo(dequantization.subtractConvert, subtractInput);
|
NetworkHelper::copyInfo(dequantization.subtractConvert, subtractInput);
|
||||||
}
|
}
|
||||||
subtractNodes.push_back(subtractInput);
|
subConstants.push_back(subtractInput);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!allDequantizationMultiplyAreZero) {
|
if (!allDequantizationMultiplyAreZero) {
|
||||||
multiplyNodes.push_back(dequantization.multiply == nullptr ?
|
mulConstants.push_back(dequantization.multiply == nullptr ?
|
||||||
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 1.0f })) :
|
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 1.0f })) :
|
||||||
broadcastElementWiseConst(dequantization.multiplyConstant, targetShape));
|
broadcastElementWiseConst(dequantization.multiplyConstant, targetShape));
|
||||||
}
|
}
|
||||||
@ -159,11 +162,10 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
|||||||
lastDequantization = convert;
|
lastDequantization = convert;
|
||||||
}
|
}
|
||||||
|
|
||||||
// concatenation axis is 1
|
if (!subConstants.empty()) {
|
||||||
if (!subtractNodes.empty()) {
|
std::shared_ptr<ov::Node> subtractNode = subConstants.size() == 1ul ?
|
||||||
std::shared_ptr<ov::Node> subtractNode = subtractNodes.size() == 1ul ?
|
subConstants[0] :
|
||||||
subtractNodes[0] :
|
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subConstants, axis);
|
||||||
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subtractNodes, 1);
|
|
||||||
if (subtractConvert != nullptr)
|
if (subtractConvert != nullptr)
|
||||||
subtractNode = subtractConvert->clone_with_new_inputs({subtractNode});
|
subtractNode = subtractConvert->clone_with_new_inputs({subtractNode});
|
||||||
const auto subtract = std::make_shared<opset1::Subtract>(
|
const auto subtract = std::make_shared<opset1::Subtract>(
|
||||||
@ -175,13 +177,13 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
|||||||
lastDequantization = subtract;
|
lastDequantization = subtract;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!multiplyNodes.empty()) {
|
if (!mulConstants.empty()) {
|
||||||
const auto multiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>(
|
const auto multiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>(
|
||||||
opset1::Multiply(
|
opset1::Multiply(
|
||||||
lastDequantization,
|
lastDequantization,
|
||||||
NetworkHelper::toScalarIfPossible(multiplyNodes.size() == 1ul ?
|
NetworkHelper::toScalarIfPossible(mulConstants.size() == 1ul ?
|
||||||
multiplyNodes[0] :
|
mulConstants[0] :
|
||||||
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(multiplyNodes, 1))),
|
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(mulConstants, axis))),
|
||||||
layerDequantizations[0].multiply->get_output_element_type(0));
|
layerDequantizations[0].multiply->get_output_element_type(0));
|
||||||
|
|
||||||
NetworkHelper::copyInfo({ concat, multiply }, multiply);
|
NetworkHelper::copyInfo({ concat, multiply }, multiply);
|
||||||
@ -213,11 +215,6 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
|
|||||||
}
|
}
|
||||||
|
|
||||||
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, outRank);
|
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, outRank);
|
||||||
|
|
||||||
if (normalizedAxis != 1ul) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (outPShape[normalizedAxis].is_dynamic()) {
|
if (outPShape[normalizedAxis].is_dynamic()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -333,17 +330,9 @@ bool ConcatTransformation::isHandled(const TransformationContext& context, const
|
|||||||
|
|
||||||
bool ConcatTransformation::isQuantizedStatic(const std::shared_ptr<const Node>& layer) {
|
bool ConcatTransformation::isQuantizedStatic(const std::shared_ptr<const Node>& layer) {
|
||||||
const auto concat = as_type_ptr<const opset1::Concat>(layer);
|
const auto concat = as_type_ptr<const opset1::Concat>(layer);
|
||||||
if (concat == nullptr) {
|
if (concat == nullptr)
|
||||||
return false;
|
return false;
|
||||||
}
|
return concat->get_output_partial_shape(0).rank().is_static();
|
||||||
|
|
||||||
const auto outputRank = concat->get_output_partial_shape(0).rank();
|
|
||||||
if (outputRank.is_dynamic()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), concat->get_axis(), outputRank);
|
|
||||||
return normalizedAxis == 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace low_precision
|
} // namespace low_precision
|
||||||
|
@ -153,20 +153,7 @@ bool SplitTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) cons
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool SplitTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
|
bool SplitTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
|
||||||
if (!LayerTransformation::canBeTransformed(context, layer) || NetworkHelper::getDequantization(layer, defaultPrecisions).empty()) {
|
return !NetworkHelper::getDequantization(layer, defaultPrecisions).empty() && layer->get_input_partial_shape(0).rank().is_static();
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto consumers = NetworkHelper::consumers(layer);
|
|
||||||
const auto concat = ov::as_type_ptr<opset1::Concat>(consumers[0]);
|
|
||||||
|
|
||||||
// WA to avoid propagation of dequantization if after Split all consumers are the same unsupported Concat
|
|
||||||
if (concat && concat->get_axis() != 1ul) {
|
|
||||||
const size_t id = consumers[0]->get_instance_id();
|
|
||||||
return std::any_of(consumers.begin(), consumers.end(), [&](const std::shared_ptr<Node>& node) { return node->get_instance_id() != id; });
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace low_precision
|
} // namespace low_precision
|
||||||
|
@ -164,7 +164,7 @@ namespace testValues2 {
|
|||||||
{},
|
{},
|
||||||
{},
|
{},
|
||||||
"Concatenation",
|
"Concatenation",
|
||||||
"FP32",
|
"I8",
|
||||||
-1
|
-1
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -127,6 +127,7 @@ public:
|
|||||||
result <<
|
result <<
|
||||||
LayerTransformation::getTestCaseNameByParams(precision, inputShape, testValues.params) << "_" <<
|
LayerTransformation::getTestCaseNameByParams(precision, inputShape, testValues.params) << "_" <<
|
||||||
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
|
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
|
||||||
|
"axis" << testValues.axis <<
|
||||||
testValues.actual << "_" <<
|
testValues.actual << "_" <<
|
||||||
testValues.result << "_";
|
testValues.result << "_";
|
||||||
return result.str();
|
return result.str();
|
||||||
@ -147,7 +148,7 @@ const std::vector<ngraph::element::Type> precisions = {
|
|||||||
const std::vector<ngraph::PartialShape> shapes = {
|
const std::vector<ngraph::PartialShape> shapes = {
|
||||||
{ 1, 3, 10, 10 },
|
{ 1, 3, 10, 10 },
|
||||||
{ 4, 3, 10, 10 },
|
{ 4, 3, 10, 10 },
|
||||||
{ Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic() }
|
{ -1, 3, 10, -1 }
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::vector<ConcatTransformationTestValues> testValues = {
|
const std::vector<ConcatTransformationTestValues> testValues = {
|
||||||
@ -171,7 +172,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
|
|||||||
{ ngraph::element::f32, {}, { 0.01f } }
|
{ ngraph::element::f32, {}, { 0.01f } }
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// U8 with unsupported axis
|
// U8 concatenation by spatial dimension
|
||||||
{
|
{
|
||||||
LayerTransformation::createParamsU8I8(),
|
LayerTransformation::createParamsU8I8(),
|
||||||
false,
|
false,
|
||||||
@ -182,13 +183,13 @@ const std::vector<ConcatTransformationTestValues> testValues = {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
|
||||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {255.f} },
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {128.f} },
|
||||||
ngraph::element::u8,
|
ngraph::element::u8,
|
||||||
{{ngraph::element::f32}, {}, {0.01f}},
|
{},
|
||||||
{{ngraph::element::f32}, {}, {0.005f}},
|
{},
|
||||||
ngraph::element::f32,
|
ngraph::element::f32,
|
||||||
{{}, {}, {}},
|
{ ngraph::element::f32, {}, { 0.01f } },
|
||||||
{{}, {}, {}}
|
{ ngraph::element::f32, {}, { 0.01f } }
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// I8
|
// I8
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,37 +0,0 @@
|
|||||||
// Copyright (C) 2018-2022 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "layer_transformation.hpp"
|
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
|
|
||||||
#include "lpt_ngraph_functions/concat_function.hpp"
|
|
||||||
#include "simple_low_precision_transformer.hpp"
|
|
||||||
|
|
||||||
using namespace ::testing;
|
|
||||||
|
|
||||||
class smoke_LPT_ConcatWithUnsupportedAxis : public Test {};
|
|
||||||
|
|
||||||
TEST_F(smoke_LPT_ConcatWithUnsupportedAxis, rtInfoCheck) {
|
|
||||||
using namespace ngraph::builder::subgraph;
|
|
||||||
|
|
||||||
const ngraph::element::Type precision = ngraph::element::f32;
|
|
||||||
const ngraph::PartialShape inputPShape = PartialShape{ 1, 3, 16, 16 };
|
|
||||||
const std::int64_t unsupportedAxis = 2;
|
|
||||||
const auto fakeQuantize = FakeQuantizeOnData{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} };
|
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> function = ConcatFunction::getOriginalWithDifferentPrecisionOnChildren(
|
|
||||||
precision,
|
|
||||||
inputPShape,
|
|
||||||
unsupportedAxis,
|
|
||||||
fakeQuantize,
|
|
||||||
fakeQuantize);
|
|
||||||
|
|
||||||
SimpleLowPrecisionTransformer transformer;
|
|
||||||
transformer.transform(function);
|
|
||||||
|
|
||||||
const auto actualConcat = LayerTransformation::get<opset1::Concat>(function)[0];
|
|
||||||
const auto& rtInfo = actualConcat->get_rt_info();
|
|
||||||
ASSERT_TRUE(rtInfo.empty()) << "Unsupported concat mustn't contain LPT runtime attributes";
|
|
||||||
}
|
|
@ -296,7 +296,7 @@ const std::vector<MoveFakeQuantizeTransformationTestValues> testValues = {
|
|||||||
{},
|
{},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// negative test
|
// concat by batch
|
||||||
{
|
{
|
||||||
LayerTransformation::createParamsU8I8(),
|
LayerTransformation::createParamsU8I8(),
|
||||||
false,
|
false,
|
||||||
@ -313,11 +313,11 @@ const std::vector<MoveFakeQuantizeTransformationTestValues> testValues = {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
2,
|
2,
|
||||||
{},
|
{{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}},
|
||||||
{},
|
{},
|
||||||
{},
|
{},
|
||||||
"",
|
"",
|
||||||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
|
{},
|
||||||
{},
|
{},
|
||||||
{}
|
{}
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,6 @@ public:
|
|||||||
TestTransformationParams params;
|
TestTransformationParams params;
|
||||||
Actual actual;
|
Actual actual;
|
||||||
Expected expected;
|
Expected expected;
|
||||||
bool addUnsupportedConcat;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -64,8 +63,7 @@ public:
|
|||||||
testValues.actual.precisionBeforeDequantization,
|
testValues.actual.precisionBeforeDequantization,
|
||||||
testValues.actual.dequantization,
|
testValues.actual.dequantization,
|
||||||
testValues.splitedAxis,
|
testValues.splitedAxis,
|
||||||
testValues.numSplits,
|
testValues.numSplits);
|
||||||
testValues.addUnsupportedConcat);
|
|
||||||
|
|
||||||
SimpleLowPrecisionTransformer transformer;
|
SimpleLowPrecisionTransformer transformer;
|
||||||
transformer.add<ngraph::pass::low_precision::SplitTransformation, ngraph::opset1::Split>(testValues.params);
|
transformer.add<ngraph::pass::low_precision::SplitTransformation, ngraph::opset1::Split>(testValues.params);
|
||||||
@ -79,8 +77,7 @@ public:
|
|||||||
testValues.expected.precisionAfterOperation,
|
testValues.expected.precisionAfterOperation,
|
||||||
testValues.expected.dequantizationAfter,
|
testValues.expected.dequantizationAfter,
|
||||||
testValues.splitedAxis,
|
testValues.splitedAxis,
|
||||||
testValues.numSplits,
|
testValues.numSplits);
|
||||||
testValues.addUnsupportedConcat);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string getTestCaseName(testing::TestParamInfo<SplitTransformationParams> obj) {
|
static std::string getTestCaseName(testing::TestParamInfo<SplitTransformationParams> obj) {
|
||||||
@ -334,6 +331,59 @@ const std::vector<SplitTransformationTestValues> testValues = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
// per channel quantization with different values, split by batch
|
||||||
|
{
|
||||||
|
{ 2, 3, 16, 16 }, std::int64_t{0}, size_t{2},
|
||||||
|
LayerTransformation::createParamsI8I8(),
|
||||||
|
{
|
||||||
|
ngraph::element::i8,
|
||||||
|
{{ngraph::element::f32},
|
||||||
|
{{2.f, 3.f}, ngraph::element::f32, {2, 1, 1, 1}},
|
||||||
|
{{22.f, 33.f}, ngraph::element::f32, {2, 1, 1, 1}}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ngraph::element::i8,
|
||||||
|
{},
|
||||||
|
ngraph::element::i8,
|
||||||
|
{
|
||||||
|
{{ngraph::element::f32}, {2.f}, {22.f}},
|
||||||
|
{{ngraph::element::f32}, {3.f}, {33.f}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// per channel quantization with different values, split by spatial dimension
|
||||||
|
{
|
||||||
|
{ -1, -1, -1, -1 }, std::int64_t{2}, size_t{3},
|
||||||
|
LayerTransformation::createParamsI8I8(),
|
||||||
|
{
|
||||||
|
ngraph::element::i8,
|
||||||
|
{{ngraph::element::f32},
|
||||||
|
{{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, ngraph::element::f32, {1, 1, 6, 1}},
|
||||||
|
{{11.f, 22.f, 33.f, 44.f, 55.f, 66.f}, ngraph::element::f32, {1, 1, 6, 1}}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ngraph::element::i8,
|
||||||
|
{},
|
||||||
|
ngraph::element::i8,
|
||||||
|
{
|
||||||
|
{
|
||||||
|
{ngraph::element::f32},
|
||||||
|
{{1.f, 2.f}, ngraph::element::f32, {1, 1, 2, 1}},
|
||||||
|
{{11.f, 22.f}, ngraph::element::f32, {1, 1, 2, 1}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ngraph::element::f32},
|
||||||
|
{{3.f, 4.f}, ngraph::element::f32, {1, 1, 2, 1}},
|
||||||
|
{{33.f, 44.f}, ngraph::element::f32, {1, 1, 2, 1}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ngraph::element::f32},
|
||||||
|
{{5.f, 6.f}, ngraph::element::f32, {1, 1, 2, 1}},
|
||||||
|
{{55.f, 66.f}, ngraph::element::f32, {1, 1, 2, 1}}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
// U8 per channel quantization with the same values
|
// U8 per channel quantization with the same values
|
||||||
{
|
{
|
||||||
{ 1, 3, 16, 16 }, std::int64_t{1}, size_t{3},
|
{ 1, 3, 16, 16 }, std::int64_t{1}, size_t{3},
|
||||||
@ -584,39 +634,6 @@ const std::vector<SplitTransformationTestValues> testValues = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// issue #56781: unsupported Concat after Split
|
|
||||||
{
|
|
||||||
{ 1, 4, 3, 3 }, std::int64_t{2}, size_t{3},
|
|
||||||
LayerTransformation::createParamsU8I8(),
|
|
||||||
{
|
|
||||||
ngraph::element::u8,
|
|
||||||
{{ngraph::element::f32}, {128.f}, {3.f}}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ngraph::element::u8,
|
|
||||||
{{ngraph::element::f32}, {128.f}, {3.f}},
|
|
||||||
ngraph::element::f32,
|
|
||||||
{}
|
|
||||||
},
|
|
||||||
true
|
|
||||||
},
|
|
||||||
// issue #56781: unsupported Concat after Split, dynamic channels
|
|
||||||
{
|
|
||||||
{ -1, -1, -1, -1 },
|
|
||||||
std::int64_t{2}, size_t{3},
|
|
||||||
LayerTransformation::createParamsU8I8(),
|
|
||||||
{
|
|
||||||
ngraph::element::u8,
|
|
||||||
{{ngraph::element::f32}, {128.f}, {3.f}}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ngraph::element::u8,
|
|
||||||
{{ngraph::element::f32}, {128.f}, {3.f}},
|
|
||||||
ngraph::element::f32,
|
|
||||||
{}
|
|
||||||
},
|
|
||||||
true
|
|
||||||
},
|
|
||||||
// no dequantization
|
// no dequantization
|
||||||
{
|
{
|
||||||
ngraph::Shape({ 1, 3, 4, 4 }), std::int64_t{2}, size_t{2},
|
ngraph::Shape({ 1, 3, 4, 4 }), std::int64_t{2}, size_t{2},
|
||||||
|
@ -196,6 +196,54 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
// U8 per channel quantization with different values, split by batch
|
||||||
|
{
|
||||||
|
{ 2, 3, 16, 16 }, std::int64_t{0}, std::vector<size_t>{ 1, 1 },
|
||||||
|
LayerTransformation::createParamsU8I8(),
|
||||||
|
{
|
||||||
|
ngraph::element::u8,
|
||||||
|
{{ngraph::element::f32},
|
||||||
|
{{2.f, 3.f}, ngraph::element::f32, {2, 1, 1, 1}},
|
||||||
|
{{22.f, 33.f}, ngraph::element::f32, {2, 1, 1, 1}}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ngraph::element::u8,
|
||||||
|
{},
|
||||||
|
ngraph::element::u8,
|
||||||
|
{
|
||||||
|
{{ngraph::element::f32}, {2.f}, {22.f}},
|
||||||
|
{{ngraph::element::f32}, {3.f}, {33.f}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// U8 per channel quantization with different values, split by spatial dimension
|
||||||
|
{
|
||||||
|
{ -1, -1, -1, -1 }, std::int64_t{3}, std::vector<size_t>{ 4, 2 },
|
||||||
|
LayerTransformation::createParamsU8I8(),
|
||||||
|
{
|
||||||
|
ngraph::element::u8,
|
||||||
|
{{ngraph::element::f32},
|
||||||
|
{{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, ngraph::element::f32, {1, 1, 1, 6}},
|
||||||
|
{{11.f, 22.f, 33.f, 44.f, 55.f, 66.f}, ngraph::element::f32, {1, 1, 1, 6}}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ngraph::element::u8,
|
||||||
|
{},
|
||||||
|
ngraph::element::u8,
|
||||||
|
{
|
||||||
|
{
|
||||||
|
{ngraph::element::f32},
|
||||||
|
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 1, 1, 4}},
|
||||||
|
{{11.f, 22.f, 33.f, 44.f}, ngraph::element::f32, {1, 1, 1, 4}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ngraph::element::f32},
|
||||||
|
{{5.f, 6.f}, ngraph::element::f32, {1, 1, 1, 2}},
|
||||||
|
{{55.f, 66.f}, ngraph::element::f32, {1, 1, 1, 2}}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
// U8 per channel quantization with different values, dynamic shape
|
// U8 per channel quantization with different values, dynamic shape
|
||||||
{
|
{
|
||||||
{ -1, 3, -1, -1 },
|
{ -1, 3, -1, -1 },
|
||||||
|
@ -17,6 +17,15 @@ namespace subgraph {
|
|||||||
|
|
||||||
class ConcatFunction {
|
class ConcatFunction {
|
||||||
public:
|
public:
|
||||||
|
static std::shared_ptr<ov::Model> get(
|
||||||
|
const ov::element::Type inputPrecision,
|
||||||
|
const ov::element::Type deqPrecision,
|
||||||
|
const std::vector<ov::PartialShape>& inputShapes,
|
||||||
|
const std::vector<DequantizationOperations>& dequantizationsBefore,
|
||||||
|
const std::int64_t concatAxis,
|
||||||
|
const ov::element::Type precisionAfter = ov::element::undefined,
|
||||||
|
const DequantizationOperations& dequantizationAfter = {});
|
||||||
|
|
||||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||||
const ngraph::element::Type precision,
|
const ngraph::element::Type precision,
|
||||||
const ngraph::PartialShape& inputShape,
|
const ngraph::PartialShape& inputShape,
|
||||||
@ -107,6 +116,15 @@ public:
|
|||||||
const FakeQuantizeOnData& fqOnData1,
|
const FakeQuantizeOnData& fqOnData1,
|
||||||
const FakeQuantizeOnData& fqOnData2);
|
const FakeQuantizeOnData& fqOnData2);
|
||||||
|
|
||||||
|
static std::shared_ptr<ov::Model> getReference(
|
||||||
|
const ov::element::Type dequantizationPrecision,
|
||||||
|
const ov::element::Type precisionBefore,
|
||||||
|
const std::vector<ov::PartialShape>& inputShapes,
|
||||||
|
const std::vector<DequantizationOperations>& dequantizationsBefore,
|
||||||
|
const ov::element::Type precisionAfter,
|
||||||
|
const DequantizationOperations& dequantizationAfter,
|
||||||
|
const std::int64_t concatAxis);
|
||||||
|
|
||||||
static std::shared_ptr<ngraph::Function> getReference(
|
static std::shared_ptr<ngraph::Function> getReference(
|
||||||
const ngraph::element::Type precision,
|
const ngraph::element::Type precision,
|
||||||
const ngraph::Shape& inputShape,
|
const ngraph::Shape& inputShape,
|
||||||
|
@ -25,8 +25,7 @@ public:
|
|||||||
const ngraph::element::Type precisionBeforeDequantization,
|
const ngraph::element::Type precisionBeforeDequantization,
|
||||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||||
const int64_t splitedAxis,
|
const int64_t splitedAxis,
|
||||||
const size_t numSplits,
|
const size_t numSplits);
|
||||||
const bool addUnsupportedConcat = false);
|
|
||||||
|
|
||||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||||
const ngraph::element::Type originalFunctionPrecision,
|
const ngraph::element::Type originalFunctionPrecision,
|
||||||
@ -43,8 +42,7 @@ public:
|
|||||||
const ngraph::element::Type precisionAfterOperation,
|
const ngraph::element::Type precisionAfterOperation,
|
||||||
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
|
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
|
||||||
const int64_t splitedAxis,
|
const int64_t splitedAxis,
|
||||||
const size_t numSplit,
|
const size_t numSplits);
|
||||||
const bool addUnsupportedConcat = false);
|
|
||||||
};
|
};
|
||||||
} // namespace subgraph
|
} // namespace subgraph
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#include "lpt_ngraph_functions/concat_function.hpp"
|
#include "lpt_ngraph_functions/concat_function.hpp"
|
||||||
|
|
||||||
#include <ngraph/opsets/opset1.hpp>
|
#include <openvino/opsets/opset1.hpp>
|
||||||
#include "ngraph_ops/type_relaxed.hpp"
|
#include "ngraph_ops/type_relaxed.hpp"
|
||||||
#include "low_precision/network_helper.hpp"
|
#include "low_precision/network_helper.hpp"
|
||||||
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
|
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
|
||||||
@ -23,6 +23,46 @@ namespace subgraph {
|
|||||||
|
|
||||||
using namespace ngraph::pass;
|
using namespace ngraph::pass;
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> ConcatFunction::get(
|
||||||
|
const ov::element::Type inputPrecision,
|
||||||
|
const ov::element::Type deqPrecision,
|
||||||
|
const std::vector<ov::PartialShape>& inputShapes,
|
||||||
|
const std::vector<DequantizationOperations>& dequantizationsBefore,
|
||||||
|
const std::int64_t concatAxis,
|
||||||
|
const ov::element::Type precisionAfter,
|
||||||
|
const DequantizationOperations& dequantizationAfter) {
|
||||||
|
auto modifyDeq = [](const DequantizationOperations& deq, const ov::element::Type deqOutPrc) {
|
||||||
|
auto dequantizationStructure = deq;
|
||||||
|
if (!dequantizationStructure.multiply.empty()) {
|
||||||
|
dequantizationStructure.multiply.outPrecision = deqOutPrc;
|
||||||
|
}
|
||||||
|
return dequantizationStructure;
|
||||||
|
};
|
||||||
|
|
||||||
|
ov::ParameterVector inputs;
|
||||||
|
ov::NodeVector concatInputs;
|
||||||
|
if (inputShapes.size() != dequantizationsBefore.size()) {
|
||||||
|
throw std::runtime_error("Concat builder: input and dequantization sizes aren't equal");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < inputShapes.size(); ++i) {
|
||||||
|
const auto input = std::make_shared<ov::opset1::Parameter>(inputPrecision, inputShapes[i]);
|
||||||
|
const auto dequantization = makeDequantization(input, modifyDeq(dequantizationsBefore[i], deqPrecision));
|
||||||
|
inputs.push_back(input);
|
||||||
|
concatInputs.push_back(dequantization);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto concat = std::make_shared<ov::opset1::Concat>(concatInputs, concatAxis);
|
||||||
|
if (precisionAfter != ov::element::undefined && (concat->get_output_element_type(0).is_real() ^ precisionAfter.is_real())) {
|
||||||
|
throw std::runtime_error("Concat builder: requested precision after operation could't be set");
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto deqAfter = makeDequantization(concat, modifyDeq(dequantizationAfter, deqPrecision));
|
||||||
|
deqAfter->set_friendly_name("Concat");
|
||||||
|
const auto result = std::make_shared<ov::opset1::Result>(deqAfter);
|
||||||
|
return std::make_shared<ov::Model>(ov::ResultVector{result}, inputs, "ConcatTransformation");
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginal(
|
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginal(
|
||||||
const ngraph::element::Type precision,
|
const ngraph::element::Type precision,
|
||||||
const ngraph::PartialShape& inputShape,
|
const ngraph::PartialShape& inputShape,
|
||||||
|
@ -23,25 +23,18 @@ std::shared_ptr<ngraph::Function> SplitFunction::getOriginal(
|
|||||||
const ngraph::element::Type precisionBeforeDequantization,
|
const ngraph::element::Type precisionBeforeDequantization,
|
||||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||||
const int64_t splitedAxis,
|
const int64_t splitedAxis,
|
||||||
const size_t numSplits,
|
const size_t numSplits) {
|
||||||
const bool addUnsupportedConcat) {
|
|
||||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, inputShape);
|
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, inputShape);
|
||||||
|
|
||||||
auto dequantizationStructure = dequantization;
|
auto dequantizationStructure = dequantization;
|
||||||
dequantizationStructure.multiply.outPrecision = precision;
|
dequantizationStructure.multiply.outPrecision = precision;
|
||||||
const std::shared_ptr<Node> dequantizationOp = makeDequantization(input, dequantization);
|
const auto dequantizationOp = makeDequantization(input, dequantization);
|
||||||
const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis);
|
const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis);
|
||||||
const std::shared_ptr<Node> split = std::make_shared<ngraph::opset1::Split>(dequantizationOp, constant, numSplits);
|
const auto split = std::make_shared<ngraph::opset1::Split>(dequantizationOp, constant, numSplits);
|
||||||
|
|
||||||
ngraph::ResultVector results;
|
ngraph::ResultVector results;
|
||||||
|
for (size_t i = 0; i < numSplits; ++i) {
|
||||||
if (addUnsupportedConcat) {
|
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
|
||||||
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul);
|
|
||||||
results.push_back(std::make_shared<opset1::Result>(concat));
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < numSplits; ++i) {
|
|
||||||
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "SplitFunction");
|
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "SplitFunction");
|
||||||
}
|
}
|
||||||
@ -82,31 +75,23 @@ std::shared_ptr<ngraph::Function> SplitFunction::getReference(
|
|||||||
const ngraph::element::Type precisionAfterOperation,
|
const ngraph::element::Type precisionAfterOperation,
|
||||||
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
|
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
|
||||||
const int64_t splitedAxis,
|
const int64_t splitedAxis,
|
||||||
const size_t numSplit,
|
const size_t numSplit) {
|
||||||
const bool addUnsupportedConcat) {
|
|
||||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
|
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
|
||||||
|
|
||||||
const auto deqBefore = makeDequantization(input, dequantizationBefore);
|
const auto deqBefore = makeDequantization(input, dequantizationBefore);
|
||||||
|
|
||||||
std::shared_ptr<ngraph::opset1::Split> split;
|
|
||||||
const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis);
|
const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis);
|
||||||
split = std::make_shared<ngraph::opset1::Split>(deqBefore, constant, numSplit);
|
const auto split = std::make_shared<ngraph::opset1::Split>(deqBefore, constant, numSplit);
|
||||||
|
|
||||||
ngraph::ResultVector results;
|
ngraph::ResultVector results;
|
||||||
if (addUnsupportedConcat) {
|
for (size_t i = 0; i < numSplit; ++i) {
|
||||||
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul);
|
if (!dequantizationAfter.empty()) {
|
||||||
results.push_back(std::make_shared<opset1::Result>(concat));
|
auto dequantizationStructure = dequantizationAfter[i];
|
||||||
} else {
|
if (!dequantizationStructure.multiply.empty()) {
|
||||||
for (size_t i = 0; i < numSplit; ++i) {
|
dequantizationStructure.multiply.outPrecision = precision;
|
||||||
if (!dequantizationAfter.empty()) {
|
|
||||||
auto dequantizationStructure = dequantizationAfter[i];
|
|
||||||
if (!dequantizationStructure.multiply.empty()) {
|
|
||||||
dequantizationStructure.multiply.outPrecision = precision;
|
|
||||||
}
|
|
||||||
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
|
|
||||||
} else {
|
|
||||||
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
|
|
||||||
}
|
}
|
||||||
|
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
|
||||||
|
} else {
|
||||||
|
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user