[LPT] Removed legacy limitations on dequantization propagation for several transformations (#13048)

This commit is contained in:
Vladislav Golubev 2022-10-10 08:58:23 +02:00 committed by GitHub
parent e8c8ad19f4
commit 372fe475c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1849 additions and 1532 deletions

View File

@ -81,7 +81,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
allDequantizationShiftConvertAreNotZero = false; allDequantizationShiftConvertAreNotZero = false;
} }
// FakeQuantize constant shape must be broadcastable to the shape on data. // constant shape must be broadcastable to the shape on data.
auto broadcastElementWiseConst = [](std::shared_ptr<opset1::Constant> operation, const Shape targetShape) { auto broadcastElementWiseConst = [](std::shared_ptr<opset1::Constant> operation, const Shape targetShape) {
auto targetShapeConst = std::make_shared<opset1::Constant>(element::i64, Shape{ targetShape.size() }, targetShape); auto targetShapeConst = std::make_shared<opset1::Constant>(element::i64, Shape{ targetShape.size() }, targetShape);
auto broadcast = fold<ngraph::opset1::Broadcast>(operation, targetShapeConst); auto broadcast = fold<ngraph::opset1::Broadcast>(operation, targetShapeConst);
@ -99,11 +99,14 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
[](const FakeQuantizeDequantization& value) { return !value.isLowPrecision(); }); [](const FakeQuantizeDequantization& value) { return !value.isLowPrecision(); });
bool DqWithDifferentPrecision = someDqInLowPrecision && someDqInFpPrecision; bool DqWithDifferentPrecision = someDqInLowPrecision && someDqInFpPrecision;
const auto axis = ngraph::normalize_axis(concat->get_friendly_name(),
concat->get_axis(),
concat->get_output_partial_shape(0).rank());
OutputVector dataNodes; OutputVector dataNodes;
NodeVector convertNodes; NodeVector convertNodes;
NodeVector subtractNodes; NodeVector subConstants;
NodeVector multiplyNodes; NodeVector mulConstants;
std::shared_ptr<opset1::Convert> subtractConvert = nullptr; std::shared_ptr<opset1::Convert> subtractConvert = nullptr;
for (size_t i = 0; i < layerDequantizations.size(); ++i) { for (size_t i = 0; i < layerDequantizations.size(); ++i) {
const auto& dequantization = layerDequantizations[i]; const auto& dequantization = layerDequantizations[i];
@ -119,7 +122,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
} }
Shape targetShape(concat->get_input_partial_shape(i).rank().get_length(), 1ul); Shape targetShape(concat->get_input_partial_shape(i).rank().get_length(), 1ul);
targetShape[1] = concat->get_input_partial_shape(i)[1].get_length(); targetShape[axis] = concat->get_input_partial_shape(i)[axis].get_length();
if (!allDequantizationShiftAreZero) { if (!allDequantizationShiftAreZero) {
auto subtractInput = dequantization.subtract == nullptr ? auto subtractInput = dequantization.subtract == nullptr ?
@ -138,11 +141,11 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
subtractInput = foldConvert(subtractInput, dequantization.subtractConvert->get_convert_element_type()); subtractInput = foldConvert(subtractInput, dequantization.subtractConvert->get_convert_element_type());
NetworkHelper::copyInfo(dequantization.subtractConvert, subtractInput); NetworkHelper::copyInfo(dequantization.subtractConvert, subtractInput);
} }
subtractNodes.push_back(subtractInput); subConstants.push_back(subtractInput);
} }
if (!allDequantizationMultiplyAreZero) { if (!allDequantizationMultiplyAreZero) {
multiplyNodes.push_back(dequantization.multiply == nullptr ? mulConstants.push_back(dequantization.multiply == nullptr ?
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 1.0f })) : std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 1.0f })) :
broadcastElementWiseConst(dequantization.multiplyConstant, targetShape)); broadcastElementWiseConst(dequantization.multiplyConstant, targetShape));
} }
@ -159,11 +162,10 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
lastDequantization = convert; lastDequantization = convert;
} }
// concatenation axis is 1 if (!subConstants.empty()) {
if (!subtractNodes.empty()) { std::shared_ptr<ov::Node> subtractNode = subConstants.size() == 1ul ?
std::shared_ptr<ov::Node> subtractNode = subtractNodes.size() == 1ul ? subConstants[0] :
subtractNodes[0] : ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subConstants, axis);
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subtractNodes, 1);
if (subtractConvert != nullptr) if (subtractConvert != nullptr)
subtractNode = subtractConvert->clone_with_new_inputs({subtractNode}); subtractNode = subtractConvert->clone_with_new_inputs({subtractNode});
const auto subtract = std::make_shared<opset1::Subtract>( const auto subtract = std::make_shared<opset1::Subtract>(
@ -175,13 +177,13 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
lastDequantization = subtract; lastDequantization = subtract;
} }
if (!multiplyNodes.empty()) { if (!mulConstants.empty()) {
const auto multiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>( const auto multiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>(
opset1::Multiply( opset1::Multiply(
lastDequantization, lastDequantization,
NetworkHelper::toScalarIfPossible(multiplyNodes.size() == 1ul ? NetworkHelper::toScalarIfPossible(mulConstants.size() == 1ul ?
multiplyNodes[0] : mulConstants[0] :
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(multiplyNodes, 1))), ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(mulConstants, axis))),
layerDequantizations[0].multiply->get_output_element_type(0)); layerDequantizations[0].multiply->get_output_element_type(0));
NetworkHelper::copyInfo({ concat, multiply }, multiply); NetworkHelper::copyInfo({ concat, multiply }, multiply);
@ -213,11 +215,6 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
} }
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, outRank); const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, outRank);
if (normalizedAxis != 1ul) {
return false;
}
if (outPShape[normalizedAxis].is_dynamic()) { if (outPShape[normalizedAxis].is_dynamic()) {
return false; return false;
} }
@ -333,17 +330,9 @@ bool ConcatTransformation::isHandled(const TransformationContext& context, const
bool ConcatTransformation::isQuantizedStatic(const std::shared_ptr<const Node>& layer) { bool ConcatTransformation::isQuantizedStatic(const std::shared_ptr<const Node>& layer) {
const auto concat = as_type_ptr<const opset1::Concat>(layer); const auto concat = as_type_ptr<const opset1::Concat>(layer);
if (concat == nullptr) { if (concat == nullptr)
return false; return false;
} return concat->get_output_partial_shape(0).rank().is_static();
const auto outputRank = concat->get_output_partial_shape(0).rank();
if (outputRank.is_dynamic()) {
return false;
}
const auto normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), concat->get_axis(), outputRank);
return normalizedAxis == 1;
} }
} // namespace low_precision } // namespace low_precision

