[LPT] Split/VariadicSplit support (#5132)

* [LPT] Split support in ConcatTransformation

* [LPT] fixing functional problems after enabling Split/VariadicSplit transformations

* [LPT] added test case for StridedSliceTransformation, enabled tests with split

* [LPT] ConcatTransformation refactoring

* ConcatTransformation: added axis check

* [LPT] Added foldDequantizationConstant to NetworkHelper & SplitTransformation refactoring

* [LPT] Subgraph: returned and refactored quantizationPerChannel

* [LPT] foldDequantizationConstant refactoring

* [LPT] SplitTransformation refactoring

* [LPT] hbonets fix & ConcatTrasnformation refactoring

* [LPT] ConcatTransformation functional tests quick fix
This commit is contained in:
Vladislav Golubev 2021-04-08 08:42:59 +03:00 committed by GitHub
parent c17c92a72d
commit 0b116620e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 523 additions and 509 deletions

View File

@ -35,6 +35,7 @@ protected:
ngraph::pass::low_precision::Subgraph& subgraph,
std::function<void(
std::shared_ptr<ngraph::Node> layer,
std::shared_ptr<ngraph::Node> child,
const std::string originalLayerName,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate)> getLayerDequantizationCallback) const;
@ -42,6 +43,15 @@ protected:
const TransformationContext& context,
const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations);
void fillDequantizationNodes(
const std::vector<FakeQuantizeDequantization>& layerDequantizations,
const std::shared_ptr<Node> layer,
NodeVector& convertNodes,
NodeVector& subtractNodes,
NodeVector& multiplyNodes) const;
std::shared_ptr<Node> concatenateDeqNodes(NodeVector& nodes) const;
private:
size_t getMinQuantizationLevels(
const DataPrecision& dataPrecision,

View File

@ -27,12 +27,9 @@ public:
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
private:
// Go through the parent elements of the layer and fill dequantization collection
// with Dq operations that should be inserted before the layer.
void fillDequantization(
std::shared_ptr<ngraph::Node> layer,
std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) const;
void fillQuantization(
const std::shared_ptr<ngraph::Node> layer,
const std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantization) const;
@ -46,8 +43,6 @@ private:
const FakeQuantizeDequantization& dequantization,
const size_t sourceOutputIdx);
static FakeQuantizeDequantization broadcastDequantiationConstant(const FakeQuantizeDequantization& deq);
bool isMultiChannel(const std::vector<std::shared_ptr<ngraph::opset1::Concat>>& concatLayers) const noexcept;
};

View File

@ -50,6 +50,12 @@ public:
template <typename OperationType>
static std::shared_ptr<Node> setOutDataPrecision(std::shared_ptr<OperationType> operation, const element::Type& precision);
// applies constant folding of operation to constant and returns the specified output
static std::shared_ptr<opset1::Constant> foldDequantizationConstant(
const std::shared_ptr<opset1::Constant>& foldingConstant,
const std::shared_ptr<Node>& operation,
const size_t outIdx = 0);
static size_t getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights = false);
static std::vector<std::shared_ptr<Node>> getParentsRecursivelyExceptTypes(

View File

@ -24,15 +24,6 @@ public:
TransformationContext& context,
std::vector<std::shared_ptr<ngraph::Node>> lastNodes,
std::shared_ptr<ngraph::Node> originalNode) const;
protected:
ngraph::Shape getConstSplitShape(
const std::vector<size_t>& constSplitLengths,
const ngraph::Shape& constShape, const size_t axis,
const size_t idx) const;
virtual std::vector<size_t> getConstSplitLengths(
const OutputVector& inputs,
const ngraph::Shape& constShape,
const size_t outputSize) const;
};
} // namespace low_precision
} // namespace pass

View File

@ -17,11 +17,6 @@ class TRANSFORMATIONS_API VariadicSplitTransformation : public SplitTransformati
public:
VariadicSplitTransformation(const Params& params);
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
protected:
std::vector<size_t> getConstSplitLengths(
const OutputVector& inputs,
const ngraph::Shape& constShape,
const size_t outputSize) const override;
};
} // namespace low_precision
} // namespace pass

View File

