[LPT] StridedSlice support in ConcatTransformation & ConcatMultiChannelsTransformation (#3950)

* [LPT] ConcatTransformation: supported StridedSlice

* [LPT] Concat with StridedSlice funcional tests
This commit is contained in:
Vladislav Golubev
2021-02-08 22:22:44 +03:00
committed by GitHub
parent e2c67bd508
commit 7aaaa293d5
7 changed files with 740 additions and 52 deletions

View File

@@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@@ -32,12 +32,21 @@ private:
std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate);
static void fillQuantization(const std::shared_ptr<ngraph::Node> layer, std::vector<std::shared_ptr<ngraph::opset1::FakeQuantize>>& fakeQuantizes);
static void fillQuantization(
const std::shared_ptr<ngraph::Node> layer,
const std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantization);
static void updateDequantizationShapesIfNecessary(
std::shared_ptr<ngraph::Node> layer,
std::vector<std::shared_ptr<ngraph::opset1::FakeQuantize>>& fakeQuantizes,
std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize);
static FakeQuantizeDequantization getConcatenatedDequantization(
const std::shared_ptr<ngraph::opset1::Concat> concat,
const std::vector<FakeQuantizeDequantization>& dequantization);
static FakeQuantizeDequantization getFoldedDequantization(
const std::shared_ptr<ngraph::Node> operation,
const FakeQuantizeDequantization& dequantization,
const size_t sourceOutputIdx);
static FakeQuantizeDequantization broadcastDequantiationConstant(const FakeQuantizeDequantization& deq);
bool isMultiChannel(const std::vector<std::shared_ptr<ngraph::opset1::Concat>>& concatLayers) const noexcept;
};

View File