View File

@ -153,20 +153,7 @@ bool SplitTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) cons
} }
bool SplitTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const { bool SplitTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
if (!LayerTransformation::canBeTransformed(context, layer) || NetworkHelper::getDequantization(layer, defaultPrecisions).empty()) { return !NetworkHelper::getDequantization(layer, defaultPrecisions).empty() && layer->get_input_partial_shape(0).rank().is_static();
return false;
}
const auto consumers = NetworkHelper::consumers(layer);
const auto concat = ov::as_type_ptr<opset1::Concat>(consumers[0]);
// WA to avoid propagation of dequantization if after Split all consumers are the same unsupported Concat
if (concat && concat->get_axis() != 1ul) {
const size_t id = consumers[0]->get_instance_id();
return std::any_of(consumers.begin(), consumers.end(), [&](const std::shared_ptr<Node>& node) { return node->get_instance_id() != id; });
}
return true;
} }
} // namespace low_precision } // namespace low_precision

View File

@ -164,7 +164,7 @@ namespace testValues2 {
{}, {},
{}, {},
"Concatenation", "Concatenation",
"FP32", "I8",
-1 -1
}, },
}; };

View File

@ -127,6 +127,7 @@ public:
result << result <<
LayerTransformation::getTestCaseNameByParams(precision, inputShape, testValues.params) << "_" << LayerTransformation::getTestCaseNameByParams(precision, inputShape, testValues.params) << "_" <<
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") << (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
"axis" << testValues.axis <<
testValues.actual << "_" << testValues.actual << "_" <<
testValues.result << "_"; testValues.result << "_";
return result.str(); return result.str();
@ -147,7 +148,7 @@ const std::vector<ngraph::element::Type> precisions = {
const std::vector<ngraph::PartialShape> shapes = { const std::vector<ngraph::PartialShape> shapes = {
{ 1, 3, 10, 10 }, { 1, 3, 10, 10 },
{ 4, 3, 10, 10 }, { 4, 3, 10, 10 },
{ Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic() } { -1, 3, 10, -1 }
}; };
const std::vector<ConcatTransformationTestValues> testValues = { const std::vector<ConcatTransformationTestValues> testValues = {
@ -171,7 +172,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ ngraph::element::f32, {}, { 0.01f } } { ngraph::element::f32, {}, { 0.01f } }
} }
}, },
// U8 with unsupported axis // U8 concatenation by spatial dimension
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
@ -182,13 +183,13 @@ const std::vector<ConcatTransformationTestValues> testValues = {
}, },
{ {
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {255.f} }, { 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {128.f} },
ngraph::element::u8, ngraph::element::u8,
{{ngraph::element::f32}, {}, {0.01f}}, {},
{{ngraph::element::f32}, {}, {0.005f}}, {},
ngraph::element::f32, ngraph::element::f32,
{{}, {}, {}}, { ngraph::element::f32, {}, { 0.01f } },
{{}, {}, {}} { ngraph::element::f32, {}, { 0.01f } }
} }
}, },
// I8 // I8