@ -201,6 +201,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
auto dequantizationValuesCallback = [&](
std::shared_ptr<ngraph::Node> layer,
std::shared_ptr<ngraph::Node> child,
const std::string originalLayerName,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) {
dequantizationsToConcatenate.push_back(dequantization);
@ -234,15 +235,97 @@ bool ConcatTransformation::isPrecisionPreserved(std::shared_ptr<Node>) const noe
bool ConcatTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
std::shared_ptr<opset1::Concat> concat = as_type_ptr<opset1::Concat>(layer);
return concat && concat->get_axis() == 1ul;
if (concat == nullptr) {
return false;
}
const auto axis = concat->get_axis();
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, concat->get_output_partial_shape(0).rank());
return normalizedAxis == 1ul;
}
void ConcatTransformation::fillDequantizationNodes(
const std::vector<FakeQuantizeDequantization>& layerDequantizations,
const std::shared_ptr<Node> layer,
NodeVector& convertNodes,
NodeVector& subtractNodes,
NodeVector& multiplyNodes) const {
if (layerDequantizations.size() > 1ul) {
auto broadcastElementWiseConst = [](
// FakeQuantize constant shape must be broadcastable to the shape on data.
std::shared_ptr<ngraph::opset1::Constant> operation,
const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
element::i64, ngraph::Shape{ targetShape.size() },
targetShape);
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
operation,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
return broadcast;
};
bool allDequantizationShiftAreZero = true;
bool allDequantizationMultiplyAreZero = true;
for (const auto& dequantization : layerDequantizations) {
if (dequantization.subtract != nullptr) {
allDequantizationShiftAreZero = false;
}
if (dequantization.multiply != nullptr) {
allDequantizationMultiplyAreZero = false;
}
}
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
const auto& dequantization = layerDequantizations[i];
const ngraph::element::Type precision = deqPrecision;
ngraph::Shape targetShape(layer->get_input_shape(i).size(), 1ul);
targetShape[1] = layer->get_input_shape(i)[1];
if (dequantization.convert != nullptr) {
convertNodes.push_back(dequantization.convert);
}
if (!allDequantizationShiftAreZero) {
subtractNodes.push_back(dequantization.subtract == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 0.f })) :
broadcastElementWiseConst(dequantization.subtractConstant, targetShape));
}
if (!allDequantizationMultiplyAreZero) {
multiplyNodes.push_back(dequantization.multiply == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 1.0f })) :
broadcastElementWiseConst(dequantization.multiplyConstant, targetShape));
}
}
} else {
// TODO: check constant shapes here - has to be scalar
if (layerDequantizations[0].convert != nullptr) {
convertNodes.push_back(layerDequantizations[0].convert);
}
if (layerDequantizations[0].subtract != nullptr) {
subtractNodes.push_back(layerDequantizations[0].subtract->input_value(1).get_node_shared_ptr());
}
if (layerDequantizations[0].multiply != nullptr) {
multiplyNodes.push_back(layerDequantizations[0].multiply->input_value(1).get_node_shared_ptr());
}
}
}
std::shared_ptr<Node> ConcatTransformation::concatenateDeqNodes(NodeVector& nodes) const {
return nodes.size() == 1ul ? nodes[0] : fold<ngraph::opset1::Concat>(nodes, 1);
}
void ConcatTransformation::addDequantizationLayers(
TransformationContext& context,
ngraph::pass::low_precision::Subgraph& subgraph,
std::function<void(
std::shared_ptr<ngraph::Node> layer,
std::shared_ptr<ngraph::Node> child,
const std::string originalLayerName,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate)> getLayerDequantizationCallback) const {
std::unordered_map<std::string, ngraph::Node*> outputs;
@ -269,95 +352,28 @@ void ConcatTransformation::addDequantizationLayers(
ngraph::Node& child = *childInput.get_node();
if (subgraph.layers.find(child.get_friendly_name()) == subgraph.layers.end()) {
std::shared_ptr<ngraph::Node> source = layer;
const std::shared_ptr<ngraph::Node> destination = child.shared_from_this();
if (layerDequantizations.size() == 0ul) {
// fill layerDequantizations collection
getLayerDequantizationCallback(layer, layer->get_friendly_name(), layerDequantizations);
getLayerDequantizationCallback(source, destination, source->get_friendly_name(), layerDequantizations);
}
std::shared_ptr<ngraph::Node> source = layer->shared_from_this();
{
std::vector<std::shared_ptr<ngraph::Node>> convertNodes;
std::vector<std::shared_ptr<ngraph::Node>> subtractNodes;
std::vector<std::shared_ptr<ngraph::Node>> multiplyNodes;
NodeVector convertNodes;
NodeVector subtractNodes;
NodeVector multiplyNodes;
// forming nodes for concatenation
if (layerDequantizations.size() > 1ul) {
auto broadcastElementWiseConst = [](
// FakeQuantize constant shape must be broadcastable to the shape on data.
std::shared_ptr<ngraph::opset1::Constant> operation,
const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
element::i64, ngraph::Shape{ targetShape.size() },
targetShape);
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
operation,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
return broadcast;
};
bool allDequantizationShiftAreZero = true;
bool allDequantizationMultiplyAreZero = true;
for (FakeQuantizeDequantization dequantization : layerDequantizations) {
if (dequantization.subtract != nullptr) {
allDequantizationShiftAreZero = false;
}
if (dequantization.multiply != nullptr) {
allDequantizationMultiplyAreZero = false;
}
}
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
const auto& dequantization = layerDequantizations[i];
if (dequantization.convert != nullptr) {
convertNodes.push_back(dequantization.convert);
}
const ngraph::element::Type precision = deqPrecision;
ngraph::Shape targetShape(layer->get_input_shape(i).size(), 1ul);
targetShape[1] = layer->get_input_shape(i)[1];
if (!allDequantizationShiftAreZero) {
subtractNodes.push_back(dequantization.subtract == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 0.f })) :
broadcastElementWiseConst(
as_type_ptr<ngraph::opset1::Constant>(dequantization.subtract->input_value(1).get_node_shared_ptr()),
targetShape));
}
if (!allDequantizationMultiplyAreZero) {
multiplyNodes.push_back(dequantization.multiply == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 1.0f })) :
broadcastElementWiseConst(
as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->input_value(1).get_node_shared_ptr()),
targetShape));
}
}
} else {
// TODO: check constant shapes here - has to be scalar
if (layerDequantizations[0].convert != nullptr) {
convertNodes.push_back(layerDequantizations[0].convert);
}
if (layerDequantizations[0].subtract != nullptr) {
subtractNodes.push_back(layerDequantizations[0].subtract->input_value(1).get_node_shared_ptr());
}
if (layerDequantizations[0].multiply != nullptr) {
multiplyNodes.push_back(layerDequantizations[0].multiply->input_value(1).get_node_shared_ptr());
}
}
fillDequantizationNodes(layerDequantizations, layer, convertNodes, subtractNodes, multiplyNodes);
// TODO: the second place (first is FQ decomposition) where dequantization operations are inserted
const std::shared_ptr<ngraph::Node> destination = child.shared_from_this();
if (!convertNodes.empty()) {
const size_t sourceOutputIdx = NetworkHelper::getChildInputIndex(source, destination);
std::shared_ptr<ngraph::Node> convert =
convertNodes[0]->clone_with_new_inputs({ destination->get_input_source_output(sourceOutputIdx) });
insert_new_node_between(source, destination, convert);
ngraph::copy_runtime_info({ layer, convert }, convert);
source = convert;
@ -368,9 +384,8 @@ void ConcatTransformation::addDequantizationLayers(
const size_t sourceOutputIdx = NetworkHelper::getChildInputIndex(source, destination);
std::shared_ptr<ngraph::opset1::Subtract> subtract = std::make_shared<DequantizationSubtract>(
destination->get_input_source_output(sourceOutputIdx),
NetworkHelper::toScalarIfPossible(subtractNodes.size() == 1ul ?
subtractNodes[0] :
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subtractNodes, 1)));
NetworkHelper::toScalarIfPossible(concatenateDeqNodes(subtractNodes)));
insert_new_node_between(source, destination, subtract);
ngraph::copy_runtime_info({ layer, subtract }, subtract);
source = subtract;
@ -381,10 +396,9 @@ void ConcatTransformation::addDequantizationLayers(
std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
DequantizationMultiply(
destination->get_input_source_output(sourceOutputIdx),
NetworkHelper::toScalarIfPossible(multiplyNodes.size() == 1ul ?
multiplyNodes[0] :
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(multiplyNodes, 1))),
NetworkHelper::toScalarIfPossible(concatenateDeqNodes(multiplyNodes))),
layerDequantizations[0].multiply->get_output_element_type(0));
insert_new_node_between(source, destination, multiply);
ngraph::copy_runtime_info({ layer, multiply }, multiply);
source = multiply;

View File

