[LPT] Removed legacy limitations on dequantization propagation for several transformations (#13048)
This commit is contained in:
parent
e8c8ad19f4
commit
372fe475c9
@ -81,7 +81,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
||||
allDequantizationShiftConvertAreNotZero = false;
|
||||
}
|
||||
|
||||
// FakeQuantize constant shape must be broadcastable to the shape on data.
|
||||
// constant shape must be broadcastable to the shape on data.
|
||||
auto broadcastElementWiseConst = [](std::shared_ptr<opset1::Constant> operation, const Shape targetShape) {
|
||||
auto targetShapeConst = std::make_shared<opset1::Constant>(element::i64, Shape{ targetShape.size() }, targetShape);
|
||||
auto broadcast = fold<ngraph::opset1::Broadcast>(operation, targetShapeConst);
|
||||
@ -99,11 +99,14 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
||||
[](const FakeQuantizeDequantization& value) { return !value.isLowPrecision(); });
|
||||
|
||||
bool DqWithDifferentPrecision = someDqInLowPrecision && someDqInFpPrecision;
|
||||
const auto axis = ngraph::normalize_axis(concat->get_friendly_name(),
|
||||
concat->get_axis(),
|
||||
concat->get_output_partial_shape(0).rank());
|
||||
|
||||
OutputVector dataNodes;
|
||||
NodeVector convertNodes;
|
||||
NodeVector subtractNodes;
|
||||
NodeVector multiplyNodes;
|
||||
NodeVector subConstants;
|
||||
NodeVector mulConstants;
|
||||
std::shared_ptr<opset1::Convert> subtractConvert = nullptr;
|
||||
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
|
||||
const auto& dequantization = layerDequantizations[i];
|
||||
@ -119,7 +122,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
||||
}
|
||||
|
||||
Shape targetShape(concat->get_input_partial_shape(i).rank().get_length(), 1ul);
|
||||
targetShape[1] = concat->get_input_partial_shape(i)[1].get_length();
|
||||
targetShape[axis] = concat->get_input_partial_shape(i)[axis].get_length();
|
||||
|
||||
if (!allDequantizationShiftAreZero) {
|
||||
auto subtractInput = dequantization.subtract == nullptr ?
|
||||
@ -138,11 +141,11 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
||||
subtractInput = foldConvert(subtractInput, dequantization.subtractConvert->get_convert_element_type());
|
||||
NetworkHelper::copyInfo(dequantization.subtractConvert, subtractInput);
|
||||
}
|
||||
subtractNodes.push_back(subtractInput);
|
||||
subConstants.push_back(subtractInput);
|
||||
}
|
||||
|
||||
if (!allDequantizationMultiplyAreZero) {
|
||||
multiplyNodes.push_back(dequantization.multiply == nullptr ?
|
||||
mulConstants.push_back(dequantization.multiply == nullptr ?
|
||||
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 1.0f })) :
|
||||
broadcastElementWiseConst(dequantization.multiplyConstant, targetShape));
|
||||
}
|
||||
@ -159,11 +162,10 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
||||
lastDequantization = convert;
|
||||
}
|
||||
|
||||
// concatenation axis is 1
|
||||
if (!subtractNodes.empty()) {
|
||||
std::shared_ptr<ov::Node> subtractNode = subtractNodes.size() == 1ul ?
|
||||
subtractNodes[0] :
|
||||
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subtractNodes, 1);
|
||||
if (!subConstants.empty()) {
|
||||
std::shared_ptr<ov::Node> subtractNode = subConstants.size() == 1ul ?
|
||||
subConstants[0] :
|
||||
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subConstants, axis);
|
||||
if (subtractConvert != nullptr)
|
||||
subtractNode = subtractConvert->clone_with_new_inputs({subtractNode});
|
||||
const auto subtract = std::make_shared<opset1::Subtract>(
|
||||
@ -175,13 +177,13 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
||||
lastDequantization = subtract;
|
||||
}
|
||||
|
||||
if (!multiplyNodes.empty()) {
|
||||
if (!mulConstants.empty()) {
|
||||
const auto multiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>(
|
||||
opset1::Multiply(
|
||||
lastDequantization,
|
||||
NetworkHelper::toScalarIfPossible(multiplyNodes.size() == 1ul ?
|
||||
multiplyNodes[0] :
|
||||
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(multiplyNodes, 1))),
|
||||
NetworkHelper::toScalarIfPossible(mulConstants.size() == 1ul ?
|
||||
mulConstants[0] :
|
||||
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(mulConstants, axis))),
|
||||
layerDequantizations[0].multiply->get_output_element_type(0));
|
||||
|
||||
NetworkHelper::copyInfo({ concat, multiply }, multiply);
|
||||
@ -213,11 +215,6 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
|
||||
}
|
||||
|
||||
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, outRank);
|
||||
|
||||
if (normalizedAxis != 1ul) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (outPShape[normalizedAxis].is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
@ -333,17 +330,9 @@ bool ConcatTransformation::isHandled(const TransformationContext& context, const
|
||||
|
||||
bool ConcatTransformation::isQuantizedStatic(const std::shared_ptr<const Node>& layer) {
|
||||
const auto concat = as_type_ptr<const opset1::Concat>(layer);
|
||||
if (concat == nullptr) {
|
||||
if (concat == nullptr)
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto outputRank = concat->get_output_partial_shape(0).rank();
|
||||
if (outputRank.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), concat->get_axis(), outputRank);
|
||||
return normalizedAxis == 1;
|
||||
return concat->get_output_partial_shape(0).rank().is_static();
|
||||
}
|
||||
|
||||
} // namespace low_precision
|
||||
|
@ -153,20 +153,7 @@ bool SplitTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) cons
|
||||
}
|
||||
|
||||
bool SplitTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
|
||||
if (!LayerTransformation::canBeTransformed(context, layer) || NetworkHelper::getDequantization(layer, defaultPrecisions).empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto consumers = NetworkHelper::consumers(layer);
|
||||
const auto concat = ov::as_type_ptr<opset1::Concat>(consumers[0]);
|
||||
|
||||
// WA to avoid propagation of dequantization if after Split all consumers are the same unsupported Concat
|
||||
if (concat && concat->get_axis() != 1ul) {
|
||||
const size_t id = consumers[0]->get_instance_id();
|
||||
return std::any_of(consumers.begin(), consumers.end(), [&](const std::shared_ptr<Node>& node) { return node->get_instance_id() != id; });
|
||||
}
|
||||
|
||||
return true;
|
||||
return !NetworkHelper::getDequantization(layer, defaultPrecisions).empty() && layer->get_input_partial_shape(0).rank().is_static();
|
||||
}
|
||||
|
||||
} // namespace low_precision
|
||||
|
@ -164,7 +164,7 @@ namespace testValues2 {
|
||||
{},
|
||||
{},
|
||||
"Concatenation",
|
||||
"FP32",
|
||||
"I8",
|
||||
-1
|
||||
},
|
||||
};
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -127,6 +127,7 @@ public:
|
||||
result <<
|
||||
LayerTransformation::getTestCaseNameByParams(precision, inputShape, testValues.params) << "_" <<
|
||||
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
|
||||
"axis" << testValues.axis <<
|
||||
testValues.actual << "_" <<
|
||||
testValues.result << "_";
|
||||
return result.str();
|
||||
@ -147,7 +148,7 @@ const std::vector<ngraph::element::Type> precisions = {
|
||||
const std::vector<ngraph::PartialShape> shapes = {
|
||||
{ 1, 3, 10, 10 },
|
||||
{ 4, 3, 10, 10 },
|
||||
{ Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic() }
|
||||
{ -1, 3, 10, -1 }
|
||||
};
|
||||
|
||||
const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
@ -171,7 +172,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
{ ngraph::element::f32, {}, { 0.01f } }
|
||||
}
|
||||
},
|
||||
// U8 with unsupported axis
|
||||
// U8 concatenation by spatial dimension
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
false,
|
||||
@ -182,13 +183,13 @@ const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
},
|
||||
{
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {255.f} },
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {128.f} },
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {}, {0.01f}},
|
||||
{{ngraph::element::f32}, {}, {0.005f}},
|
||||
{},
|
||||
{},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {}},
|
||||
{{}, {}, {}}
|
||||
{ ngraph::element::f32, {}, { 0.01f } },
|
||||
{ ngraph::element::f32, {}, { 0.01f } }
|
||||
}
|
||||
},
|
||||
// I8
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,37 +0,0 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "layer_transformation.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "lpt_ngraph_functions/concat_function.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
|
||||
using namespace ::testing;
|
||||
|
||||
class smoke_LPT_ConcatWithUnsupportedAxis : public Test {};
|
||||
|
||||
TEST_F(smoke_LPT_ConcatWithUnsupportedAxis, rtInfoCheck) {
|
||||
using namespace ngraph::builder::subgraph;
|
||||
|
||||
const ngraph::element::Type precision = ngraph::element::f32;
|
||||
const ngraph::PartialShape inputPShape = PartialShape{ 1, 3, 16, 16 };
|
||||
const std::int64_t unsupportedAxis = 2;
|
||||
const auto fakeQuantize = FakeQuantizeOnData{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} };
|
||||
|
||||
std::shared_ptr<ngraph::Function> function = ConcatFunction::getOriginalWithDifferentPrecisionOnChildren(
|
||||
precision,
|
||||
inputPShape,
|
||||
unsupportedAxis,
|
||||
fakeQuantize,
|
||||
fakeQuantize);
|
||||
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
transformer.transform(function);
|
||||
|
||||
const auto actualConcat = LayerTransformation::get<opset1::Concat>(function)[0];
|
||||
const auto& rtInfo = actualConcat->get_rt_info();
|
||||
ASSERT_TRUE(rtInfo.empty()) << "Unsupported concat mustn't contain LPT runtime attributes";
|
||||
}
|
@ -296,7 +296,7 @@ const std::vector<MoveFakeQuantizeTransformationTestValues> testValues = {
|
||||
{},
|
||||
}
|
||||
},
|
||||
// negative test
|
||||
// concat by batch
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
false,
|
||||
@ -313,11 +313,11 @@ const std::vector<MoveFakeQuantizeTransformationTestValues> testValues = {
|
||||
},
|
||||
{
|
||||
2,
|
||||
{},
|
||||
{{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}},
|
||||
{},
|
||||
{},
|
||||
"",
|
||||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
|
||||
{},
|
||||
{},
|
||||
{}
|
||||
}
|
||||
|
@ -43,7 +43,6 @@ public:
|
||||
TestTransformationParams params;
|
||||
Actual actual;
|
||||
Expected expected;
|
||||
bool addUnsupportedConcat;
|
||||
};
|
||||
|
||||
|
||||
@ -64,8 +63,7 @@ public:
|
||||
testValues.actual.precisionBeforeDequantization,
|
||||
testValues.actual.dequantization,
|
||||
testValues.splitedAxis,
|
||||
testValues.numSplits,
|
||||
testValues.addUnsupportedConcat);
|
||||
testValues.numSplits);
|
||||
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
transformer.add<ngraph::pass::low_precision::SplitTransformation, ngraph::opset1::Split>(testValues.params);
|
||||
@ -79,8 +77,7 @@ public:
|
||||
testValues.expected.precisionAfterOperation,
|
||||
testValues.expected.dequantizationAfter,
|
||||
testValues.splitedAxis,
|
||||
testValues.numSplits,
|
||||
testValues.addUnsupportedConcat);
|
||||
testValues.numSplits);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<SplitTransformationParams> obj) {
|
||||
@ -334,6 +331,59 @@ const std::vector<SplitTransformationTestValues> testValues = {
|
||||
}
|
||||
}
|
||||
},
|
||||
// per channel quantization with different values, split by batch
|
||||
{
|
||||
{ 2, 3, 16, 16 }, std::int64_t{0}, size_t{2},
|
||||
LayerTransformation::createParamsI8I8(),
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32},
|
||||
{{2.f, 3.f}, ngraph::element::f32, {2, 1, 1, 1}},
|
||||
{{22.f, 33.f}, ngraph::element::f32, {2, 1, 1, 1}}}
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{},
|
||||
ngraph::element::i8,
|
||||
{
|
||||
{{ngraph::element::f32}, {2.f}, {22.f}},
|
||||
{{ngraph::element::f32}, {3.f}, {33.f}},
|
||||
}
|
||||
}
|
||||
},
|
||||
// per channel quantization with different values, split by spatial dimension
|
||||
{
|
||||
{ -1, -1, -1, -1 }, std::int64_t{2}, size_t{3},
|
||||
LayerTransformation::createParamsI8I8(),
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32},
|
||||
{{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, ngraph::element::f32, {1, 1, 6, 1}},
|
||||
{{11.f, 22.f, 33.f, 44.f, 55.f, 66.f}, ngraph::element::f32, {1, 1, 6, 1}}}
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{},
|
||||
ngraph::element::i8,
|
||||
{
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{1.f, 2.f}, ngraph::element::f32, {1, 1, 2, 1}},
|
||||
{{11.f, 22.f}, ngraph::element::f32, {1, 1, 2, 1}}
|
||||
},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{3.f, 4.f}, ngraph::element::f32, {1, 1, 2, 1}},
|
||||
{{33.f, 44.f}, ngraph::element::f32, {1, 1, 2, 1}}
|
||||
},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{5.f, 6.f}, ngraph::element::f32, {1, 1, 2, 1}},
|
||||
{{55.f, 66.f}, ngraph::element::f32, {1, 1, 2, 1}}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8 per channel quantization with the same values
|
||||
{
|
||||
{ 1, 3, 16, 16 }, std::int64_t{1}, size_t{3},
|
||||
@ -584,39 +634,6 @@ const std::vector<SplitTransformationTestValues> testValues = {
|
||||
}
|
||||
}
|
||||
},
|
||||
// issue #56781: unsupported Concat after Split
|
||||
{
|
||||
{ 1, 4, 3, 3 }, std::int64_t{2}, size_t{3},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {3.f}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {3.f}},
|
||||
ngraph::element::f32,
|
||||
{}
|
||||
},
|
||||
true
|
||||
},
|
||||
// issue #56781: unsupported Concat after Split, dynamic channels
|
||||
{
|
||||
{ -1, -1, -1, -1 },
|
||||
std::int64_t{2}, size_t{3},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {3.f}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {3.f}},
|
||||
ngraph::element::f32,
|
||||
{}
|
||||
},
|
||||
true
|
||||
},
|
||||
// no dequantization
|
||||
{
|
||||
ngraph::Shape({ 1, 3, 4, 4 }), std::int64_t{2}, size_t{2},
|
||||
|
@ -196,6 +196,54 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8 per channel quantization with different values, split by batch
|
||||
{
|
||||
{ 2, 3, 16, 16 }, std::int64_t{0}, std::vector<size_t>{ 1, 1 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{2.f, 3.f}, ngraph::element::f32, {2, 1, 1, 1}},
|
||||
{{22.f, 33.f}, ngraph::element::f32, {2, 1, 1, 1}}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{{ngraph::element::f32}, {2.f}, {22.f}},
|
||||
{{ngraph::element::f32}, {3.f}, {33.f}}
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8 per channel quantization with different values, split by spatial dimension
|
||||
{
|
||||
{ -1, -1, -1, -1 }, std::int64_t{3}, std::vector<size_t>{ 4, 2 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, ngraph::element::f32, {1, 1, 1, 6}},
|
||||
{{11.f, 22.f, 33.f, 44.f, 55.f, 66.f}, ngraph::element::f32, {1, 1, 1, 6}}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 1, 1, 4}},
|
||||
{{11.f, 22.f, 33.f, 44.f}, ngraph::element::f32, {1, 1, 1, 4}}
|
||||
},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{5.f, 6.f}, ngraph::element::f32, {1, 1, 1, 2}},
|
||||
{{55.f, 66.f}, ngraph::element::f32, {1, 1, 1, 2}}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8 per channel quantization with different values, dynamic shape
|
||||
{
|
||||
{ -1, 3, -1, -1 },
|
||||
|
@ -17,6 +17,15 @@ namespace subgraph {
|
||||
|
||||
class ConcatFunction {
|
||||
public:
|
||||
static std::shared_ptr<ov::Model> get(
|
||||
const ov::element::Type inputPrecision,
|
||||
const ov::element::Type deqPrecision,
|
||||
const std::vector<ov::PartialShape>& inputShapes,
|
||||
const std::vector<DequantizationOperations>& dequantizationsBefore,
|
||||
const std::int64_t concatAxis,
|
||||
const ov::element::Type precisionAfter = ov::element::undefined,
|
||||
const DequantizationOperations& dequantizationAfter = {});
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::PartialShape& inputShape,
|
||||
@ -107,6 +116,15 @@ public:
|
||||
const FakeQuantizeOnData& fqOnData1,
|
||||
const FakeQuantizeOnData& fqOnData2);
|
||||
|
||||
static std::shared_ptr<ov::Model> getReference(
|
||||
const ov::element::Type dequantizationPrecision,
|
||||
const ov::element::Type precisionBefore,
|
||||
const std::vector<ov::PartialShape>& inputShapes,
|
||||
const std::vector<DequantizationOperations>& dequantizationsBefore,
|
||||
const ov::element::Type precisionAfter,
|
||||
const DequantizationOperations& dequantizationAfter,
|
||||
const std::int64_t concatAxis);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getReference(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
|
@ -25,8 +25,7 @@ public:
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const int64_t splitedAxis,
|
||||
const size_t numSplits,
|
||||
const bool addUnsupportedConcat = false);
|
||||
const size_t numSplits);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type originalFunctionPrecision,
|
||||
@ -43,8 +42,7 @@ public:
|
||||
const ngraph::element::Type precisionAfterOperation,
|
||||
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
|
||||
const int64_t splitedAxis,
|
||||
const size_t numSplit,
|
||||
const bool addUnsupportedConcat = false);
|
||||
const size_t numSplits);
|
||||
};
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#include "lpt_ngraph_functions/concat_function.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <openvino/opsets/opset1.hpp>
|
||||
#include "ngraph_ops/type_relaxed.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
|
||||
@ -23,6 +23,46 @@ namespace subgraph {
|
||||
|
||||
using namespace ngraph::pass;
|
||||
|
||||
std::shared_ptr<ov::Model> ConcatFunction::get(
|
||||
const ov::element::Type inputPrecision,
|
||||
const ov::element::Type deqPrecision,
|
||||
const std::vector<ov::PartialShape>& inputShapes,
|
||||
const std::vector<DequantizationOperations>& dequantizationsBefore,
|
||||
const std::int64_t concatAxis,
|
||||
const ov::element::Type precisionAfter,
|
||||
const DequantizationOperations& dequantizationAfter) {
|
||||
auto modifyDeq = [](const DequantizationOperations& deq, const ov::element::Type deqOutPrc) {
|
||||
auto dequantizationStructure = deq;
|
||||
if (!dequantizationStructure.multiply.empty()) {
|
||||
dequantizationStructure.multiply.outPrecision = deqOutPrc;
|
||||
}
|
||||
return dequantizationStructure;
|
||||
};
|
||||
|
||||
ov::ParameterVector inputs;
|
||||
ov::NodeVector concatInputs;
|
||||
if (inputShapes.size() != dequantizationsBefore.size()) {
|
||||
throw std::runtime_error("Concat builder: input and dequantization sizes aren't equal");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inputShapes.size(); ++i) {
|
||||
const auto input = std::make_shared<ov::opset1::Parameter>(inputPrecision, inputShapes[i]);
|
||||
const auto dequantization = makeDequantization(input, modifyDeq(dequantizationsBefore[i], deqPrecision));
|
||||
inputs.push_back(input);
|
||||
concatInputs.push_back(dequantization);
|
||||
}
|
||||
|
||||
const auto concat = std::make_shared<ov::opset1::Concat>(concatInputs, concatAxis);
|
||||
if (precisionAfter != ov::element::undefined && (concat->get_output_element_type(0).is_real() ^ precisionAfter.is_real())) {
|
||||
throw std::runtime_error("Concat builder: requested precision after operation could't be set");
|
||||
}
|
||||
|
||||
const auto deqAfter = makeDequantization(concat, modifyDeq(dequantizationAfter, deqPrecision));
|
||||
deqAfter->set_friendly_name("Concat");
|
||||
const auto result = std::make_shared<ov::opset1::Result>(deqAfter);
|
||||
return std::make_shared<ov::Model>(ov::ResultVector{result}, inputs, "ConcatTransformation");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::PartialShape& inputShape,
|
||||
|
@ -23,25 +23,18 @@ std::shared_ptr<ngraph::Function> SplitFunction::getOriginal(
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const int64_t splitedAxis,
|
||||
const size_t numSplits,
|
||||
const bool addUnsupportedConcat) {
|
||||
const size_t numSplits) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, inputShape);
|
||||
|
||||
auto dequantizationStructure = dequantization;
|
||||
dequantizationStructure.multiply.outPrecision = precision;
|
||||
const std::shared_ptr<Node> dequantizationOp = makeDequantization(input, dequantization);
|
||||
const auto dequantizationOp = makeDequantization(input, dequantization);
|
||||
const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis);
|
||||
const std::shared_ptr<Node> split = std::make_shared<ngraph::opset1::Split>(dequantizationOp, constant, numSplits);
|
||||
const auto split = std::make_shared<ngraph::opset1::Split>(dequantizationOp, constant, numSplits);
|
||||
|
||||
ngraph::ResultVector results;
|
||||
|
||||
if (addUnsupportedConcat) {
|
||||
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul);
|
||||
results.push_back(std::make_shared<opset1::Result>(concat));
|
||||
} else {
|
||||
for (size_t i = 0; i < numSplits; ++i) {
|
||||
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
|
||||
}
|
||||
for (size_t i = 0; i < numSplits; ++i) {
|
||||
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
|
||||
}
|
||||
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "SplitFunction");
|
||||
}
|
||||
@ -82,31 +75,23 @@ std::shared_ptr<ngraph::Function> SplitFunction::getReference(
|
||||
const ngraph::element::Type precisionAfterOperation,
|
||||
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
|
||||
const int64_t splitedAxis,
|
||||
const size_t numSplit,
|
||||
const bool addUnsupportedConcat) {
|
||||
const size_t numSplit) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
|
||||
|
||||
const auto deqBefore = makeDequantization(input, dequantizationBefore);
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Split> split;
|
||||
const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis);
|
||||
split = std::make_shared<ngraph::opset1::Split>(deqBefore, constant, numSplit);
|
||||
const auto split = std::make_shared<ngraph::opset1::Split>(deqBefore, constant, numSplit);
|
||||
|
||||
ngraph::ResultVector results;
|
||||
if (addUnsupportedConcat) {
|
||||
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul);
|
||||
results.push_back(std::make_shared<opset1::Result>(concat));
|
||||
} else {
|
||||
for (size_t i = 0; i < numSplit; ++i) {
|
||||
if (!dequantizationAfter.empty()) {
|
||||
auto dequantizationStructure = dequantizationAfter[i];
|
||||
if (!dequantizationStructure.multiply.empty()) {
|
||||
dequantizationStructure.multiply.outPrecision = precision;
|
||||
}
|
||||
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
|
||||
} else {
|
||||
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
|
||||
for (size_t i = 0; i < numSplit; ++i) {
|
||||
if (!dequantizationAfter.empty()) {
|
||||
auto dequantizationStructure = dequantizationAfter[i];
|
||||
if (!dequantizationStructure.multiply.empty()) {
|
||||
dequantizationStructure.multiply.outPrecision = precision;
|
||||
}
|
||||
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
|
||||
} else {
|
||||
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user