View File

@ -1,37 +0,0 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "layer_transformation.hpp"
#include <gtest/gtest.h>
#include "lpt_ngraph_functions/concat_function.hpp"
#include "simple_low_precision_transformer.hpp"
using namespace ::testing;
class smoke_LPT_ConcatWithUnsupportedAxis : public Test {};
TEST_F(smoke_LPT_ConcatWithUnsupportedAxis, rtInfoCheck) {
using namespace ngraph::builder::subgraph;
const ngraph::element::Type precision = ngraph::element::f32;
const ngraph::PartialShape inputPShape = PartialShape{ 1, 3, 16, 16 };
const std::int64_t unsupportedAxis = 2;
const auto fakeQuantize = FakeQuantizeOnData{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} };
std::shared_ptr<ngraph::Function> function = ConcatFunction::getOriginalWithDifferentPrecisionOnChildren(
precision,
inputPShape,
unsupportedAxis,
fakeQuantize,
fakeQuantize);
SimpleLowPrecisionTransformer transformer;
transformer.transform(function);
const auto actualConcat = LayerTransformation::get<opset1::Concat>(function)[0];
const auto& rtInfo = actualConcat->get_rt_info();
ASSERT_TRUE(rtInfo.empty()) << "Unsupported concat mustn't contain LPT runtime attributes";
}

View File

@ -296,7 +296,7 @@ const std::vector<MoveFakeQuantizeTransformationTestValues> testValues = {
{}, {},
} }
}, },
// negative test // concat by batch
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
@ -313,11 +313,11 @@ const std::vector<MoveFakeQuantizeTransformationTestValues> testValues = {
}, },
{ {
2, 2,
{}, {{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}},
{}, {},
{}, {},
"", "",
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, {},
{}, {},
{} {}
} }

View File