@ -137,6 +137,7 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
auto dequantizationValuesCallback = [&](
std::shared_ptr<ngraph::Node> layer,
std::shared_ptr<ngraph::Node> child,
const std::string originalLayerName,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) {
if (layer->get_friendly_name() != originalLayerName) {
@ -157,6 +158,15 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
layer,
dequantizations,
dequantizationsToConcatenate);
if (!is_type<ngraph::opset1::Concat>(layer)) {
// for intermediate layers we should get Dq operations to be inserted between layer and child
assert(dequantizationsToConcatenate.size() == 1ul);
const size_t sourceOutputIdx = NetworkHelper::getParentOutputIndex(layer, child);
if (layer->get_input_shape(0)[1] != layer->get_output_shape(sourceOutputIdx)[1]) {
dequantizationsToConcatenate[0] = getFoldedDequantization(layer, dequantizationsToConcatenate[0], sourceOutputIdx);
}
}
};
addDequantizationLayers(context, subgraph, dequantizationValuesCallback);
@ -185,137 +195,66 @@ bool ConcatMultiChannelsTransformation::isPrecisionPreserved(std::shared_ptr<Nod
return true;
}
// fill dequantizationsToMerge collection for layer with using dequantizationByFakeQuantize
void ConcatMultiChannelsTransformation::fillDequantization(
std::shared_ptr<ngraph::Node> layer,
std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) const {
std::shared_ptr<ngraph::opset1::FakeQuantize> currentFakeQuantize = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(layer);
if (currentFakeQuantize) {
const auto it = dequantizationByFakeQuantize.find(currentFakeQuantize->get_friendly_name());
if (it == dequantizationByFakeQuantize.end()) {
THROW_IE_LPT_EXCEPTION(*currentFakeQuantize) << "dequantization scale values are not found";
}
const FakeQuantizeDequantization& fakeQuantizeDequantization = it->second;
dequantizationsToConcatenate.push_back(broadcastDequantiationConstant(fakeQuantizeDequantization));
} else {
fillQuantization(layer, dequantizationByFakeQuantize, dequantizationsToConcatenate);
}
}
void ConcatMultiChannelsTransformation::fillQuantization(
const std::shared_ptr<ngraph::Node> layer,
const std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantization) const {
for (size_t i = 0; i < layer->get_input_size(); ++i) {
std::shared_ptr<ngraph::Node> parent = layer->get_input_node_shared_ptr(i);
const auto fillDqByFakeQuantize = [&](const std::shared_ptr<ngraph::Node>& fq) {
const auto it = dequantizationByFakeQuantize.find(fq->get_friendly_name());
if (it == dequantizationByFakeQuantize.end()) {
THROW_IE_LPT_EXCEPTION(*fq) << "dequantization scale values are not found";
}
std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantize = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(parent);
if (fakeQuantize) {
const auto it = dequantizationByFakeQuantize.find(fakeQuantize->get_friendly_name());
if (it == dequantizationByFakeQuantize.end()) {
THROW_IE_LPT_EXCEPTION(*fakeQuantize) << "dequantization scale values are not found";
const FakeQuantizeDequantization& fakeQuantizeDequantization = it->second;
dequantization.push_back(fakeQuantizeDequantization);
};
if (is_type<ngraph::opset1::FakeQuantize>(layer)) {
fillDqByFakeQuantize(layer);
} else {
for (size_t i = 0; i < layer->get_input_size(); ++i) {
std::shared_ptr<ngraph::Node> parent = layer->get_input_node_shared_ptr(i);
if (as_type_ptr<ngraph::opset1::Constant>(parent)) {
continue;
}
const FakeQuantizeDequantization& fakeQuantizeDequantization = it->second;
dequantization.push_back(broadcastDequantiationConstant(fakeQuantizeDequantization));
} else {
std::shared_ptr<ngraph::opset1::Concat> concat = ngraph::as_type_ptr<ngraph::opset1::Concat>(parent);
if (concat) {
std::vector<FakeQuantizeDequantization> dequantizationToConcatenate;
fillQuantization(concat, dequantizationByFakeQuantize, dequantizationToConcatenate);
// add concatenated dequantization operations to dequantization collection
dequantization.push_back(getConcatenatedDequantization(concat, dequantizationToConcatenate));
const auto fakeQuantize = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(parent);
if (fakeQuantize) {
fillDqByFakeQuantize(fakeQuantize);
} else {
std::shared_ptr<ngraph::opset1::StridedSlice> stridedSlice = ngraph::as_type_ptr<ngraph::opset1::StridedSlice>(parent);
if (stridedSlice) {
std::vector<FakeQuantizeDequantization> dequantizationToPropagate;
fillQuantization(stridedSlice, dequantizationByFakeQuantize, dequantizationToPropagate);
const auto concat = ngraph::as_type_ptr<ngraph::opset1::Concat>(parent);
if (concat) {
std::vector<FakeQuantizeDequantization> dequantizationToConcatenate;
fillDequantization(concat, dequantizationByFakeQuantize, dequantizationToConcatenate);
const size_t sourceOutputIdx = NetworkHelper::getParentOutputIndex(parent, layer);
// add folded dequantization operations to dequantization colection
dequantization.push_back(getFoldedDequantization(stridedSlice, dequantizationToPropagate[0], sourceOutputIdx));
// add concatenated dequantization operations to dequantization collection
dequantization.push_back(getConcatenatedDequantization(concat, dequantizationToConcatenate));
} else {
fillQuantization(parent, dequantizationByFakeQuantize, dequantization);
const size_t sourceOutputIdx = NetworkHelper::getParentOutputIndex(parent, layer);
if (parent->get_input_shape(0)[1] != parent->get_output_shape(sourceOutputIdx)[1]) {
std::vector<FakeQuantizeDequantization> dequantizationToPropagate;
fillDequantization(parent, dequantizationByFakeQuantize, dequantizationToPropagate);
// add folded dequantization operations to dequantization colection
dequantization.push_back(getFoldedDequantization(parent, dequantizationToPropagate[0], sourceOutputIdx));
} else {
fillDequantization(parent, dequantizationByFakeQuantize, dequantization);
}
}
}
}
}
}
// broadcast of dequantization constants by channels
FakeQuantizeDequantization ConcatMultiChannelsTransformation::broadcastDequantiationConstant(const FakeQuantizeDequantization& deq) {
ngraph::Shape targetShape(deq.data.get_shape().size(), 1ul);
targetShape[1] = deq.data.get_shape()[1];
FakeQuantizeDequantization result;
result.data = deq.data;
result.convert = deq.convert;
const auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
element::i64, ngraph::Shape{ targetShape.size() },
targetShape);
if (deq.subtract) {
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
deq.subtractConstant,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
result.subtract = deq.subtract;
result.subtractConstant = as_type_ptr<ngraph::opset1::Constant>(broadcast);
}
if (deq.multiply) {
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
deq.multiplyConstant,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
result.multiply = deq.multiply;
result.multiplyConstant = as_type_ptr<ngraph::opset1::Constant>(broadcast);
}
return result;
}
FakeQuantizeDequantization ConcatMultiChannelsTransformation::getConcatenatedDequantization(
const std::shared_ptr<ngraph::opset1::Concat> concat,
const std::vector<FakeQuantizeDequantization>& dequantization) const {
bool allDequantizationShiftAreZero = true;
bool allDequantizationMultiplyAreZero = true;
for (const auto& deq : dequantization) {
if (deq.subtract != nullptr) {
allDequantizationShiftAreZero = false;
}
if (deq.multiply != nullptr) {
allDequantizationMultiplyAreZero = false;
}
}
NodeVector convertNodes;
NodeVector subNodes;
NodeVector mulNodes;
//preparing to concatenate dequantization nodes
for (const auto& deq : dequantization) {
ngraph::Shape targetShape(deq.data.get_shape().size(), 1ul);
targetShape[1] = deq.data.get_shape()[1];
NodeVector subtractNodes;
NodeVector multiplyNodes;
if (deq.convert != nullptr) {
convertNodes.push_back(deq.convert);
}
if (!allDequantizationShiftAreZero) {
subNodes.push_back(deq.subtract == nullptr ?
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 0.f })) :
deq.subtractConstant);
}
if (!allDequantizationMultiplyAreZero) {
mulNodes.push_back(deq.multiply == nullptr ?
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 1.0f })) :
deq.multiplyConstant);
}
}
// forming nodes for concatenation
fillDequantizationNodes(dequantization, concat, convertNodes, subtractNodes, multiplyNodes);
std::shared_ptr<Node> parent = concat;
std::shared_ptr<DequantizationConvert> convert;
@ -326,20 +265,16 @@ FakeQuantizeDequantization ConcatMultiChannelsTransformation::getConcatenatedDeq
std::shared_ptr<DequantizationSubtract> subtract;
std::shared_ptr<ngraph::opset1::Constant> subConst;
if (!subNodes.empty()) {
subConst = as_type_ptr<ngraph::opset1::Constant>(
subNodes.size() == 1ul ? subNodes[0] : fold<ngraph::opset1::Concat>(subNodes, 1ul));
if (!subtractNodes.empty()) {
subConst = as_type_ptr<ngraph::opset1::Constant>(concatenateDeqNodes(subtractNodes));
subtract = std::make_shared<DequantizationSubtract>(parent, subConst);
parent = subtract;
}
std::shared_ptr<DequantizationMultiply> multiply;
std::shared_ptr<ngraph::opset1::Constant> mulConst;
if (!mulNodes.empty()) {
mulConst = as_type_ptr<ngraph::opset1::Constant>(
mulNodes.size() == 1ul ? mulNodes[0] : fold<ngraph::opset1::Concat>(mulNodes, 1ul));
if (!multiplyNodes.empty()) {
mulConst = as_type_ptr<ngraph::opset1::Constant>(concatenateDeqNodes(multiplyNodes));
multiply = std::make_shared<DequantizationMultiply>(parent, mulConst);
}
@ -352,24 +287,19 @@ FakeQuantizeDequantization ConcatMultiChannelsTransformation::getFoldedDequantiz
const size_t sourceOutputIdx) {
OutputVector inputs = operation->input_values();
OutputVector outputs(operation->get_output_size());
Output<Node> data = operation->output(sourceOutputIdx);
std::shared_ptr<Node> parent = operation;
std::shared_ptr<DequantizationConvert> convert;
if (dequantization.convert) {
convert = as_type_ptr<DequantizationConvert>(dequantization.convert->clone_with_new_inputs({ parent }));
convert = as_type_ptr<DequantizationConvert>(dequantization.convert->clone_with_new_inputs({ data }));
parent = convert;
}
std::shared_ptr<DequantizationSubtract> subtract;
std::shared_ptr<ngraph::opset1::Constant> subConst;
if (dequantization.subtract) {
inputs[0] = dequantization.subtractConstant;
const auto op = operation->clone_with_new_inputs(inputs);
// constant folding of subtract constant
op->constant_fold(outputs, inputs);
subConst = as_type_ptr<ngraph::opset1::Constant>(outputs[sourceOutputIdx].get_node_shared_ptr());
subConst = NetworkHelper::foldDequantizationConstant(dequantization.subtractConstant, operation, sourceOutputIdx);
subtract = std::make_shared<DequantizationSubtract>(parent, subConst);
parent = subtract;
}
@ -377,17 +307,11 @@ FakeQuantizeDequantization ConcatMultiChannelsTransformation::getFoldedDequantiz
std::shared_ptr<DequantizationMultiply> multiply;
std::shared_ptr<ngraph::opset1::Constant> mulConst;
if (dequantization.multiply) {
inputs[0] = dequantization.multiplyConstant;
const auto op = operation->clone_with_new_inputs(inputs);
// constant folding of multiply constant
op->constant_fold(outputs, inputs);
mulConst = as_type_ptr<ngraph::opset1::Constant>(outputs[sourceOutputIdx].get_node_shared_ptr());
mulConst = NetworkHelper::foldDequantizationConstant(dequantization.multiplyConstant, operation, sourceOutputIdx);
multiply = std::make_shared<DequantizationMultiply>(parent, mulConst);
}
return FakeQuantizeDequantization(operation->output(sourceOutputIdx), convert, subtract, nullptr, subConst, multiply, mulConst);
return FakeQuantizeDequantization(data, convert, subtract, nullptr, subConst, multiply, mulConst);
}
} // namespace low_precision