@@ -267,6 +267,7 @@ void ConcatTransformation::addDequantizationLayers(
if (subgraph.layers.find(child.get_friendly_name()) == subgraph.layers.end()) {
if (layerDequantizations.size() == 0ul) {
// fill layerDequantizations collection
getLayerDequantizationCallback(layer, layer->get_friendly_name(), layerDequantizations);
}
@@ -276,6 +277,7 @@ void ConcatTransformation::addDequantizationLayers(
std::vector<std::shared_ptr<ngraph::Node>> subtractNodes;
std::vector<std::shared_ptr<ngraph::Node>> multiplyNodes;
// forming nodes for concatenation
if (layerDequantizations.size() > 1ul) {
auto broadcastElementWiseConst = [](
// FakeQuantize constant shape must be broadcastable to the shape on data.
@@ -312,12 +314,8 @@ void ConcatTransformation::addDequantizationLayers(
}
const ngraph::element::Type precision = dequantization.data.get_element_type();
ngraph::Shape targetShape = dequantization.data.get_shape();
targetShape[0] = 1ul;
for (size_t i = 2; i < targetShape.size(); ++i) {
targetShape[i] = 1ul;
}
ngraph::Shape targetShape(dequantization.data.get_shape().size(), 1ul);
targetShape[1] = dequantization.data.get_shape()[1];
if (!allDequantizationShiftAreZero) {
subtractNodes.push_back(dequantization.subtract == nullptr ?

View File

@@ -14,6 +14,7 @@
#include <ngraph/opsets/opset1.hpp>
#include "low_precision/common/fake_quantize_dequantization.hpp"
#include "low_precision/common/dequantization_op.hpp"
#include "low_precision/common/ie_lpt_exception.hpp"
#include "low_precision/common/subgraph.hpp"
#include "low_precision/network_helper.hpp"
@@ -184,63 +185,207 @@ void ConcatMultiChannelsTransformation::fillDequantization(
std::shared_ptr<ngraph::Node> layer,
std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) {
std::vector<std::shared_ptr<ngraph::opset1::FakeQuantize>> fakeQuantizes;
std::shared_ptr<ngraph::opset1::FakeQuantize> currentFakeQuantize = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(layer);
if (currentFakeQuantize != nullptr) {
fakeQuantizes.push_back(currentFakeQuantize);
} else {
fillQuantization(layer, fakeQuantizes);
if (fakeQuantizes.size() == layer->get_input_size()) {
updateDequantizationShapesIfNecessary(layer, fakeQuantizes, dequantizationByFakeQuantize);
}
}
for (const auto& fakeQuantize : fakeQuantizes) {
const auto it = dequantizationByFakeQuantize.find(fakeQuantize->get_friendly_name());
if (currentFakeQuantize) {
const auto it = dequantizationByFakeQuantize.find(currentFakeQuantize->get_friendly_name());
if (it == dequantizationByFakeQuantize.end()) {
THROW_IE_LPT_EXCEPTION(*fakeQuantize) << "dequantization scale values are not found";
THROW_IE_LPT_EXCEPTION(*currentFakeQuantize) << "dequantization scale values are not found";
}
const FakeQuantizeDequantization& fakeQuantizeDequantization = it->second;
dequantizationsToConcatenate.push_back(fakeQuantizeDequantization);
}
}
void ConcatMultiChannelsTransformation::updateDequantizationShapesIfNecessary(
std::shared_ptr<ngraph::Node> layer,
std::vector<std::shared_ptr<ngraph::opset1::FakeQuantize>>& fakeQuantizes,
std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize) {
for (size_t i = 0; i < fakeQuantizes.size(); ++i) {
ngraph::Shape inputShape = layer->get_input_shape(i);
ngraph::Shape dequantizationShape = fakeQuantizes[i]->get_shape();
if (inputShape[1] != dequantizationShape[1]) {
FakeQuantizeDequantization replacedDequantization = dequantizationByFakeQuantize[fakeQuantizes[i]->get_friendly_name()];
const float scale = as_type_ptr<ngraph::opset1::Constant>(replacedDequantization.multiply->get_input_node_shared_ptr(1))->cast_vector<float>()[0];
const float shift = replacedDequantization.subtract ? replacedDequantization.subtractConstant->cast_vector<float>()[0] : 0.f;
const auto precisionBefore = replacedDequantization.data.get_element_type();
const auto precisionAfter = replacedDequantization.multiply->get_element_type();
auto newDequantization = ngraph::pass::low_precision::NetworkHelper::makeDequantization(
scale, shift, precisionBefore, inputShape, precisionAfter, 0.f, 5.f);
dequantizationByFakeQuantize[fakeQuantizes[i]->get_friendly_name()] = newDequantization;
}
dequantizationsToConcatenate.push_back(broadcastDequantiationConstant(fakeQuantizeDequantization));
} else {
fillQuantization(layer, dequantizationByFakeQuantize, dequantizationsToConcatenate);
}
}
void ConcatMultiChannelsTransformation::fillQuantization(
const std::shared_ptr<ngraph::Node> layer,
std::vector<std::shared_ptr<ngraph::opset1::FakeQuantize>>& fakeQuantizes) {
const std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantization) {
for (size_t i = 0; i < layer->get_input_size(); ++i) {
std::shared_ptr<ngraph::Node> parent = layer->get_input_node_shared_ptr(i);
std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantize = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(parent);
if (fakeQuantize != nullptr) {
fakeQuantizes.push_back(fakeQuantize);
if (fakeQuantize) {
const auto it = dequantizationByFakeQuantize.find(fakeQuantize->get_friendly_name());
if (it == dequantizationByFakeQuantize.end()) {
THROW_IE_LPT_EXCEPTION(*fakeQuantize) << "dequantization scale values are not found";
}
const FakeQuantizeDequantization& fakeQuantizeDequantization = it->second;
dequantization.push_back(broadcastDequantiationConstant(fakeQuantizeDequantization));
} else {
fillQuantization(parent, fakeQuantizes);
std::shared_ptr<ngraph::opset1::Concat> concat = ngraph::as_type_ptr<ngraph::opset1::Concat>(parent);
if (concat) {
std::vector<FakeQuantizeDequantization> dequantizationToConcatenate;
fillQuantization(concat, dequantizationByFakeQuantize, dequantizationToConcatenate);
// add concatenated dequantization operations to dequantization collection
dequantization.push_back(getConcatenatedDequantization(concat, dequantizationToConcatenate));
} else {
std::shared_ptr<ngraph::opset1::StridedSlice> stridedSlice = ngraph::as_type_ptr<ngraph::opset1::StridedSlice>(parent);
if (stridedSlice) {
std::vector<FakeQuantizeDequantization> dequantizationToPropagate;
fillQuantization(stridedSlice, dequantizationByFakeQuantize, dequantizationToPropagate);
const size_t sourceOutputIdx = NetworkHelper::getParentOutputIndex(parent, layer);
// add folded dequantization operations to dequantization colection
dequantization.push_back(getFoldedDequantization(stridedSlice, dequantizationToPropagate[0], sourceOutputIdx));
} else {
fillQuantization(parent, dequantizationByFakeQuantize, dequantization);
}
}
}
}
}
// broadcast of dequantization constants by channels
FakeQuantizeDequantization ConcatMultiChannelsTransformation::broadcastDequantiationConstant(const FakeQuantizeDequantization& deq) {
ngraph::Shape targetShape(deq.data.get_shape().size(), 1ul);
targetShape[1] = deq.data.get_shape()[1];
FakeQuantizeDequantization result;
result.data = deq.data;
result.convert = deq.convert;
const auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
element::i64, ngraph::Shape{ targetShape.size() },
targetShape);
if (deq.subtract) {
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
deq.subtractConstant,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
result.subtract = deq.subtract;
result.subtractConstant = as_type_ptr<ngraph::opset1::Constant>(broadcast);
}
if (deq.multiply) {
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
deq.multiplyConstant,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
result.multiply = deq.multiply;
result.multiplyConstant = as_type_ptr<ngraph::opset1::Constant>(broadcast);
}
return result;
}
FakeQuantizeDequantization ConcatMultiChannelsTransformation::getConcatenatedDequantization(
const std::shared_ptr<ngraph::opset1::Concat> concat,
const std::vector<FakeQuantizeDequantization>& dequantization) {
bool allDequantizationShiftAreZero = true;
bool allDequantizationMultiplyAreZero = true;
for (const auto& deq : dequantization) {
if (deq.subtract != nullptr) {
allDequantizationShiftAreZero = false;
}
if (deq.multiply != nullptr) {
allDequantizationMultiplyAreZero = false;
}
}
NodeVector convertNodes;
NodeVector subNodes;
NodeVector mulNodes;
//preparing to concatenate dequantization nodes
for (const auto& deq : dequantization) {
const ngraph::element::Type precision = deq.data.get_element_type();
ngraph::Shape targetShape(deq.data.get_shape().size(), 1ul);
targetShape[1] = deq.data.get_shape()[1];
if (deq.convert != nullptr) {
convertNodes.push_back(deq.convert);
}
if (!allDequantizationShiftAreZero) {
subNodes.push_back(deq.subtract == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 0.f })) :
deq.subtractConstant);
}
if (!allDequantizationMultiplyAreZero) {
mulNodes.push_back(deq.multiply == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 1.0f })) :
deq.multiplyConstant);
}
}
std::shared_ptr<Node> parent = concat;
std::shared_ptr<DequantizationConvert> convert;
if (!convertNodes.empty()) {
convert = as_type_ptr<DequantizationConvert>(dequantization[0].convert->clone_with_new_inputs({ parent }));
parent = convert;
}
std::shared_ptr<DequantizationSubtract> subtract;
std::shared_ptr<ngraph::opset1::Constant> subConst;
if (!subNodes.empty()) {
subConst = as_type_ptr<ngraph::opset1::Constant>(
subNodes.size() == 1ul ? subNodes[0] : fold<ngraph::opset1::Concat>(subNodes, 1ul));
subtract = std::make_shared<DequantizationSubtract>(parent, subConst);
parent = subtract;
}
std::shared_ptr<DequantizationMultiply> multiply;
std::shared_ptr<ngraph::opset1::Constant> mulConst;
if (!mulNodes.empty()) {
mulConst = as_type_ptr<ngraph::opset1::Constant>(
mulNodes.size() == 1ul ? mulNodes[0] : fold<ngraph::opset1::Concat>(mulNodes, 1ul));
multiply = std::make_shared<DequantizationMultiply>(parent, mulConst);
}
return FakeQuantizeDequantization(concat, convert, subtract, nullptr, subConst, multiply, mulConst);
}
FakeQuantizeDequantization ConcatMultiChannelsTransformation::getFoldedDequantization(
const std::shared_ptr<ngraph::Node> operation,
const FakeQuantizeDequantization& dequantization,
const size_t sourceOutputIdx) {
OutputVector inputs = operation->input_values();
OutputVector outputs(operation->get_output_size());
std::shared_ptr<Node> parent = operation;
std::shared_ptr<DequantizationConvert> convert;
if (dequantization.convert) {
convert = as_type_ptr<DequantizationConvert>(dequantization.convert->clone_with_new_inputs({ parent }));
parent = convert;
}
std::shared_ptr<DequantizationSubtract> subtract;
std::shared_ptr<ngraph::opset1::Constant> subConst;
if (dequantization.subtract) {
inputs[0] = dequantization.subtractConstant;
const auto op = operation->clone_with_new_inputs(inputs);
// constant folding of subtract constant
op->constant_fold(outputs, inputs);
subConst = as_type_ptr<ngraph::opset1::Constant>(outputs[sourceOutputIdx].get_node_shared_ptr());
subtract = std::make_shared<DequantizationSubtract>(parent, subConst);
parent = subtract;
}
std::shared_ptr<DequantizationMultiply> multiply;
std::shared_ptr<ngraph::opset1::Constant> mulConst;
if (dequantization.multiply) {
inputs[0] = dequantization.multiplyConstant;
const auto op = operation->clone_with_new_inputs(inputs);
// constant folding of multiply constant
op->constant_fold(outputs, inputs);
mulConst = as_type_ptr<ngraph::opset1::Constant>(outputs[sourceOutputIdx].get_node_shared_ptr());
multiply = std::make_shared<DequantizationMultiply>(parent, mulConst);
}
return FakeQuantizeDequantization(operation->output(sourceOutputIdx), convert, subtract, nullptr, subConst, multiply, mulConst);
}
} // namespace low_precision
} // namespace pass
} // namespace ngraph