@ -43,7 +43,6 @@ public:
TestTransformationParams params; TestTransformationParams params;
Actual actual; Actual actual;
Expected expected; Expected expected;
bool addUnsupportedConcat;
}; };
@ -64,8 +63,7 @@ public:
testValues.actual.precisionBeforeDequantization, testValues.actual.precisionBeforeDequantization,
testValues.actual.dequantization, testValues.actual.dequantization,
testValues.splitedAxis, testValues.splitedAxis,
testValues.numSplits, testValues.numSplits);
testValues.addUnsupportedConcat);
SimpleLowPrecisionTransformer transformer; SimpleLowPrecisionTransformer transformer;
transformer.add<ngraph::pass::low_precision::SplitTransformation, ngraph::opset1::Split>(testValues.params); transformer.add<ngraph::pass::low_precision::SplitTransformation, ngraph::opset1::Split>(testValues.params);
@ -79,8 +77,7 @@ public:
testValues.expected.precisionAfterOperation, testValues.expected.precisionAfterOperation,
testValues.expected.dequantizationAfter, testValues.expected.dequantizationAfter,
testValues.splitedAxis, testValues.splitedAxis,
testValues.numSplits, testValues.numSplits);
testValues.addUnsupportedConcat);
} }
static std::string getTestCaseName(testing::TestParamInfo<SplitTransformationParams> obj) { static std::string getTestCaseName(testing::TestParamInfo<SplitTransformationParams> obj) {
@ -334,6 +331,59 @@ const std::vector<SplitTransformationTestValues> testValues = {
} }
} }
}, },
// per channel quantization with different values, split by batch
{
{ 2, 3, 16, 16 }, std::int64_t{0}, size_t{2},
LayerTransformation::createParamsI8I8(),
{
ngraph::element::i8,
{{ngraph::element::f32},
{{2.f, 3.f}, ngraph::element::f32, {2, 1, 1, 1}},
{{22.f, 33.f}, ngraph::element::f32, {2, 1, 1, 1}}}
},
{
ngraph::element::i8,
{},
ngraph::element::i8,
{
{{ngraph::element::f32}, {2.f}, {22.f}},
{{ngraph::element::f32}, {3.f}, {33.f}},
}
}
},
// per channel quantization with different values, split by spatial dimension
{
{ -1, -1, -1, -1 }, std::int64_t{2}, size_t{3},
LayerTransformation::createParamsI8I8(),
{
ngraph::element::i8,
{{ngraph::element::f32},
{{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, ngraph::element::f32, {1, 1, 6, 1}},
{{11.f, 22.f, 33.f, 44.f, 55.f, 66.f}, ngraph::element::f32, {1, 1, 6, 1}}}
},
{
ngraph::element::i8,
{},
ngraph::element::i8,
{
{
{ngraph::element::f32},
{{1.f, 2.f}, ngraph::element::f32, {1, 1, 2, 1}},
{{11.f, 22.f}, ngraph::element::f32, {1, 1, 2, 1}}
},
{
{ngraph::element::f32},
{{3.f, 4.f}, ngraph::element::f32, {1, 1, 2, 1}},
{{33.f, 44.f}, ngraph::element::f32, {1, 1, 2, 1}}
},
{
{ngraph::element::f32},
{{5.f, 6.f}, ngraph::element::f32, {1, 1, 2, 1}},
{{55.f, 66.f}, ngraph::element::f32, {1, 1, 2, 1}}
},
}
}
},
// U8 per channel quantization with the same values // U8 per channel quantization with the same values
{ {
{ 1, 3, 16, 16 }, std::int64_t{1}, size_t{3}, { 1, 3, 16, 16 }, std::int64_t{1}, size_t{3},
@ -584,39 +634,6 @@ const std::vector<SplitTransformationTestValues> testValues = {
} }
} }
}, },
// issue #56781: unsupported Concat after Split
{
{ 1, 4, 3, 3 }, std::int64_t{2}, size_t{3},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {3.f}}
},
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {3.f}},
ngraph::element::f32,
{}
},
true
},
// issue #56781: unsupported Concat after Split, dynamic channels
{
{ -1, -1, -1, -1 },
std::int64_t{2}, size_t{3},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {3.f}}
},
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {3.f}},
ngraph::element::f32,
{}
},
true
},
// no dequantization // no dequantization
{ {
ngraph::Shape({ 1, 3, 4, 4 }), std::int64_t{2}, size_t{2}, ngraph::Shape({ 1, 3, 4, 4 }), std::int64_t{2}, size_t{2},

View File

@ -196,6 +196,54 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
} }
} }
}, },
// U8 per channel quantization with different values, split by batch
{
{ 2, 3, 16, 16 }, std::int64_t{0}, std::vector<size_t>{ 1, 1 },
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32},
{{2.f, 3.f}, ngraph::element::f32, {2, 1, 1, 1}},
{{22.f, 33.f}, ngraph::element::f32, {2, 1, 1, 1}}}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{{ngraph::element::f32}, {2.f}, {22.f}},
{{ngraph::element::f32}, {3.f}, {33.f}}
}
}
},
// U8 per channel quantization with different values, split by spatial dimension
{
{ -1, -1, -1, -1 }, std::int64_t{3}, std::vector<size_t>{ 4, 2 },
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32},
{{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, ngraph::element::f32, {1, 1, 1, 6}},
{{11.f, 22.f, 33.f, 44.f, 55.f, 66.f}, ngraph::element::f32, {1, 1, 1, 6}}}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{
{ngraph::element::f32},
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 1, 1, 4}},
{{11.f, 22.f, 33.f, 44.f}, ngraph::element::f32, {1, 1, 1, 4}}
},
{
{ngraph::element::f32},
{{5.f, 6.f}, ngraph::element::f32, {1, 1, 1, 2}},
{{55.f, 66.f}, ngraph::element::f32, {1, 1, 1, 2}}
},
}
}
},
// U8 per channel quantization with different values, dynamic shape // U8 per channel quantization with different values, dynamic shape
{ {
{ -1, 3, -1, -1 }, { -1, 3, -1, -1 },

View File

@ -17,6 +17,15 @@ namespace subgraph {
class ConcatFunction { class ConcatFunction {
public: public:
static std::shared_ptr<ov::Model> get(
const ov::element::Type inputPrecision,
const ov::element::Type deqPrecision,
const std::vector<ov::PartialShape>& inputShapes,
const std::vector<DequantizationOperations>& dequantizationsBefore,
const std::int64_t concatAxis,
const ov::element::Type precisionAfter = ov::element::undefined,
const DequantizationOperations& dequantizationAfter = {});
static std::shared_ptr<ngraph::Function> getOriginal( static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type precision, const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape, const ngraph::PartialShape& inputShape,
@ -107,6 +116,15 @@ public:
const FakeQuantizeOnData& fqOnData1, const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2); const FakeQuantizeOnData& fqOnData2);
static std::shared_ptr<ov::Model> getReference(
const ov::element::Type dequantizationPrecision,
const ov::element::Type precisionBefore,
const std::vector<ov::PartialShape>& inputShapes,
const std::vector<DequantizationOperations>& dequantizationsBefore,
const ov::element::Type precisionAfter,
const DequantizationOperations& dequantizationAfter,
const std::int64_t concatAxis);
static std::shared_ptr<ngraph::Function> getReference( static std::shared_ptr<ngraph::Function> getReference(
const ngraph::element::Type precision, const ngraph::element::Type precision,
const ngraph::Shape& inputShape, const ngraph::Shape& inputShape,

View File

@ -25,8 +25,7 @@ public:
const ngraph::element::Type precisionBeforeDequantization, const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization, const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const int64_t splitedAxis, const int64_t splitedAxis,
const size_t numSplits, const size_t numSplits);
const bool addUnsupportedConcat = false);
static std::shared_ptr<ngraph::Function> getOriginal( static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type originalFunctionPrecision, const ngraph::element::Type originalFunctionPrecision,
@ -43,8 +42,7 @@ public:
const ngraph::element::Type precisionAfterOperation, const ngraph::element::Type precisionAfterOperation,
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter, const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
const int64_t splitedAxis, const int64_t splitedAxis,
const size_t numSplit, const size_t numSplits);
const bool addUnsupportedConcat = false);
}; };
} // namespace subgraph } // namespace subgraph
} // namespace builder } // namespace builder