View File

@ -87,6 +87,31 @@ bool NetworkHelper::isConstantPath(const std::shared_ptr<Node>& op) {
return true;
}
std::shared_ptr<opset1::Constant> NetworkHelper::foldDequantizationConstant(
const std::shared_ptr<opset1::Constant>& foldingConstant,
const std::shared_ptr<Node>& operation,
const size_t outIdx) {
OutputVector inputs = operation->input_values();
OutputVector outputs(operation->get_output_size());
if (shape_size(foldingConstant->get_shape()) == 1ul) {
return toScalar(foldingConstant);
} else {
inputs[0] = foldingConstant;
const auto op = operation->clone_with_new_inputs(inputs);
// constant folding of constant
op->constant_fold(outputs, inputs);
const auto result = as_type_ptr<opset1::Constant>(outputs[outIdx].get_node_shared_ptr());
if (result == nullptr) {
THROW_IE_LPT_EXCEPTION(*result) << "result of constant folding is not constant";
}
return result;
}
}
size_t NetworkHelper::getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights) {
if (layer->outputs().size() == 0) {
THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " doesn't have output tensors";

View File

@ -5,6 +5,7 @@
#include "low_precision/split.hpp"
#include "ngraph/node.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/common/dequantization_op.hpp"
namespace ngraph {
namespace pass {
@ -22,81 +23,68 @@ bool SplitTransformation::transform(TransformationContext& context, ngraph::patt
return false;
}
const std::shared_ptr<Node> split = NetworkHelper::separateInStandaloneBranch(m.get_match_root());
auto dequantization = NetworkHelper::getDequantization(split);
const auto split = NetworkHelper::separateInStandaloneBranch(m.get_match_root());
const auto dequantization = NetworkHelper::getDequantization(split);
OutputVector inputs(split->get_input_size());
for (size_t i = 0; i < split->get_input_size(); ++i) {
inputs[i] = split->get_input_node_shared_ptr(i);
}
OutputVector inputs = split->input_values();
inputs[0] = dequantization.data;
const size_t dequantizationIndex = NetworkHelper::getChildInputIndex(dequantization.multiply, split);
inputs[dequantizationIndex] = dequantization.data;
std::shared_ptr<ngraph::Node> newSplit = split->clone_with_new_inputs(inputs);
const auto newSplit = split->clone_with_new_inputs(inputs);
newSplit->set_friendly_name(split->get_friendly_name());
ngraph::copy_runtime_info(split, newSplit);
const ngraph::Shape subConstShape = dequantization.subtract ?
dequantization.subtract->get_input_node_shared_ptr(1)->get_shape() : Shape{};
std::vector<float> subValues = dequantization.subtract ? as_type_ptr<opset1::Constant>(
dequantization.subtract->get_input_node_shared_ptr(1))->cast_vector<float>() : std::vector<float>();
const int64_t axis = as_type_ptr<opset1::Constant>(split->get_input_node_shared_ptr(1))->cast_vector<int64_t>()[0];
const size_t normalizedAxis = normalize_axis(split->get_friendly_name(), axis, split->get_input_partial_shape(0).rank());
const size_t outputSize = newSplit->get_output_size();
const ngraph::Shape mulConstShape = dequantization.multiply->get_input_node_shared_ptr(1)->get_shape();
std::vector<float> mulValues = as_type_ptr<opset1::Constant>(
dequantization.multiply->get_input_node_shared_ptr(1))->cast_vector<float>();
const auto splitConstant = [&](const std::shared_ptr<Node> operation) {
// if batch is absent in constant shape - add batch
const auto normalizedConstant = NetworkHelper::normalizeDequantizationShape(operation);
const auto constantShape = normalizedConstant->get_shape();
int64_t SplitedAxis = as_type_ptr<opset1::Constant>(split->get_input_node_shared_ptr(1))->cast_vector<int64_t>()[0];
size_t axis = SplitedAxis > 0 ? SplitedAxis : split->get_input_shape(0).size() + SplitedAxis;
size_t outputSize = newSplit->get_output_size();
const auto subSplitLengths = getConstSplitLengths(inputs, subConstShape, outputSize);
const auto mulSplitLengths = getConstSplitLengths(inputs, mulConstShape, outputSize);
std::vector<std::shared_ptr<ngraph::Node>> lastNodes(outputSize);
ngraph::OutputVector replacement;
for (size_t i = 0; i < outputSize; ++i) {
Output<Node> previous = newSplit->output(i);
if (dequantization.convert != nullptr) {
const std::shared_ptr<ngraph::Node> convert =
dequantization.convert->clone_with_new_inputs({ newSplit->output(i) });
previous = convert;
}
if (dequantization.subtract != nullptr) {
std::shared_ptr<ngraph::opset1::Constant> subConst;
if (!subSplitLengths.empty()) {
const auto newSubConstShape = getConstSplitShape(subSplitLengths, subConstShape, axis, i);
std::vector<float> newSubValues(
subValues.begin() + subSplitLengths[i],
subValues.begin() + subSplitLengths[i + 1]);
subConst = as_type_ptr<ngraph::opset1::Constant>(std::make_shared<ngraph::opset1::Constant>(
dequantization.subtract->get_input_element_type(1),
newSubConstShape,
newSubValues));
} else {
subConst = as_type_ptr<ngraph::opset1::Constant>(dequantization.subtract->get_input_node_shared_ptr(1)->clone_with_new_inputs({}));
}
const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(previous, subConst);
previous = subtract;
}
std::shared_ptr<ngraph::opset1::Constant> mulConst;
if (!mulSplitLengths.empty()) {
const auto newMulConstShape = getConstSplitShape(mulSplitLengths, mulConstShape, axis, i);
std::vector<float> newMulValues(
mulValues.begin() + mulSplitLengths[i],
mulValues.begin() + mulSplitLengths[i + 1]);
mulConst = as_type_ptr<ngraph::opset1::Constant>(std::make_shared<ngraph::opset1::Constant>(
dequantization.multiply->get_input_element_type(1), newMulConstShape, newMulValues));
OutputVector results(outputSize);
if ((shape_size(constantShape) == 1ul) || (constantShape[normalizedAxis] == 1ul)) {
std::for_each(results.begin(), results.end(), [&](Output<Node>& elem) { elem = normalizedConstant->clone_with_new_inputs({}); });
} else {
mulConst = as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({}));
// prepare new inputs for constant folding
OutputVector inputs = newSplit->input_values();
inputs[0] = normalizedConstant;
const auto foldSplit = newSplit->clone_with_new_inputs(inputs);
// fold and fill results
foldSplit->constant_fold(results, inputs);
}
const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(previous, mulConst);
for (auto& result : results) {
result = NetworkHelper::toScalarIfPossible(result.get_node_shared_ptr());
}
return results;
};
// get splited dequantization constants
OutputVector splitedSub = dequantization.subtract ? splitConstant(dequantization.subtract) : OutputVector{};
OutputVector splitedMul = splitConstant(dequantization.multiply);
NodeVector lastNodes;
OutputVector replacement;
for (size_t i = 0; i < outputSize; ++i) {
Output<Node> parent = newSplit->output(i);
if (dequantization.convert) {
const auto convert = dequantization.convert->clone_with_new_inputs({ newSplit->output(i) });
copy_runtime_info({ newSplit, convert }, convert);
parent = convert;
}
if (dequantization.subtract) {
const auto subtract = std::make_shared<DequantizationSubtract>(parent, splitedSub[i]);
copy_runtime_info({ newSplit, subtract }, subtract);
parent = subtract;
}
const auto multiply = std::make_shared<DequantizationMultiply>(parent, splitedMul[i]);
copy_runtime_info({ newSplit, multiply }, multiply);
lastNodes.push_back(multiply);
replacement.push_back(multiply);
@ -107,33 +95,6 @@ bool SplitTransformation::transform(TransformationContext& context, ngraph::patt
return true;
}
std::vector<size_t> SplitTransformation::getConstSplitLengths(
const OutputVector& inputs,
const ngraph::Shape& constShape,
const size_t outputSize) const {
int64_t axis = as_type_ptr<opset1::Constant>(inputs[1].get_node_shared_ptr())->cast_vector<int64_t>()[0];
size_t splitedAxis = axis > 0 ? axis : inputs[0].get_shape().size() + axis;
if ((!constShape.empty()) && (constShape[splitedAxis] != 1)) {
std::vector<size_t> result(outputSize + 1);
result[0] = 0;
for (size_t i = 1; i < result.size(); ++i) {
result[i] = result[i - 1] + constShape[splitedAxis] / outputSize;
}
return result;
} else {
return std::vector<size_t>();
}
}
ngraph::Shape SplitTransformation::getConstSplitShape(
const std::vector<size_t>& constSplitLengths,
const ngraph::Shape& constShape, const size_t axis,
const size_t idx) const {
Shape result(constShape);
result[axis] = constSplitLengths[idx + 1] - constSplitLengths[idx];
return result;
}
void SplitTransformation::updateOutputs(
TransformationContext& context,

View File

@ -23,7 +23,7 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
//}
const auto stridedSliceShape = strSlice->get_input_shape(0);
const auto constantShape = constant->get_shape();
auto constantShape = constant->get_shape();
if (stridedSliceShape.size() != constantShape.size()) {
ngraph::Shape newConstantShape;
if (ngraph::shape_size(constantShape) == 1) {
@ -37,6 +37,7 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
newConstantShape.insert(newConstantShape.begin(), stridedSliceShape[0]);
}
}
constantShape = newConstantShape;
const auto newConstant = fold<ngraph::opset1::Broadcast>(
constant,
@ -45,13 +46,24 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
}
const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice);
auto beginMask = stridedSlice->get_begin_mask();
auto endMask = stridedSlice->get_end_mask();
for (size_t i = 0; i < constantShape.size(); ++i) {
// don't slice constant if current dimension is 1
if (constantShape[i] == 1ul) {
beginMask[i] = 1ul;
endMask[i] = 1ul;
}
}
const auto result = fold<ngraph::opset1::StridedSlice>(
constant,
stridedSlice->get_input_node_shared_ptr(1),
stridedSlice->get_input_node_shared_ptr(2),
stridedSlice->get_input_node_shared_ptr(3),
stridedSlice->get_begin_mask(),
stridedSlice->get_end_mask(),
beginMask,
endMask,
stridedSlice->get_new_axis_mask(),
stridedSlice->get_shrink_axis_mask(),
stridedSlice->get_ellipsis_mask());

View File

@ -22,16 +22,15 @@ namespace ngraph {
namespace pass {
namespace low_precision {
bool isQuantizationPerChannel(const std::shared_ptr<ngraph::Node>& node) {
if (node->outputs().size() > 1ul) {
return false;
}
//WA to support StridedSlice in ConcatTransformation
if (ngraph::is_type<opset1::StridedSlice>(node)) {
bool operationIsSupportedInConcat(const std::shared_ptr<ngraph::Node>& node) {
// list of operations, which change channels, but supported in ConcatTransformation
if (ngraph::is_type<opset1::StridedSlice>(node) ||
ngraph::is_type<opset1::Split>(node) ||
ngraph::is_type<opset1::VariadicSplit>(node)) {
return true;
}
// operations, which change channels, usually don't support in ConcatTransformation
const auto inputs = node->input_values();
for (const auto& input : inputs) {
if (ngraph::is_type<opset1::Constant>(input.get_node())) {
@ -82,7 +81,7 @@ bool Subgraph::fillSubgraphForQuantization(
if (fakeQuantizeChild != nullptr) {
//
} else {
if (layerTransformationsManager->isPrecisionPreserved(child) && isQuantizationPerChannel(child)) {
if (layerTransformationsManager->isPrecisionPreserved(child) && operationIsSupportedInConcat(child)) {
if (!fillSubgraphForIntermediate(child, handledLayers)) {
return false;
}
@ -104,7 +103,7 @@ bool Subgraph::atLeastOneIsIntermediate(const std::shared_ptr<ngraph::Node>& nod
return true;
}
if (!layerTransformationsManager->isPrecisionPreserved(child) || !isQuantizationPerChannel(child)) {
if (!layerTransformationsManager->isPrecisionPreserved(child) || !operationIsSupportedInConcat(child)) {
// child branch is out of subgraph
continue;
}
@ -144,10 +143,6 @@ bool Subgraph::fill(const std::shared_ptr<ngraph::Node>& layer, std::unordered_s
return false;
}
} else {
// WA: issue #46906
if (parent->get_output_size() != 1ul) {
return false;
}
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(parent, 0, true);
const std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantizeParent = dequantization.empty() ?
ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(parent) :
@ -161,7 +156,7 @@ bool Subgraph::fill(const std::shared_ptr<ngraph::Node>& layer, std::unordered_s
if (constant != nullptr) {
//
} else {
if (layerTransformationsManager->isPrecisionPreserved(parent) && isQuantizationPerChannel(parent)) {
if (layerTransformationsManager->isPrecisionPreserved(parent) && operationIsSupportedInConcat(parent)) {
if (!fillSubgraphForIntermediate(parent, handledLayers)) {
return false;
}
@ -197,7 +192,7 @@ bool Subgraph::fill(const std::shared_ptr<ngraph::Node>& layer, std::unordered_s
const std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantizeChild = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(child);
if (fakeQuantizeChild != nullptr) {
//
} else if (layerTransformationsManager->isPrecisionPreserved(child) && isQuantizationPerChannel(child)) {
} else if (layerTransformationsManager->isPrecisionPreserved(child) && operationIsSupportedInConcat(child)) {
if (!fillSubgraphForIntermediate(child, handledLayers)) {
return false;
}
@ -221,6 +216,13 @@ bool Subgraph::empty() const {
}
bool Subgraph::fillSubgraphForConcat(const std::shared_ptr<ngraph::opset1::Concat>& concat, std::unordered_set<std::string>& handledLayers) {
const auto axis = concat->get_axis();
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, concat->get_output_partial_shape(0).rank());
// supported only per-channel concat
if (normalizedAxis != 1ul) {
return false;
}
concatLayers.push_back(concat);
handledLayers.insert(concat->get_friendly_name());
layers.emplace(concat->get_friendly_name(), concat);

View File

@ -229,9 +229,11 @@ LowPrecisionTransformations LowPrecisionTransformer::getAllTransformations(const
add<ReluTransformation, opset1::Relu>(params).
add<ReshapeTransformation, opset1::Reshape>(params).
add<SqueezeTransformation, opset1::Squeeze>(params).
add<SplitTransformation, opset1::Split>(params).
add<StridedSliceTransformation, opset1::StridedSlice>(params).
add<TransposeTransformation, opset1::Transpose>(params).
add<UnsqueezeTransformation, opset1::Unsqueeze>(params).
add<VariadicSplitTransformation, opset1::VariadicSplit>(params).
addCleanup<FoldConvertTransformation, opset1::Subtract>(params).
addCleanup<FuseConvertTransformation, opset1::Multiply>(params).

View File

@ -20,26 +20,6 @@ void VariadicSplitTransformation::registerMatcherIn(GraphRewrite& pass, Transfor
make_op_label<opset1::Constant>() }));
}
std::vector<size_t> VariadicSplitTransformation::getConstSplitLengths(
const OutputVector& inputs,
const ngraph::Shape& constShape,
const size_t outputSize) const {
std::vector<size_t> lengths = as_type_ptr<opset1::Constant>(inputs[2].get_node_shared_ptr())->cast_vector<size_t>();
int64_t axis = as_type_ptr<opset1::Constant>(inputs[1].get_node_shared_ptr())->cast_vector<int64_t>()[0];
size_t splitedAxis = axis > 0 ? axis : inputs[0].get_shape().size() + axis;
if ((!constShape.empty()) && (constShape[splitedAxis] != 1)) {
std::vector<size_t> result(outputSize + 1);
result[0] = 0;
for (size_t i = 1; i < result.size(); ++i) {
result[i] = result[i - 1] + lengths[i - 1];
}
return result;
} else {
return std::vector<size_t>();
}
}
} // namespace low_precision
} // namespace pass
} // namespace ngraph

View File

@ -74,6 +74,7 @@ class ConcatTransformationTestValues {
public:
ngraph::pass::low_precision::LayerTransformation::Params params;
bool multiChannels;
std::int64_t axis;
ConcatTransformationActualValues actual;
ConcatTransformationResultValues result;
};
@ -114,7 +115,8 @@ public:
testValues.actual.convert2,
testValues.actual.dequantization2,
ngraph::element::undefined,
{});
{},
testValues.axis);
SimpleLowPrecisionTransformer transform;
if (testValues.multiChannels) {
@ -146,7 +148,8 @@ public:
testValues.result.convert2,
testValues.result.dequantization2,
testValues.result.precisionAfterOperation,
testValues.result.dequantizationAfter);
testValues.result.dequantizationAfter,
testValues.axis);
}
static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
@ -158,6 +161,7 @@ public:
result <<
LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
"axis_" << testValues.axis << "_" <<
testValues.actual << "_" <<
testValues.result << "_";
return result.str();
@ -180,6 +184,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -201,6 +206,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} },
{ ngraph::element::u8 },
@ -232,6 +238,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} },
{ ngraph::element::u8 },
@ -263,6 +270,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -290,6 +298,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -317,6 +326,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{ 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -340,6 +350,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{ 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -363,6 +374,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -386,6 +398,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{ 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -409,6 +422,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{
256ul,
@ -450,6 +464,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -477,6 +492,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsI8I8(),
false,
1,
{
{ 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{},
@ -500,6 +516,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -523,6 +540,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -546,6 +564,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{ 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{},
@ -569,6 +588,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{ 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {2.3007815f} },
{},
@ -588,10 +608,61 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ ngraph::element::f32, { 128 }, { 0.0302619f } }
}
},
// U8: concat multi channels with subtract, negative axis
{
LayerTransformation::createParamsU8I8(),
true,
-3,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
{ 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} },
{},
{}
},
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
{},
{},
{ 256ul, {}, {1.275f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
{},
{},
ngraph::element::u8,
{
ngraph::element::f32,
{{ 0.f, 0.f, 0.f, -255.f, -255.f, -255.f }},
{{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }}
}
}
},
// U8: concat multi channels with subtract, not supported axis
{
LayerTransformation::createParamsU8I8(),
true,
0,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
{ 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} },
{},
{}
},
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
{ 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} },
{},
{}
},
},
// not update precisions
{
LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
false,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
@ -615,6 +686,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
{},
@ -638,6 +710,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
{},

View File

@ -217,6 +217,40 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ ngraph::element::f32, {}, { 0.005f } }
}
},
// U8: concat multi channels with per-channel quantization
{
{ 1, 6, 10, 10 },
LayerTransformation::createParamsU8I8(),
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} },
{
256ul,
ngraph::Shape({ 1, 6, 1, 1 }),
{0.f, 0.f, 0.f, 0.f, 0.f, 0.f},
{255.f, 25.5f, 2.55f, 25.5f, 255.f, 2.55f},
{0.f, 0.f, 0.f, 0.f, 0.f, 0.f},
{255.f, 25.5f, 2.55f, 25.5f, 255.f, 2.55f}
}
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {255.f}},
{
256ul,
ngraph::Shape({ 1, 6, 1, 1 }),
{0.f, 0.f, 0.f, 0.f, 0.f, 0.f},
{255.f, 25.5f, 2.55f, 25.5f, 255.f, 2.55f},
{0.f},
{255.f}
},
ngraph::element::u8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::u8,
{ ngraph::element::f32, {}, {{ 0.005f, 0.005f, 0.005f, 1.f, 0.1f, 0.01f }} },
{ ngraph::element::f32, {}, {{ 0.1f, 1.f, 0.01f }} }
}
},
// I8: concat multi channels
{
{ 1, 6, 10, 10 },
@ -259,9 +293,8 @@ const std::vector<ConcatTransformationTestValues> testValues = {
},
};
// TODO: Split/VariadicSplit operations are not supported in ConcatTransformation
INSTANTIATE_TEST_CASE_P(
DISABLED_smoke_LPT,
smoke_LPT,
ConcatWithSplitTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),

View File

@ -160,21 +160,30 @@ const std::vector<SplitTransformationTestValues> testValues = {
{},
ngraph::element::u8,
{
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{2.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{22.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{3.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {2.f}, {22.f}},
{{ngraph::element::f32}, {3.f}, {33.f}},
}
}
},
// U8 per channel quantization with different values (constants without batch)
{
ngraph::Shape({ 1, 3, 16, 16 }), std::int64_t{-3}, size_t{3},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32},
{{1.f, 2.f, 3.f}, ngraph::element::f32, {3, 1, 1}},
{{11.f, 22.f, 33.f}, ngraph::element::f32, {3, 1, 1}}}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {2.f}, {22.f}},
{{ngraph::element::f32}, {3.f}, {33.f}},
}
}
},
@ -193,21 +202,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{},
ngraph::element::i8,
{
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{2.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{22.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{3.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {2.f}, {22.f}},
{{ngraph::element::f32}, {3.f}, {33.f}},
}
}
},
@ -226,21 +223,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{},
ngraph::element::u8,
{
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {1.f}, {11.f}},
}
}
},
@ -259,21 +244,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{},
ngraph::element::i8,
{
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {1.f}, {11.f}}
}
}
},
@ -358,21 +331,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{},
ngraph::element::u8,
{
{
{ngraph::element::f32},
{},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{},
{{22.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{{ngraph::element::f32}, {}, {11.f}},
{{ngraph::element::f32}, {}, {22.f}},
{{ngraph::element::f32}, {}, {33.f}},
}
}
},
@ -391,21 +352,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{},
ngraph::element::i8,
{
{
{ngraph::element::f32},
{},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{},
{{22.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{{ngraph::element::f32}, {}, {11.f}},
{{ngraph::element::f32}, {}, {22.f}},
{{ngraph::element::f32}, {}, {33.f}},
}
}
},

View File

@ -150,6 +150,17 @@ StridedSliceTransformationTestValues::LayerParams specialDimensionSlice = {
{}
};
StridedSliceTransformationTestValues::LayerParams specialDimensionEndSlice = {
{ 0, 0, 20, 0 },
{ 1, 3, 24, 24 },
{ 1, 1, 1, 1 },
{ 1, 1, 0, 1 },
{ 1, 1, 0, 1 },
{},
{},
{}
};
const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformationTestValues = {
// U8: channel slice, per-tensor quantization
{
@ -311,6 +322,38 @@ const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformati
{{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}}
}
},
// I8: special dimension end slice, per-channel quantization with different values
{
ngraph::Shape{1, 3, 24, 24},
LayerTransformation::createParamsI8I8(),
specialDimensionEndSlice,
{
ngraph::element::i8,
{{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}}
},
{
ngraph::element::i8,
{},
ngraph::element::i8,
{{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}}
}
},
// I8: special dimension end slice, per-tensor quantization with different values
{
ngraph::Shape{1, 3, 24, 24},
LayerTransformation::createParamsI8I8(),
specialDimensionEndSlice,
{
ngraph::element::i8,
{{ngraph::element::f32}, { 32.f }, { 0.1f }}
},
{
ngraph::element::i8,
{},
ngraph::element::i8,
{{ngraph::element::f32}, { 32.f }, { 0.1f }}
}
},
// I8: channel slice, quantization by special dimension
{
ngraph::Shape{1, 3, 4, 4},

View File

@ -177,11 +177,31 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{{1.f, 2.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{11.f, 22.f}, ngraph::element::f32, {1, 2, 1, 1}}
},
{{ngraph::element::f32}, {3.f}, {33.f}}
}
}
},
// U8 per channel quantization with different values (constants without batch)
{
ngraph::Shape({ 1, 3, 16, 16 }), std::int64_t{ -3 }, std::vector<size_t>{ 2, 1 },
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32},
{{1.f, 2.f, 3.f}, ngraph::element::f32, {3, 1, 1}},
{{11.f, 22.f, 33.f}, ngraph::element::f32, {3, 1, 1}}}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{
{ngraph::element::f32},
{{3.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
}
{{1.f, 2.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{11.f, 22.f}, ngraph::element::f32, {1, 2, 1, 1}}
},
{{ngraph::element::f32}, {3.f}, {33.f}}
}
}
},
@ -205,11 +225,7 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{{1.f, 2.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{11.f, 22.f}, ngraph::element::f32, {1, 2, 1, 1}}
},
{
{ngraph::element::f32},
{{3.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
}
{{ngraph::element::f32}, {3.f}, {33.f}}
}
}
},
@ -228,16 +244,8 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{},
ngraph::element::u8,
{
{
{ngraph::element::f32},
{{1.f, 1.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{11.f, 11.f}, ngraph::element::f32, {1, 2, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
}
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {1.f}, {11.f}}
}
}
},
@ -256,16 +264,8 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{},
ngraph::element::i8,
{
{
{ngraph::element::f32},
{{1.f, 1.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{11.f, 11.f}, ngraph::element::f32, {1, 2, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
}
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {1.f}, {11.f}}
}
}
},
@ -322,21 +322,13 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{},
ngraph::element::i8,
{
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{{ngraph::element::f32}, {1.f}, {11.f}},
{
{ngraph::element::f32},
{{2.f, 3.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{22.f, 33.f}, ngraph::element::f32, {1, 2, 1, 1}}
},
{
{ngraph::element::f32},
{{4.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{44.f}, ngraph::element::f32, {1, 1, 1, 1}}
}
{{ngraph::element::f32}, {4.f}, {44.f}}
}
}
},