View File

@@ -27,6 +27,11 @@ bool isQuantizationPerChannel(const std::shared_ptr<ngraph::Node>& node) {
return false;
}
//WA to support StridedSlice in ConcatTransformation
if (ngraph::is_type<opset1::StridedSlice>(node)) {
return true;
}
const auto inputs = ngraph::pass::low_precision::NetworkHelper::getInputs(node);
for (const auto& input : inputs) {
if (ngraph::is_type<opset1::Constant>(input.get_node())) {

View File

@@ -0,0 +1,284 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "layer_transformation.hpp"
#include <string>
#include <sstream>
#include <memory>
#include <gtest/gtest.h>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <low_precision/transformer.hpp>
#include <low_precision/concat.hpp>
#include <low_precision/concat_multi_channels.hpp>
#include <low_precision/max_pool.hpp>
#include <low_precision/strided_slice.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "lpt_ngraph_functions/concat_function.hpp"
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
#include "simple_low_precision_transformer.hpp"
using namespace testing;
using namespace ngraph;
using namespace ngraph::pass;
namespace {
class ConcatTransformationActualValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
};
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) {
return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
}
class ConcatTransformationResultValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
ngraph::element::Type precisionBeforeConcat;
ngraph::element::Type precisionAfterConcat;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter1;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter2;
};
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) {
return out << "_" <<
values.fakeQuantize1 << "_" <<
values.fakeQuantize2 << "_" <<
values.dequantizationAfter1 << "_" <<
values.dequantizationAfter2;
}
class ConcatTransformationTestValues {
public:
ngraph::pass::low_precision::LayerTransformation::Params params;
bool multiChannels;
bool ssBeforeConcat;
bool ssAfterConcat;
ConcatTransformationActualValues actual;
ConcatTransformationResultValues result;
};
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) {
return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result;
}
typedef std::tuple <
ngraph::element::Type,
ngraph::Shape,
ConcatTransformationTestValues
> ConcatTransformationParams;
class ConcatWithStridedSliceTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
public:
void SetUp() override {
const ngraph::element::Type precision = std::get<0>(GetParam());
const ngraph::Shape shape = std::get<1>(GetParam());
ConcatTransformationTestValues testValues = std::get<2>(GetParam());
actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithStridedSlice(
precision,
shape,
testValues.actual.fakeQuantize1,
testValues.actual.fakeQuantize2,
testValues.ssBeforeConcat,
testValues.ssAfterConcat);
SimpleLowPrecisionTransformer transform;
if (testValues.multiChannels) {
transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
} else {
transform.add<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
}
transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
transform.add<ngraph::pass::low_precision::StridedSliceTransformation, ngraph::opset1::StridedSlice>(testValues.params);
transform.transform(actualFunction);
referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithStridedSlice(
precision,
shape,
testValues.result.fakeQuantize1,
testValues.result.fakeQuantize2,
testValues.result.dequantizationBefore,
testValues.result.precisionBeforeConcat,
testValues.result.precisionAfterConcat,
testValues.ssBeforeConcat,
testValues.ssAfterConcat,
testValues.result.dequantizationAfter1,
testValues.result.dequantizationAfter2);
}
static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
const ngraph::element::Type precision = std::get<0>(obj.param);
const ngraph::Shape shape = std::get<1>(obj.param);
const ConcatTransformationTestValues testValues = std::get<2>(obj.param);
std::ostringstream result;
result <<
LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
(testValues.ssBeforeConcat ? "SS_before_concat_" : "") <<
(testValues.ssAfterConcat ? "SS_after_cancat_" : "") <<
testValues.actual << "_" <<
testValues.result << "_";
return result.str();
}
};
TEST_P(ConcatWithStridedSliceTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true);
ASSERT_TRUE(res.first) << res.second;
}
const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f32,
// ngraph::element::f16
};
const std::vector<ConcatTransformationTestValues> testValues = {
// FQ with the same values, ss before concat, ss after concat
{
LayerTransformation::createParamsU8I8(),
true,
true,
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, {}, { 0.01f }},
{ngraph::element::f32, {}, { 0.01f }}
}
},
// FQ with different values, ss before concat, ss after concat
{
LayerTransformation::createParamsU8I8(),
true,
true,
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f} }},
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }}
}
},
// FQ with different values, ss after concat
{
LayerTransformation::createParamsU8I8(),
true,
false,
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.01f, 0.01f, 0.1f, 0.1f} }},
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }}
}
},
// FQ with different values, ss before concat
{
LayerTransformation::createParamsU8I8(),
true,
true,
false,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }},
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }}
}
},
// FQ with zero-point, ss before concat, ss after concat
{
LayerTransformation::createParamsU8I8(),
true,
true,
true,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} }
},
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, {}, {1.275f}, {2.55f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, { {0.f, 0.f, -255.f, -255.f} }, { {0.01f, 0.01f, 0.005f, 0.005f} }},
{ngraph::element::f32, { {0.f, 0.f, -255.f, -255.f, -255.f, -255.f} }, { {0.01f, 0.01f, 0.005f, 0.005f, 0.005f, 0.005f} }}
}
},
// not multi channels concat, ss before concat, ss after concat
{
LayerTransformation::createParamsU8I8(),
false,
true,
true,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
},
{
{ 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f} },
{ 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f} },
{ngraph::element::f32, { 85 }, { 0.015f } },
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, { 85 }, { 0.015f } },
{ngraph::element::f32, { 85 }, { 0.015f } }
}
},
};
const std::vector<ngraph::Shape> shapes = {
{ 1, 4, 9, 9 },
{ 4, 4, 9, 9 }
};
INSTANTIATE_TEST_CASE_P(
smoke_LPT,
ConcatWithStridedSliceTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::ValuesIn(shapes),
::testing::ValuesIn(testValues)),
ConcatWithStridedSliceTransformation::getTestCaseName);
} // namespace