View File

@ -4,7 +4,7 @@
#include "lpt_ngraph_functions/concat_function.hpp" #include "lpt_ngraph_functions/concat_function.hpp"
#include <ngraph/opsets/opset1.hpp> #include <openvino/opsets/opset1.hpp>
#include "ngraph_ops/type_relaxed.hpp" #include "ngraph_ops/type_relaxed.hpp"
#include "low_precision/network_helper.hpp" #include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/precision_preserved_attribute.hpp" #include "low_precision/rt_info/precision_preserved_attribute.hpp"
@ -23,6 +23,46 @@ namespace subgraph {
using namespace ngraph::pass; using namespace ngraph::pass;
std::shared_ptr<ov::Model> ConcatFunction::get(
const ov::element::Type inputPrecision,
const ov::element::Type deqPrecision,
const std::vector<ov::PartialShape>& inputShapes,
const std::vector<DequantizationOperations>& dequantizationsBefore,
const std::int64_t concatAxis,
const ov::element::Type precisionAfter,
const DequantizationOperations& dequantizationAfter) {
auto modifyDeq = [](const DequantizationOperations& deq, const ov::element::Type deqOutPrc) {
auto dequantizationStructure = deq;
if (!dequantizationStructure.multiply.empty()) {
dequantizationStructure.multiply.outPrecision = deqOutPrc;
}
return dequantizationStructure;
};
ov::ParameterVector inputs;
ov::NodeVector concatInputs;
if (inputShapes.size() != dequantizationsBefore.size()) {
throw std::runtime_error("Concat builder: input and dequantization sizes aren't equal");
}
for (size_t i = 0; i < inputShapes.size(); ++i) {
const auto input = std::make_shared<ov::opset1::Parameter>(inputPrecision, inputShapes[i]);
const auto dequantization = makeDequantization(input, modifyDeq(dequantizationsBefore[i], deqPrecision));
inputs.push_back(input);
concatInputs.push_back(dequantization);
}
const auto concat = std::make_shared<ov::opset1::Concat>(concatInputs, concatAxis);
if (precisionAfter != ov::element::undefined && (concat->get_output_element_type(0).is_real() ^ precisionAfter.is_real())) {
throw std::runtime_error("Concat builder: requested precision after operation could't be set");
}
const auto deqAfter = makeDequantization(concat, modifyDeq(dequantizationAfter, deqPrecision));
deqAfter->set_friendly_name("Concat");
const auto result = std::make_shared<ov::opset1::Result>(deqAfter);
return std::make_shared<ov::Model>(ov::ResultVector{result}, inputs, "ConcatTransformation");
}
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginal( std::shared_ptr<ngraph::Function> ConcatFunction::getOriginal(
const ngraph::element::Type precision, const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape, const ngraph::PartialShape& inputShape,

View File

@ -23,25 +23,18 @@ std::shared_ptr<ngraph::Function> SplitFunction::getOriginal(
const ngraph::element::Type precisionBeforeDequantization, const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization, const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const int64_t splitedAxis, const int64_t splitedAxis,
const size_t numSplits, const size_t numSplits) {
const bool addUnsupportedConcat) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, inputShape); const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, inputShape);
auto dequantizationStructure = dequantization; auto dequantizationStructure = dequantization;
dequantizationStructure.multiply.outPrecision = precision; dequantizationStructure.multiply.outPrecision = precision;
const std::shared_ptr<Node> dequantizationOp = makeDequantization(input, dequantization); const auto dequantizationOp = makeDequantization(input, dequantization);
const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis); const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis);
const std::shared_ptr<Node> split = std::make_shared<ngraph::opset1::Split>(dequantizationOp, constant, numSplits); const auto split = std::make_shared<ngraph::opset1::Split>(dequantizationOp, constant, numSplits);
ngraph::ResultVector results; ngraph::ResultVector results;
for (size_t i = 0; i < numSplits; ++i) {
if (addUnsupportedConcat) { results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul);
results.push_back(std::make_shared<opset1::Result>(concat));
} else {
for (size_t i = 0; i < numSplits; ++i) {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
}
} }
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "SplitFunction"); return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "SplitFunction");
} }
@ -82,31 +75,23 @@ std::shared_ptr<ngraph::Function> SplitFunction::getReference(
const ngraph::element::Type precisionAfterOperation, const ngraph::element::Type precisionAfterOperation,
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter, const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
const int64_t splitedAxis, const int64_t splitedAxis,
const size_t numSplit, const size_t numSplit) {
const bool addUnsupportedConcat) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape); const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
const auto deqBefore = makeDequantization(input, dequantizationBefore); const auto deqBefore = makeDequantization(input, dequantizationBefore);
std::shared_ptr<ngraph::opset1::Split> split;
const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis); const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis);
split = std::make_shared<ngraph::opset1::Split>(deqBefore, constant, numSplit); const auto split = std::make_shared<ngraph::opset1::Split>(deqBefore, constant, numSplit);
ngraph::ResultVector results; ngraph::ResultVector results;
if (addUnsupportedConcat) { for (size_t i = 0; i < numSplit; ++i) {
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul); if (!dequantizationAfter.empty()) {
results.push_back(std::make_shared<opset1::Result>(concat)); auto dequantizationStructure = dequantizationAfter[i];
} else { if (!dequantizationStructure.multiply.empty()) {
for (size_t i = 0; i < numSplit; ++i) { dequantizationStructure.multiply.outPrecision = precision;
if (!dequantizationAfter.empty()) {
auto dequantizationStructure = dequantizationAfter[i];
if (!dequantizationStructure.multiply.empty()) {
dequantizationStructure.multiply.outPrecision = precision;
}
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
} else {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
} }
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
} else {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
} }
} }