View File

@ -45,8 +45,7 @@ const std::vector<ConcatWithSplitTransformationParam> testValues = {
}
};
// TODO: Split/VariadicSplit operations are not supported in ConcatTransformation
INSTANTIATE_TEST_CASE_P(DISABLED_smoke_LPT, ConcatWithSplitTransformation,
INSTANTIATE_TEST_CASE_P(smoke_LPT, ConcatWithSplitTransformation,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(ngraph::Shape({ 1, 6, 10, 10 })),

View File

@ -45,8 +45,7 @@ const std::vector<ConcatWithSplitTransformationParam> testValues = {
}
};
// TODO: Split/VariadicSplit operations are not supported in ConcatTransformation
INSTANTIATE_TEST_CASE_P(DISABLED_smoke_LPT, ConcatWithSplitTransformation,
INSTANTIATE_TEST_CASE_P(smoke_LPT, ConcatWithSplitTransformation,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(ngraph::Shape({ 1, 6, 10, 10 })),

View File

@ -114,7 +114,8 @@ public:
const DequantizationOperations::Convert& convert2,
const DequantizationOperations& dequantization2,
const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationAfter);
const DequantizationOperations& dequantizationAfter,
const std::int64_t& axis);
static std::shared_ptr<ngraph::Function> getReferenceWithNeighbors(
const ngraph::element::Type precision,

View File

@ -752,7 +752,8 @@ std::shared_ptr<ngraph::Function> ConcatFunction::get(
const DequantizationOperations::Convert& convert2,
const DequantizationOperations& dequantization2,
const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationAfter) {
const DequantizationOperations& dequantizationAfter,
const std::int64_t& axis) {
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
input1->set_friendly_name("input1");
@ -775,7 +776,7 @@ std::shared_ptr<ngraph::Function> ConcatFunction::get(
parent2 = makeDequantization(parent2, dequantization2);
}
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{ parent1, parent2 }, 1);
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{ parent1, parent2 }, axis);
auto& rtInfo = concat->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
@ -989,6 +990,13 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedInterme
input2->set_friendly_name("input2");
const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(input2, precision, fqOnData2);
replace_node(
fakeQuantize2->get_input_node_shared_ptr(3),
ngraph::pass::low_precision::NetworkHelper::toScalarIfPossible(fakeQuantize2->get_input_node_shared_ptr(3)));
replace_node(
fakeQuantize2->get_input_node_shared_ptr(4),
ngraph::pass::low_precision::NetworkHelper::toScalarIfPossible(fakeQuantize2->get_input_node_shared_ptr(4)));
fakeQuantize2->set_friendly_name("fakeQuantize2");
low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize2, precisionAfterOperation);
const auto deqBefore2 = makeDequantization(fakeQuantize2, dequantizationBefore1);