View File

@@ -62,6 +62,14 @@ public:
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2);
static std::shared_ptr<ngraph::Function> getOriginalWithStridedSlice(
const ngraph::element::Type precision,
const ngraph::Shape inputShape,
const FakeQuantizeOnData& fq1,
const FakeQuantizeOnData& fq2,
const bool ssBeforeConcat,
const bool ssAfterConcat);
static std::shared_ptr<ngraph::Function> getOriginalWithDifferentPrecisionOnChilds(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
@@ -151,6 +159,19 @@ public:
const DequantizationOperations& dequantizationOperations1,
const DequantizationOperations& dequantizationOperations2);
static std::shared_ptr<ngraph::Function> getReferenceWithStridedSlice(
const ngraph::element::Type inputPrecision,
const ngraph::Shape inputShape,
const FakeQuantizeOnData& fq1,
const FakeQuantizeOnData& fq2,
const DequantizationOperations& deqBefore,
const ngraph::element::Type precisionBeforeConcat,
const ngraph::element::Type precisionAfterConcat,
const bool ssBeforeConcat,
const bool ssAfterConcat,
const DequantizationOperations& deqAfter1,
const DequantizationOperations& deqAfter2);
static std::shared_ptr<ngraph::Function> getReferenceWithDifferentPrecisionOnChilds(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,

View File

@@ -375,6 +375,121 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalSelectionWithInterm
return function;
}
/*
(SS) - optional
Input
/
FQ
/ \
(SS) Clamp
| |
| FQ
\ /
Concat
/\
/ \
(SS) MaxPool
*/
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithStridedSlice(
const ngraph::element::Type precision,
const ngraph::Shape inputShape,
const FakeQuantizeOnData& fq1,
const FakeQuantizeOnData& fq2,
const bool ssBeforeConcat,
const bool ssAfterConcat) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
input->set_friendly_name("input");
const auto fakeQuantize1 = makeFakeQuantize(input, precision, fq1);
fakeQuantize1->set_friendly_name("FakeQuantize_1");
std::shared_ptr<ngraph::Node> parent1 = fakeQuantize1;
if (ssBeforeConcat) {
const auto beginParam = ngraph::op::Constant::create(
ngraph::element::i64,
ngraph::Shape{ inputShape.size() },
std::vector<int64_t>(inputShape.size(), 0));
const auto endParam = ngraph::op::Constant::create(
ngraph::element::i64,
ngraph::Shape{ inputShape.size() },
std::vector<size_t>{ inputShape[0], inputShape[1] - 2ul, inputShape[2], inputShape[3] });
const std::vector<int64_t> beginMask{ 1, 0, 1, 1 };
const std::vector<int64_t> endMask{ 1, 0, 1, 1 };
parent1 = std::make_shared<ngraph::opset1::StridedSlice>(parent1, beginParam, endParam, beginMask, endMask);
parent1->set_friendly_name("StridedSlice_1");
}
const auto clamp = std::make_shared<ngraph::opset1::Clamp>(fakeQuantize1, 0.0, 6.0);
clamp->set_friendly_name("Clamp");
const auto fakeQuantize2 = makeFakeQuantize(clamp, precision, fq2);
fakeQuantize2->set_friendly_name("FakeQuantize_2");
const auto concat = std::make_shared<ngraph::opset1::Concat>(NodeVector{ parent1, fakeQuantize2 }, 1);
concat->set_friendly_name("Concat");
ngraph::ResultVector results;
if (ssAfterConcat) {
const auto concatShape = concat->get_output_shape(0);
const auto beginParam = ngraph::op::Constant::create(
ngraph::element::i64,
ngraph::Shape{ concatShape.size() },
std::vector<int64_t>(concatShape.size(), 0));
const auto endParam = ngraph::op::Constant::create(
ngraph::element::i64,
ngraph::Shape{ concatShape.size() },
std::vector<size_t>{ concatShape[0], concatShape[1] - 2ul, concatShape[2], concatShape[3] });
const std::vector<int64_t> beginMask{ 1, 0, 1, 1 };
const std::vector<int64_t> endMask{ 1, 0, 1, 1 };
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(concat, beginParam, endParam, beginMask, endMask);
stridedSlice->set_friendly_name("StridedSlice_2");
const auto result1 = std::make_shared<ngraph::opset1::Result>(stridedSlice);
result1->set_friendly_name("Result_1");
results.push_back(result1);
} else {
const auto result1 = std::make_shared<ngraph::opset1::Result>(concat);
result1->set_friendly_name("Result_1");
results.push_back(result1);
}
const std::vector<size_t> kernel = { 3, 3 };
const std::vector<size_t> stride = { 1, 1 };
const std::vector<size_t> padBegin = { 0, 0 };
const std::vector<size_t> padEnd = { 0, 0 };
const ngraph::op::PadType padType = ngraph::op::PadType::NOTSET;
const ngraph::op::RoundingType roundingType = ngraph::op::RoundingType::FLOOR;
const auto maxPool = std::make_shared<ngraph::opset1::MaxPool>(
concat,
stride,
padBegin,
padEnd,
kernel,
roundingType,
padType);
maxPool->set_friendly_name("MaxPool");
const auto result2 = std::make_shared<ngraph::opset1::Result>(maxPool);
result2->set_friendly_name("Result_2");
results.push_back(result2);
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
ngraph::ParameterVector{ input },
"ConcatWithDifferentChildsTransformation");
return function;
}
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithDifferentPrecisionOnChilds(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
@@ -985,6 +1100,117 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceSelectionWithInter
return function;
}
std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithStridedSlice(
const ngraph::element::Type inputPrecision,
const ngraph::Shape inputShape,
const FakeQuantizeOnData& fq1,
const FakeQuantizeOnData& fq2,
const DequantizationOperations& deqBefore,
const ngraph::element::Type precisionBeforeConcat,
const ngraph::element::Type precisionAfterConcat,
const bool ssBeforeConcat,
const bool ssAfterConcat,
const DequantizationOperations& deqAfter1,
const DequantizationOperations& deqAfter2) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
input->set_friendly_name("input1");
const auto fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input, inputPrecision, fq1);
low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize1, precisionBeforeConcat);
fakeQuantize1->set_friendly_name("FakeQuantize_1");
std::shared_ptr<ngraph::Node> parent1 = fakeQuantize1;
if (ssBeforeConcat) {
const auto beginParam = ngraph::op::Constant::create(
ngraph::element::i64,
ngraph::Shape{ inputShape.size() },
std::vector<int64_t>(inputShape.size(), 0));
const auto endParam = ngraph::op::Constant::create(
ngraph::element::i64,
ngraph::Shape{ inputShape.size() },
std::vector<size_t>{ inputShape[0], inputShape[1] - 2ul, inputShape[2], inputShape[3] });
const std::vector<int64_t> beginMask{ 1, 0, 1, 1 };
const std::vector<int64_t> endMask{ 1, 0, 1, 1 };
parent1 = std::make_shared<ngraph::opset1::StridedSlice>(parent1, beginParam, endParam, beginMask, endMask);
parent1->set_friendly_name("StridedSlice_1");
}
const auto dequantizationBefore = makeDequantization(fakeQuantize1, deqBefore);
const auto clamp = std::make_shared<ngraph::opset1::Clamp>(dequantizationBefore, 0.0, 6.0);
clamp->set_friendly_name("Clamp");
const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(clamp, inputPrecision, fq2);
low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize2, precisionBeforeConcat);
fakeQuantize2->set_friendly_name("FakeQuantize_2");
const auto concat = std::make_shared<ngraph::opset1::Concat>(NodeVector{ parent1, fakeQuantize2 }, 1);
concat->set_friendly_name("Concat");
ngraph::ResultVector results;
if (ssAfterConcat) {
const auto concatShape = concat->get_output_shape(0);
const auto beginParam = ngraph::op::Constant::create(
ngraph::element::i64,
ngraph::Shape{ concatShape.size() },
std::vector<int64_t>(concatShape.size(), 0));
const auto endParam = ngraph::op::Constant::create(
ngraph::element::i64,
ngraph::Shape{ concatShape.size() },
std::vector<size_t>{ concatShape[0], concatShape[1] - 2ul, concatShape[2], concatShape[3] });
const std::vector<int64_t> beginMask{ 1, 0, 1, 1 };
const std::vector<int64_t> endMask{ 1, 0, 1, 1 };
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(concat, beginParam, endParam, beginMask, endMask);
stridedSlice->set_friendly_name("StridedSlice_2");
const auto dequantizationAfter1 = makeDequantization(stridedSlice, deqAfter1);
const auto result1 = std::make_shared<ngraph::opset1::Result>(dequantizationAfter1);
result1->set_friendly_name("Result_1");
results.push_back(result1);
} else {
const auto dequantizationAfter1 = makeDequantization(concat, deqAfter1);
const auto result1 = std::make_shared<ngraph::opset1::Result>(dequantizationAfter1);
result1->set_friendly_name("Result_1");
results.push_back(result1);
}
const std::vector<size_t> kernel = { 3, 3 };
const std::vector<size_t> stride = { 1, 1 };
const std::vector<size_t> padBegin = { 0, 0 };
const std::vector<size_t> padEnd = { 0, 0 };
const ngraph::op::PadType padType = ngraph::op::PadType::NOTSET;
const ngraph::op::RoundingType roundingType = ngraph::op::RoundingType::FLOOR;
const auto maxPool = std::make_shared<ngraph::opset1::MaxPool>(
concat,
stride,
padBegin,
padEnd,
kernel,
roundingType,
padType);
maxPool->set_friendly_name("MaxPool");
const auto dequantizationAfter2 = makeDequantization(maxPool, deqAfter2);
const auto result2 = std::make_shared<ngraph::opset1::Result>(dequantizationAfter2);
result2->set_friendly_name("Result_2");
results.push_back(result2);
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
ngraph::ParameterVector{ input },
"ConcatWithDifferentChildsTransformation");
return function;
}
std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithDifferentPrecisionOnChilds(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,