[LPT] Split & VariadicSplit support (#4195)

* [LPT] Split support in ConcatTransformation

* [LPT] fixing functional problems after enabling Split/VariadicSplit transformations

* [LPT] added test case for StridedSliceTransformation, enabled tests with split

* [LPT] ConcatTransformation refactoring

* ConcatTransformation: added axis check

* [LPT] Added foldDequantizationConstant to NetworkHelper & SplitTransformation refactoring

* [LPT] Subgraph: returned and refactored quantizationPerChannel

* [LPT] foldDequantizationConstant refactoring

* [LPT] SplitTransformation refactoring

* [LPT] hbonets fix & ConcatTrasnformation refactoring
This commit is contained in:
Vladislav Golubev 2021-04-07 11:02:43 +03:00 committed by GitHub
parent 6c290a506f
commit 7877287301
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 521 additions and 509 deletions

View File

@ -35,6 +35,7 @@ protected:
ngraph::pass::low_precision::Subgraph& subgraph, ngraph::pass::low_precision::Subgraph& subgraph,
std::function<void( std::function<void(
std::shared_ptr<ngraph::Node> layer, std::shared_ptr<ngraph::Node> layer,
std::shared_ptr<ngraph::Node> child,
const std::string originalLayerName, const std::string originalLayerName,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate)> getLayerDequantizationCallback) const; std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate)> getLayerDequantizationCallback) const;
@ -42,6 +43,15 @@ protected:
const TransformationContext& context, const TransformationContext& context,
const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations); const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations);
void fillDequantizationNodes(
const std::vector<FakeQuantizeDequantization>& layerDequantizations,
const std::shared_ptr<Node> layer,
NodeVector& convertNodes,
NodeVector& subtractNodes,
NodeVector& multiplyNodes) const;
std::shared_ptr<Node> concatenateDeqNodes(NodeVector& nodes) const;
private: private:
size_t getMinQuantizationLevels( size_t getMinQuantizationLevels(
const DataPrecision& dataPrecision, const DataPrecision& dataPrecision,

View File

@ -27,12 +27,9 @@ public:
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override; bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
private: private:
// Go through the parent elements of the layer and fill dequantization collection
// with Dq operations that should be inserted before the layer.
void fillDequantization( void fillDequantization(
std::shared_ptr<ngraph::Node> layer,
std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) const;
void fillQuantization(
const std::shared_ptr<ngraph::Node> layer, const std::shared_ptr<ngraph::Node> layer,
const std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize, const std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantization) const; std::vector<FakeQuantizeDequantization>& dequantization) const;
@ -46,8 +43,6 @@ private:
const FakeQuantizeDequantization& dequantization, const FakeQuantizeDequantization& dequantization,
const size_t sourceOutputIdx); const size_t sourceOutputIdx);
static FakeQuantizeDequantization broadcastDequantiationConstant(const FakeQuantizeDequantization& deq);
bool isMultiChannel(const std::vector<std::shared_ptr<ngraph::opset1::Concat>>& concatLayers) const noexcept; bool isMultiChannel(const std::vector<std::shared_ptr<ngraph::opset1::Concat>>& concatLayers) const noexcept;
}; };

View File

@ -50,6 +50,12 @@ public:
template <typename OperationType> template <typename OperationType>
static std::shared_ptr<Node> setOutDataPrecision(std::shared_ptr<OperationType> operation, const element::Type& precision); static std::shared_ptr<Node> setOutDataPrecision(std::shared_ptr<OperationType> operation, const element::Type& precision);
// applies constant folding of operation to constant and returns the specified output
static std::shared_ptr<opset1::Constant> foldDequantizationConstant(
const std::shared_ptr<opset1::Constant>& foldingConstant,
const std::shared_ptr<Node>& operation,
const size_t outIdx = 0);
static size_t getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights = false); static size_t getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights = false);
static std::vector<std::shared_ptr<Node>> getParentsRecursivelyExceptTypes( static std::vector<std::shared_ptr<Node>> getParentsRecursivelyExceptTypes(

View File

@ -24,15 +24,6 @@ public:
TransformationContext& context, TransformationContext& context,
std::vector<std::shared_ptr<ngraph::Node>> lastNodes, std::vector<std::shared_ptr<ngraph::Node>> lastNodes,
std::shared_ptr<ngraph::Node> originalNode) const; std::shared_ptr<ngraph::Node> originalNode) const;
protected:
ngraph::Shape getConstSplitShape(
const std::vector<size_t>& constSplitLengths,
const ngraph::Shape& constShape, const size_t axis,
const size_t idx) const;
virtual std::vector<size_t> getConstSplitLengths(
const OutputVector& inputs,
const ngraph::Shape& constShape,
const size_t outputSize) const;
}; };
} // namespace low_precision } // namespace low_precision
} // namespace pass } // namespace pass

View File

@ -17,11 +17,6 @@ class TRANSFORMATIONS_API VariadicSplitTransformation : public SplitTransformati
public: public:
VariadicSplitTransformation(const Params& params); VariadicSplitTransformation(const Params& params);
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override; void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
protected:
std::vector<size_t> getConstSplitLengths(
const OutputVector& inputs,
const ngraph::Shape& constShape,
const size_t outputSize) const override;
}; };
} // namespace low_precision } // namespace low_precision
} // namespace pass } // namespace pass

View File

@ -201,6 +201,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
auto dequantizationValuesCallback = [&]( auto dequantizationValuesCallback = [&](
std::shared_ptr<ngraph::Node> layer, std::shared_ptr<ngraph::Node> layer,
std::shared_ptr<ngraph::Node> child,
const std::string originalLayerName, const std::string originalLayerName,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) { std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) {
dequantizationsToConcatenate.push_back(dequantization); dequantizationsToConcatenate.push_back(dequantization);
@ -234,15 +235,97 @@ bool ConcatTransformation::isPrecisionPreserved(std::shared_ptr<Node>) const noe
bool ConcatTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const { bool ConcatTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
std::shared_ptr<opset1::Concat> concat = as_type_ptr<opset1::Concat>(layer); std::shared_ptr<opset1::Concat> concat = as_type_ptr<opset1::Concat>(layer);
return concat && concat->get_axis() == 1ul; if (concat == nullptr) {
return false;
}
const auto axis = concat->get_axis();
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, concat->get_output_partial_shape(0).rank());
return normalizedAxis == 1ul;
} }
void ConcatTransformation::fillDequantizationNodes(
const std::vector<FakeQuantizeDequantization>& layerDequantizations,
const std::shared_ptr<Node> layer,
NodeVector& convertNodes,
NodeVector& subtractNodes,
NodeVector& multiplyNodes) const {
if (layerDequantizations.size() > 1ul) {
auto broadcastElementWiseConst = [](
// FakeQuantize constant shape must be broadcastable to the shape on data.
std::shared_ptr<ngraph::opset1::Constant> operation,
const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
element::i64, ngraph::Shape{ targetShape.size() },
targetShape);
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
operation,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
return broadcast;
};
bool allDequantizationShiftAreZero = true;
bool allDequantizationMultiplyAreZero = true;
for (const auto& dequantization : layerDequantizations) {
if (dequantization.subtract != nullptr) {
allDequantizationShiftAreZero = false;
}
if (dequantization.multiply != nullptr) {
allDequantizationMultiplyAreZero = false;
}
}
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
const auto& dequantization = layerDequantizations[i];
const ngraph::element::Type precision = deqPrecision;
ngraph::Shape targetShape(layer->get_input_shape(i).size(), 1ul);
targetShape[1] = layer->get_input_shape(i)[1];
if (dequantization.convert != nullptr) {
convertNodes.push_back(dequantization.convert);
}
if (!allDequantizationShiftAreZero) {
subtractNodes.push_back(dequantization.subtract == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 0.f })) :
broadcastElementWiseConst(dequantization.subtractConstant, targetShape));
}
if (!allDequantizationMultiplyAreZero) {
multiplyNodes.push_back(dequantization.multiply == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 1.0f })) :
broadcastElementWiseConst(dequantization.multiplyConstant, targetShape));
}
}
} else {
// TODO: check constant shapes here - has to be scalar
if (layerDequantizations[0].convert != nullptr) {
convertNodes.push_back(layerDequantizations[0].convert);
}
if (layerDequantizations[0].subtract != nullptr) {
subtractNodes.push_back(layerDequantizations[0].subtract->input_value(1).get_node_shared_ptr());
}
if (layerDequantizations[0].multiply != nullptr) {
multiplyNodes.push_back(layerDequantizations[0].multiply->input_value(1).get_node_shared_ptr());
}
}
}
std::shared_ptr<Node> ConcatTransformation::concatenateDeqNodes(NodeVector& nodes) const {
return nodes.size() == 1ul ? nodes[0] : fold<ngraph::opset1::Concat>(nodes, 1);
}
void ConcatTransformation::addDequantizationLayers( void ConcatTransformation::addDequantizationLayers(
TransformationContext& context, TransformationContext& context,
ngraph::pass::low_precision::Subgraph& subgraph, ngraph::pass::low_precision::Subgraph& subgraph,
std::function<void( std::function<void(
std::shared_ptr<ngraph::Node> layer, std::shared_ptr<ngraph::Node> layer,
std::shared_ptr<ngraph::Node> child,
const std::string originalLayerName, const std::string originalLayerName,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate)> getLayerDequantizationCallback) const { std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate)> getLayerDequantizationCallback) const {
std::unordered_map<std::string, ngraph::Node*> outputs; std::unordered_map<std::string, ngraph::Node*> outputs;
@ -269,95 +352,28 @@ void ConcatTransformation::addDequantizationLayers(
ngraph::Node& child = *childInput.get_node(); ngraph::Node& child = *childInput.get_node();
if (subgraph.layers.find(child.get_friendly_name()) == subgraph.layers.end()) { if (subgraph.layers.find(child.get_friendly_name()) == subgraph.layers.end()) {
std::shared_ptr<ngraph::Node> source = layer;
const std::shared_ptr<ngraph::Node> destination = child.shared_from_this();
if (layerDequantizations.size() == 0ul) { if (layerDequantizations.size() == 0ul) {
// fill layerDequantizations collection // fill layerDequantizations collection
getLayerDequantizationCallback(layer, layer->get_friendly_name(), layerDequantizations); getLayerDequantizationCallback(source, destination, source->get_friendly_name(), layerDequantizations);
} }
std::shared_ptr<ngraph::Node> source = layer->shared_from_this();
{ {
std::vector<std::shared_ptr<ngraph::Node>> convertNodes; NodeVector convertNodes;
std::vector<std::shared_ptr<ngraph::Node>> subtractNodes; NodeVector subtractNodes;
std::vector<std::shared_ptr<ngraph::Node>> multiplyNodes; NodeVector multiplyNodes;
// forming nodes for concatenation // forming nodes for concatenation
if (layerDequantizations.size() > 1ul) { fillDequantizationNodes(layerDequantizations, layer, convertNodes, subtractNodes, multiplyNodes);
auto broadcastElementWiseConst = [](
// FakeQuantize constant shape must be broadcastable to the shape on data.
std::shared_ptr<ngraph::opset1::Constant> operation,
const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
element::i64, ngraph::Shape{ targetShape.size() },
targetShape);
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
operation,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
return broadcast;
};
bool allDequantizationShiftAreZero = true;
bool allDequantizationMultiplyAreZero = true;
for (FakeQuantizeDequantization dequantization : layerDequantizations) {
if (dequantization.subtract != nullptr) {
allDequantizationShiftAreZero = false;
}
if (dequantization.multiply != nullptr) {
allDequantizationMultiplyAreZero = false;
}
}
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
const auto& dequantization = layerDequantizations[i];
if (dequantization.convert != nullptr) {
convertNodes.push_back(dequantization.convert);
}
const ngraph::element::Type precision = deqPrecision;
ngraph::Shape targetShape(layer->get_input_shape(i).size(), 1ul);
targetShape[1] = layer->get_input_shape(i)[1];
if (!allDequantizationShiftAreZero) {
subtractNodes.push_back(dequantization.subtract == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 0.f })) :
broadcastElementWiseConst(
as_type_ptr<ngraph::opset1::Constant>(dequantization.subtract->input_value(1).get_node_shared_ptr()),
targetShape));
}
if (!allDequantizationMultiplyAreZero) {
multiplyNodes.push_back(dequantization.multiply == nullptr ?
std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 1.0f })) :
broadcastElementWiseConst(
as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->input_value(1).get_node_shared_ptr()),
targetShape));
}
}
} else {
// TODO: check constant shapes here - has to be scalar
if (layerDequantizations[0].convert != nullptr) {
convertNodes.push_back(layerDequantizations[0].convert);
}
if (layerDequantizations[0].subtract != nullptr) {
subtractNodes.push_back(layerDequantizations[0].subtract->input_value(1).get_node_shared_ptr());
}
if (layerDequantizations[0].multiply != nullptr) {
multiplyNodes.push_back(layerDequantizations[0].multiply->input_value(1).get_node_shared_ptr());
}
}
// TODO: the second place (first is FQ decomposition) where dequantization operations are inserted // TODO: the second place (first is FQ decomposition) where dequantization operations are inserted
const std::shared_ptr<ngraph::Node> destination = child.shared_from_this();
if (!convertNodes.empty()) { if (!convertNodes.empty()) {
const size_t sourceOutputIdx = NetworkHelper::getChildInputIndex(source, destination); const size_t sourceOutputIdx = NetworkHelper::getChildInputIndex(source, destination);
std::shared_ptr<ngraph::Node> convert = std::shared_ptr<ngraph::Node> convert =
convertNodes[0]->clone_with_new_inputs({ destination->get_input_source_output(sourceOutputIdx) }); convertNodes[0]->clone_with_new_inputs({ destination->get_input_source_output(sourceOutputIdx) });
insert_new_node_between(source, destination, convert); insert_new_node_between(source, destination, convert);
ngraph::copy_runtime_info({ layer, convert }, convert); ngraph::copy_runtime_info({ layer, convert }, convert);
source = convert; source = convert;
@ -368,9 +384,8 @@ void ConcatTransformation::addDequantizationLayers(
const size_t sourceOutputIdx = NetworkHelper::getChildInputIndex(source, destination); const size_t sourceOutputIdx = NetworkHelper::getChildInputIndex(source, destination);
std::shared_ptr<ngraph::opset1::Subtract> subtract = std::make_shared<DequantizationSubtract>( std::shared_ptr<ngraph::opset1::Subtract> subtract = std::make_shared<DequantizationSubtract>(
destination->get_input_source_output(sourceOutputIdx), destination->get_input_source_output(sourceOutputIdx),
NetworkHelper::toScalarIfPossible(subtractNodes.size() == 1ul ? NetworkHelper::toScalarIfPossible(concatenateDeqNodes(subtractNodes)));
subtractNodes[0] :
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subtractNodes, 1)));
insert_new_node_between(source, destination, subtract); insert_new_node_between(source, destination, subtract);
ngraph::copy_runtime_info({ layer, subtract }, subtract); ngraph::copy_runtime_info({ layer, subtract }, subtract);
source = subtract; source = subtract;
@ -381,10 +396,9 @@ void ConcatTransformation::addDequantizationLayers(
std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>( std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
DequantizationMultiply( DequantizationMultiply(
destination->get_input_source_output(sourceOutputIdx), destination->get_input_source_output(sourceOutputIdx),
NetworkHelper::toScalarIfPossible(multiplyNodes.size() == 1ul ? NetworkHelper::toScalarIfPossible(concatenateDeqNodes(multiplyNodes))),
multiplyNodes[0] :
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(multiplyNodes, 1))),
layerDequantizations[0].multiply->get_output_element_type(0)); layerDequantizations[0].multiply->get_output_element_type(0));
insert_new_node_between(source, destination, multiply); insert_new_node_between(source, destination, multiply);
ngraph::copy_runtime_info({ layer, multiply }, multiply); ngraph::copy_runtime_info({ layer, multiply }, multiply);
source = multiply; source = multiply;

View File

@ -137,6 +137,7 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
auto dequantizationValuesCallback = [&]( auto dequantizationValuesCallback = [&](
std::shared_ptr<ngraph::Node> layer, std::shared_ptr<ngraph::Node> layer,
std::shared_ptr<ngraph::Node> child,
const std::string originalLayerName, const std::string originalLayerName,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) { std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) {
if (layer->get_friendly_name() != originalLayerName) { if (layer->get_friendly_name() != originalLayerName) {
@ -157,6 +158,15 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
layer, layer,
dequantizations, dequantizations,
dequantizationsToConcatenate); dequantizationsToConcatenate);
if (!is_type<ngraph::opset1::Concat>(layer)) {
// for intermediate layers we should get Dq operations to be inserted between layer and child
assert(dequantizationsToConcatenate.size() == 1ul);
const size_t sourceOutputIdx = NetworkHelper::getParentOutputIndex(layer, child);
if (layer->get_input_shape(0)[1] != layer->get_output_shape(sourceOutputIdx)[1]) {
dequantizationsToConcatenate[0] = getFoldedDequantization(layer, dequantizationsToConcatenate[0], sourceOutputIdx);
}
}
}; };
addDequantizationLayers(context, subgraph, dequantizationValuesCallback); addDequantizationLayers(context, subgraph, dequantizationValuesCallback);
@ -185,137 +195,66 @@ bool ConcatMultiChannelsTransformation::isPrecisionPreserved(std::shared_ptr<Nod
return true; return true;
} }
// fill dequantizationsToMerge collection for layer with using dequantizationByFakeQuantize
void ConcatMultiChannelsTransformation::fillDequantization( void ConcatMultiChannelsTransformation::fillDequantization(
std::shared_ptr<ngraph::Node> layer,
std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) const {
std::shared_ptr<ngraph::opset1::FakeQuantize> currentFakeQuantize = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(layer);
if (currentFakeQuantize) {
const auto it = dequantizationByFakeQuantize.find(currentFakeQuantize->get_friendly_name());
if (it == dequantizationByFakeQuantize.end()) {
THROW_IE_LPT_EXCEPTION(*currentFakeQuantize) << "dequantization scale values are not found";
}
const FakeQuantizeDequantization& fakeQuantizeDequantization = it->second;
dequantizationsToConcatenate.push_back(broadcastDequantiationConstant(fakeQuantizeDequantization));
} else {
fillQuantization(layer, dequantizationByFakeQuantize, dequantizationsToConcatenate);
}
}
void ConcatMultiChannelsTransformation::fillQuantization(
const std::shared_ptr<ngraph::Node> layer, const std::shared_ptr<ngraph::Node> layer,
const std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize, const std::unordered_map<std::string, FakeQuantizeDequantization>& dequantizationByFakeQuantize,
std::vector<FakeQuantizeDequantization>& dequantization) const { std::vector<FakeQuantizeDequantization>& dequantization) const {
for (size_t i = 0; i < layer->get_input_size(); ++i) { const auto fillDqByFakeQuantize = [&](const std::shared_ptr<ngraph::Node>& fq) {
std::shared_ptr<ngraph::Node> parent = layer->get_input_node_shared_ptr(i); const auto it = dequantizationByFakeQuantize.find(fq->get_friendly_name());
if (it == dequantizationByFakeQuantize.end()) {
THROW_IE_LPT_EXCEPTION(*fq) << "dequantization scale values are not found";
}
std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantize = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(parent); const FakeQuantizeDequantization& fakeQuantizeDequantization = it->second;
if (fakeQuantize) { dequantization.push_back(fakeQuantizeDequantization);
const auto it = dequantizationByFakeQuantize.find(fakeQuantize->get_friendly_name()); };
if (it == dequantizationByFakeQuantize.end()) {
THROW_IE_LPT_EXCEPTION(*fakeQuantize) << "dequantization scale values are not found"; if (is_type<ngraph::opset1::FakeQuantize>(layer)) {
fillDqByFakeQuantize(layer);
} else {
for (size_t i = 0; i < layer->get_input_size(); ++i) {
std::shared_ptr<ngraph::Node> parent = layer->get_input_node_shared_ptr(i);
if (as_type_ptr<ngraph::opset1::Constant>(parent)) {
continue;
} }
const FakeQuantizeDequantization& fakeQuantizeDequantization = it->second; const auto fakeQuantize = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(parent);
dequantization.push_back(broadcastDequantiationConstant(fakeQuantizeDequantization)); if (fakeQuantize) {
} else { fillDqByFakeQuantize(fakeQuantize);
std::shared_ptr<ngraph::opset1::Concat> concat = ngraph::as_type_ptr<ngraph::opset1::Concat>(parent);
if (concat) {
std::vector<FakeQuantizeDequantization> dequantizationToConcatenate;
fillQuantization(concat, dequantizationByFakeQuantize, dequantizationToConcatenate);
// add concatenated dequantization operations to dequantization collection
dequantization.push_back(getConcatenatedDequantization(concat, dequantizationToConcatenate));
} else { } else {
std::shared_ptr<ngraph::opset1::StridedSlice> stridedSlice = ngraph::as_type_ptr<ngraph::opset1::StridedSlice>(parent); const auto concat = ngraph::as_type_ptr<ngraph::opset1::Concat>(parent);
if (stridedSlice) { if (concat) {
std::vector<FakeQuantizeDequantization> dequantizationToPropagate; std::vector<FakeQuantizeDequantization> dequantizationToConcatenate;
fillQuantization(stridedSlice, dequantizationByFakeQuantize, dequantizationToPropagate); fillDequantization(concat, dequantizationByFakeQuantize, dequantizationToConcatenate);
const size_t sourceOutputIdx = NetworkHelper::getParentOutputIndex(parent, layer); // add concatenated dequantization operations to dequantization collection
// add folded dequantization operations to dequantization colection dequantization.push_back(getConcatenatedDequantization(concat, dequantizationToConcatenate));
dequantization.push_back(getFoldedDequantization(stridedSlice, dequantizationToPropagate[0], sourceOutputIdx));
} else { } else {
fillQuantization(parent, dequantizationByFakeQuantize, dequantization); const size_t sourceOutputIdx = NetworkHelper::getParentOutputIndex(parent, layer);
if (parent->get_input_shape(0)[1] != parent->get_output_shape(sourceOutputIdx)[1]) {
std::vector<FakeQuantizeDequantization> dequantizationToPropagate;
fillDequantization(parent, dequantizationByFakeQuantize, dequantizationToPropagate);
// add folded dequantization operations to dequantization colection
dequantization.push_back(getFoldedDequantization(parent, dequantizationToPropagate[0], sourceOutputIdx));
} else {
fillDequantization(parent, dequantizationByFakeQuantize, dequantization);
}
} }
} }
} }
} }
} }
// broadcast of dequantization constants by channels
FakeQuantizeDequantization ConcatMultiChannelsTransformation::broadcastDequantiationConstant(const FakeQuantizeDequantization& deq) {
ngraph::Shape targetShape(deq.data.get_shape().size(), 1ul);
targetShape[1] = deq.data.get_shape()[1];
FakeQuantizeDequantization result;
result.data = deq.data;
result.convert = deq.convert;
const auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
element::i64, ngraph::Shape{ targetShape.size() },
targetShape);
if (deq.subtract) {
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
deq.subtractConstant,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
result.subtract = deq.subtract;
result.subtractConstant = as_type_ptr<ngraph::opset1::Constant>(broadcast);
}
if (deq.multiply) {
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
deq.multiplyConstant,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
result.multiply = deq.multiply;
result.multiplyConstant = as_type_ptr<ngraph::opset1::Constant>(broadcast);
}
return result;
}
FakeQuantizeDequantization ConcatMultiChannelsTransformation::getConcatenatedDequantization( FakeQuantizeDequantization ConcatMultiChannelsTransformation::getConcatenatedDequantization(
const std::shared_ptr<ngraph::opset1::Concat> concat, const std::shared_ptr<ngraph::opset1::Concat> concat,
const std::vector<FakeQuantizeDequantization>& dequantization) const { const std::vector<FakeQuantizeDequantization>& dequantization) const {
bool allDequantizationShiftAreZero = true;
bool allDequantizationMultiplyAreZero = true;
for (const auto& deq : dequantization) {
if (deq.subtract != nullptr) {
allDequantizationShiftAreZero = false;
}
if (deq.multiply != nullptr) {
allDequantizationMultiplyAreZero = false;
}
}
NodeVector convertNodes; NodeVector convertNodes;
NodeVector subNodes; NodeVector subtractNodes;
NodeVector mulNodes; NodeVector multiplyNodes;
//preparing to concatenate dequantization nodes
for (const auto& deq : dequantization) {
ngraph::Shape targetShape(deq.data.get_shape().size(), 1ul);
targetShape[1] = deq.data.get_shape()[1];
if (deq.convert != nullptr) { // forming nodes for concatenation
convertNodes.push_back(deq.convert); fillDequantizationNodes(dequantization, concat, convertNodes, subtractNodes, multiplyNodes);
}
if (!allDequantizationShiftAreZero) {
subNodes.push_back(deq.subtract == nullptr ?
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 0.f })) :
deq.subtractConstant);
}
if (!allDequantizationMultiplyAreZero) {
mulNodes.push_back(deq.multiply == nullptr ?
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 1.0f })) :
deq.multiplyConstant);
}
}
std::shared_ptr<Node> parent = concat; std::shared_ptr<Node> parent = concat;
std::shared_ptr<DequantizationConvert> convert; std::shared_ptr<DequantizationConvert> convert;
@ -326,20 +265,16 @@ FakeQuantizeDequantization ConcatMultiChannelsTransformation::getConcatenatedDeq
std::shared_ptr<DequantizationSubtract> subtract; std::shared_ptr<DequantizationSubtract> subtract;
std::shared_ptr<ngraph::opset1::Constant> subConst; std::shared_ptr<ngraph::opset1::Constant> subConst;
if (!subNodes.empty()) { if (!subtractNodes.empty()) {
subConst = as_type_ptr<ngraph::opset1::Constant>( subConst = as_type_ptr<ngraph::opset1::Constant>(concatenateDeqNodes(subtractNodes));
subNodes.size() == 1ul ? subNodes[0] : fold<ngraph::opset1::Concat>(subNodes, 1ul));
subtract = std::make_shared<DequantizationSubtract>(parent, subConst); subtract = std::make_shared<DequantizationSubtract>(parent, subConst);
parent = subtract; parent = subtract;
} }
std::shared_ptr<DequantizationMultiply> multiply; std::shared_ptr<DequantizationMultiply> multiply;
std::shared_ptr<ngraph::opset1::Constant> mulConst; std::shared_ptr<ngraph::opset1::Constant> mulConst;
if (!mulNodes.empty()) { if (!multiplyNodes.empty()) {
mulConst = as_type_ptr<ngraph::opset1::Constant>( mulConst = as_type_ptr<ngraph::opset1::Constant>(concatenateDeqNodes(multiplyNodes));
mulNodes.size() == 1ul ? mulNodes[0] : fold<ngraph::opset1::Concat>(mulNodes, 1ul));
multiply = std::make_shared<DequantizationMultiply>(parent, mulConst); multiply = std::make_shared<DequantizationMultiply>(parent, mulConst);
} }
@ -352,24 +287,19 @@ FakeQuantizeDequantization ConcatMultiChannelsTransformation::getFoldedDequantiz
const size_t sourceOutputIdx) { const size_t sourceOutputIdx) {
OutputVector inputs = operation->input_values(); OutputVector inputs = operation->input_values();
OutputVector outputs(operation->get_output_size()); OutputVector outputs(operation->get_output_size());
Output<Node> data = operation->output(sourceOutputIdx);
std::shared_ptr<Node> parent = operation; std::shared_ptr<Node> parent = operation;
std::shared_ptr<DequantizationConvert> convert; std::shared_ptr<DequantizationConvert> convert;
if (dequantization.convert) { if (dequantization.convert) {
convert = as_type_ptr<DequantizationConvert>(dequantization.convert->clone_with_new_inputs({ parent })); convert = as_type_ptr<DequantizationConvert>(dequantization.convert->clone_with_new_inputs({ data }));
parent = convert; parent = convert;
} }
std::shared_ptr<DequantizationSubtract> subtract; std::shared_ptr<DequantizationSubtract> subtract;
std::shared_ptr<ngraph::opset1::Constant> subConst; std::shared_ptr<ngraph::opset1::Constant> subConst;
if (dequantization.subtract) { if (dequantization.subtract) {
inputs[0] = dequantization.subtractConstant; subConst = NetworkHelper::foldDequantizationConstant(dequantization.subtractConstant, operation, sourceOutputIdx);
const auto op = operation->clone_with_new_inputs(inputs);
// constant folding of subtract constant
op->constant_fold(outputs, inputs);
subConst = as_type_ptr<ngraph::opset1::Constant>(outputs[sourceOutputIdx].get_node_shared_ptr());
subtract = std::make_shared<DequantizationSubtract>(parent, subConst); subtract = std::make_shared<DequantizationSubtract>(parent, subConst);
parent = subtract; parent = subtract;
} }
@ -377,17 +307,11 @@ FakeQuantizeDequantization ConcatMultiChannelsTransformation::getFoldedDequantiz
std::shared_ptr<DequantizationMultiply> multiply; std::shared_ptr<DequantizationMultiply> multiply;
std::shared_ptr<ngraph::opset1::Constant> mulConst; std::shared_ptr<ngraph::opset1::Constant> mulConst;
if (dequantization.multiply) { if (dequantization.multiply) {
inputs[0] = dequantization.multiplyConstant; mulConst = NetworkHelper::foldDequantizationConstant(dequantization.multiplyConstant, operation, sourceOutputIdx);
const auto op = operation->clone_with_new_inputs(inputs);
// constant folding of multiply constant
op->constant_fold(outputs, inputs);
mulConst = as_type_ptr<ngraph::opset1::Constant>(outputs[sourceOutputIdx].get_node_shared_ptr());
multiply = std::make_shared<DequantizationMultiply>(parent, mulConst); multiply = std::make_shared<DequantizationMultiply>(parent, mulConst);
} }
return FakeQuantizeDequantization(operation->output(sourceOutputIdx), convert, subtract, nullptr, subConst, multiply, mulConst); return FakeQuantizeDequantization(data, convert, subtract, nullptr, subConst, multiply, mulConst);
} }
} // namespace low_precision } // namespace low_precision

View File

@ -87,6 +87,31 @@ bool NetworkHelper::isConstantPath(const std::shared_ptr<Node>& op) {
return true; return true;
} }
std::shared_ptr<opset1::Constant> NetworkHelper::foldDequantizationConstant(
const std::shared_ptr<opset1::Constant>& foldingConstant,
const std::shared_ptr<Node>& operation,
const size_t outIdx) {
OutputVector inputs = operation->input_values();
OutputVector outputs(operation->get_output_size());
if (shape_size(foldingConstant->get_shape()) == 1ul) {
return toScalar(foldingConstant);
} else {
inputs[0] = foldingConstant;
const auto op = operation->clone_with_new_inputs(inputs);
// constant folding of constant
op->constant_fold(outputs, inputs);
const auto result = as_type_ptr<opset1::Constant>(outputs[outIdx].get_node_shared_ptr());
if (result == nullptr) {
THROW_IE_LPT_EXCEPTION(*result) << "result of constant folding is not constant";
}
return result;
}
}
size_t NetworkHelper::getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights) { size_t NetworkHelper::getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights) {
if (layer->outputs().size() == 0) { if (layer->outputs().size() == 0) {
THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " doesn't have output tensors"; THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " doesn't have output tensors";

View File

@ -5,6 +5,7 @@
#include "low_precision/split.hpp" #include "low_precision/split.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "low_precision/network_helper.hpp" #include "low_precision/network_helper.hpp"
#include "low_precision/common/dequantization_op.hpp"
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
@ -22,81 +23,68 @@ bool SplitTransformation::transform(TransformationContext& context, ngraph::patt
return false; return false;
} }
const std::shared_ptr<Node> split = NetworkHelper::separateInStandaloneBranch(m.get_match_root()); const auto split = NetworkHelper::separateInStandaloneBranch(m.get_match_root());
auto dequantization = NetworkHelper::getDequantization(split); const auto dequantization = NetworkHelper::getDequantization(split);
OutputVector inputs(split->get_input_size()); OutputVector inputs = split->input_values();
for (size_t i = 0; i < split->get_input_size(); ++i) { inputs[0] = dequantization.data;
inputs[i] = split->get_input_node_shared_ptr(i);
}
const size_t dequantizationIndex = NetworkHelper::getChildInputIndex(dequantization.multiply, split); const auto newSplit = split->clone_with_new_inputs(inputs);
inputs[dequantizationIndex] = dequantization.data;
std::shared_ptr<ngraph::Node> newSplit = split->clone_with_new_inputs(inputs);
newSplit->set_friendly_name(split->get_friendly_name()); newSplit->set_friendly_name(split->get_friendly_name());
ngraph::copy_runtime_info(split, newSplit);
const ngraph::Shape subConstShape = dequantization.subtract ? const int64_t axis = as_type_ptr<opset1::Constant>(split->get_input_node_shared_ptr(1))->cast_vector<int64_t>()[0];
dequantization.subtract->get_input_node_shared_ptr(1)->get_shape() : Shape{}; const size_t normalizedAxis = normalize_axis(split->get_friendly_name(), axis, split->get_input_partial_shape(0).rank());
std::vector<float> subValues = dequantization.subtract ? as_type_ptr<opset1::Constant>( const size_t outputSize = newSplit->get_output_size();
dequantization.subtract->get_input_node_shared_ptr(1))->cast_vector<float>() : std::vector<float>();
const ngraph::Shape mulConstShape = dequantization.multiply->get_input_node_shared_ptr(1)->get_shape(); const auto splitConstant = [&](const std::shared_ptr<Node> operation) {
std::vector<float> mulValues = as_type_ptr<opset1::Constant>( // if batch is absent in constant shape - add batch
dequantization.multiply->get_input_node_shared_ptr(1))->cast_vector<float>(); const auto normalizedConstant = NetworkHelper::normalizeDequantizationShape(operation);
const auto constantShape = normalizedConstant->get_shape();
int64_t SplitedAxis = as_type_ptr<opset1::Constant>(split->get_input_node_shared_ptr(1))->cast_vector<int64_t>()[0]; OutputVector results(outputSize);
size_t axis = SplitedAxis > 0 ? SplitedAxis : split->get_input_shape(0).size() + SplitedAxis; if ((shape_size(constantShape) == 1ul) || (constantShape[normalizedAxis] == 1ul)) {
size_t outputSize = newSplit->get_output_size(); std::for_each(results.begin(), results.end(), [&](Output<Node>& elem) { elem = normalizedConstant->clone_with_new_inputs({}); });
const auto subSplitLengths = getConstSplitLengths(inputs, subConstShape, outputSize);
const auto mulSplitLengths = getConstSplitLengths(inputs, mulConstShape, outputSize);
std::vector<std::shared_ptr<ngraph::Node>> lastNodes(outputSize);
ngraph::OutputVector replacement;
for (size_t i = 0; i < outputSize; ++i) {
Output<Node> previous = newSplit->output(i);
if (dequantization.convert != nullptr) {
const std::shared_ptr<ngraph::Node> convert =
dequantization.convert->clone_with_new_inputs({ newSplit->output(i) });
previous = convert;
}
if (dequantization.subtract != nullptr) {
std::shared_ptr<ngraph::opset1::Constant> subConst;
if (!subSplitLengths.empty()) {
const auto newSubConstShape = getConstSplitShape(subSplitLengths, subConstShape, axis, i);
std::vector<float> newSubValues(
subValues.begin() + subSplitLengths[i],
subValues.begin() + subSplitLengths[i + 1]);
subConst = as_type_ptr<ngraph::opset1::Constant>(std::make_shared<ngraph::opset1::Constant>(
dequantization.subtract->get_input_element_type(1),
newSubConstShape,
newSubValues));
} else {
subConst = as_type_ptr<ngraph::opset1::Constant>(dequantization.subtract->get_input_node_shared_ptr(1)->clone_with_new_inputs({}));
}
const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(previous, subConst);
previous = subtract;
}
std::shared_ptr<ngraph::opset1::Constant> mulConst;
if (!mulSplitLengths.empty()) {
const auto newMulConstShape = getConstSplitShape(mulSplitLengths, mulConstShape, axis, i);
std::vector<float> newMulValues(
mulValues.begin() + mulSplitLengths[i],
mulValues.begin() + mulSplitLengths[i + 1]);
mulConst = as_type_ptr<ngraph::opset1::Constant>(std::make_shared<ngraph::opset1::Constant>(
dequantization.multiply->get_input_element_type(1), newMulConstShape, newMulValues));
} else { } else {
mulConst = as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({})); // prepare new inputs for constant folding
OutputVector inputs = newSplit->input_values();
inputs[0] = normalizedConstant;
const auto foldSplit = newSplit->clone_with_new_inputs(inputs);
// fold and fill results
foldSplit->constant_fold(results, inputs);
} }
const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(previous, mulConst);
for (auto& result : results) {
result = NetworkHelper::toScalarIfPossible(result.get_node_shared_ptr());
}
return results;
};
// get splited dequantization constants
OutputVector splitedSub = dequantization.subtract ? splitConstant(dequantization.subtract) : OutputVector{};
OutputVector splitedMul = splitConstant(dequantization.multiply);
NodeVector lastNodes;
OutputVector replacement;
for (size_t i = 0; i < outputSize; ++i) {
Output<Node> parent = newSplit->output(i);
if (dequantization.convert) {
const auto convert = dequantization.convert->clone_with_new_inputs({ newSplit->output(i) });
copy_runtime_info({ newSplit, convert }, convert);
parent = convert;
}
if (dequantization.subtract) {
const auto subtract = std::make_shared<DequantizationSubtract>(parent, splitedSub[i]);
copy_runtime_info({ newSplit, subtract }, subtract);
parent = subtract;
}
const auto multiply = std::make_shared<DequantizationMultiply>(parent, splitedMul[i]);
copy_runtime_info({ newSplit, multiply }, multiply);
lastNodes.push_back(multiply); lastNodes.push_back(multiply);
replacement.push_back(multiply); replacement.push_back(multiply);
@ -107,33 +95,6 @@ bool SplitTransformation::transform(TransformationContext& context, ngraph::patt
return true; return true;
} }
std::vector<size_t> SplitTransformation::getConstSplitLengths(
const OutputVector& inputs,
const ngraph::Shape& constShape,
const size_t outputSize) const {
int64_t axis = as_type_ptr<opset1::Constant>(inputs[1].get_node_shared_ptr())->cast_vector<int64_t>()[0];
size_t splitedAxis = axis > 0 ? axis : inputs[0].get_shape().size() + axis;
if ((!constShape.empty()) && (constShape[splitedAxis] != 1)) {
std::vector<size_t> result(outputSize + 1);
result[0] = 0;
for (size_t i = 1; i < result.size(); ++i) {
result[i] = result[i - 1] + constShape[splitedAxis] / outputSize;
}
return result;
} else {
return std::vector<size_t>();
}
}
ngraph::Shape SplitTransformation::getConstSplitShape(
const std::vector<size_t>& constSplitLengths,
const ngraph::Shape& constShape, const size_t axis,
const size_t idx) const {
Shape result(constShape);
result[axis] = constSplitLengths[idx + 1] - constSplitLengths[idx];
return result;
}
void SplitTransformation::updateOutputs( void SplitTransformation::updateOutputs(
TransformationContext& context, TransformationContext& context,

View File

@ -23,7 +23,7 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
//} //}
const auto stridedSliceShape = strSlice->get_input_shape(0); const auto stridedSliceShape = strSlice->get_input_shape(0);
const auto constantShape = constant->get_shape(); auto constantShape = constant->get_shape();
if (stridedSliceShape.size() != constantShape.size()) { if (stridedSliceShape.size() != constantShape.size()) {
ngraph::Shape newConstantShape; ngraph::Shape newConstantShape;
if (ngraph::shape_size(constantShape) == 1) { if (ngraph::shape_size(constantShape) == 1) {
@ -37,6 +37,7 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
newConstantShape.insert(newConstantShape.begin(), stridedSliceShape[0]); newConstantShape.insert(newConstantShape.begin(), stridedSliceShape[0]);
} }
} }
constantShape = newConstantShape;
const auto newConstant = fold<ngraph::opset1::Broadcast>( const auto newConstant = fold<ngraph::opset1::Broadcast>(
constant, constant,
@ -45,13 +46,24 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
} }
const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice); const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice);
auto beginMask = stridedSlice->get_begin_mask();
auto endMask = stridedSlice->get_end_mask();
for (size_t i = 0; i < constantShape.size(); ++i) {
// don't slice constant if current dimension is 1
if (constantShape[i] == 1ul) {
beginMask[i] = 1ul;
endMask[i] = 1ul;
}
}
const auto result = fold<ngraph::opset1::StridedSlice>( const auto result = fold<ngraph::opset1::StridedSlice>(
constant, constant,
stridedSlice->get_input_node_shared_ptr(1), stridedSlice->get_input_node_shared_ptr(1),
stridedSlice->get_input_node_shared_ptr(2), stridedSlice->get_input_node_shared_ptr(2),
stridedSlice->get_input_node_shared_ptr(3), stridedSlice->get_input_node_shared_ptr(3),
stridedSlice->get_begin_mask(), beginMask,
stridedSlice->get_end_mask(), endMask,
stridedSlice->get_new_axis_mask(), stridedSlice->get_new_axis_mask(),
stridedSlice->get_shrink_axis_mask(), stridedSlice->get_shrink_axis_mask(),
stridedSlice->get_ellipsis_mask()); stridedSlice->get_ellipsis_mask());

View File

@ -22,16 +22,15 @@ namespace ngraph {
namespace pass { namespace pass {
namespace low_precision { namespace low_precision {
bool isQuantizationPerChannel(const std::shared_ptr<ngraph::Node>& node) { bool operationIsSupportedInConcat(const std::shared_ptr<ngraph::Node>& node) {
if (node->outputs().size() > 1ul) { // list of operations, which change channels, but supported in ConcatTransformation
return false; if (ngraph::is_type<opset1::StridedSlice>(node) ||
} ngraph::is_type<opset1::Split>(node) ||
ngraph::is_type<opset1::VariadicSplit>(node)) {
//WA to support StridedSlice in ConcatTransformation
if (ngraph::is_type<opset1::StridedSlice>(node)) {
return true; return true;
} }
// operations, which change channels, usually don't support in ConcatTransformation
const auto inputs = node->input_values(); const auto inputs = node->input_values();
for (const auto& input : inputs) { for (const auto& input : inputs) {
if (ngraph::is_type<opset1::Constant>(input.get_node())) { if (ngraph::is_type<opset1::Constant>(input.get_node())) {
@ -82,7 +81,7 @@ bool Subgraph::fillSubgraphForQuantization(
if (fakeQuantizeChild != nullptr) { if (fakeQuantizeChild != nullptr) {
// //
} else { } else {
if (layerTransformationsManager->isPrecisionPreserved(child) && isQuantizationPerChannel(child)) { if (layerTransformationsManager->isPrecisionPreserved(child) && operationIsSupportedInConcat(child)) {
if (!fillSubgraphForIntermediate(child, handledLayers)) { if (!fillSubgraphForIntermediate(child, handledLayers)) {
return false; return false;
} }
@ -104,7 +103,7 @@ bool Subgraph::atLeastOneIsIntermediate(const std::shared_ptr<ngraph::Node>& nod
return true; return true;
} }
if (!layerTransformationsManager->isPrecisionPreserved(child) || !isQuantizationPerChannel(child)) { if (!layerTransformationsManager->isPrecisionPreserved(child) || !operationIsSupportedInConcat(child)) {
// child branch is out of subgraph // child branch is out of subgraph
continue; continue;
} }
@ -144,10 +143,6 @@ bool Subgraph::fill(const std::shared_ptr<ngraph::Node>& layer, std::unordered_s
return false; return false;
} }
} else { } else {
// WA: issue #46906
if (parent->get_output_size() != 1ul) {
return false;
}
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(parent, 0, true); const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(parent, 0, true);
const std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantizeParent = dequantization.empty() ? const std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantizeParent = dequantization.empty() ?
ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(parent) : ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(parent) :
@ -161,7 +156,7 @@ bool Subgraph::fill(const std::shared_ptr<ngraph::Node>& layer, std::unordered_s
if (constant != nullptr) { if (constant != nullptr) {
// //
} else { } else {
if (layerTransformationsManager->isPrecisionPreserved(parent) && isQuantizationPerChannel(parent)) { if (layerTransformationsManager->isPrecisionPreserved(parent) && operationIsSupportedInConcat(parent)) {
if (!fillSubgraphForIntermediate(parent, handledLayers)) { if (!fillSubgraphForIntermediate(parent, handledLayers)) {
return false; return false;
} }
@ -197,7 +192,7 @@ bool Subgraph::fill(const std::shared_ptr<ngraph::Node>& layer, std::unordered_s
const std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantizeChild = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(child); const std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantizeChild = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(child);
if (fakeQuantizeChild != nullptr) { if (fakeQuantizeChild != nullptr) {
// //
} else if (layerTransformationsManager->isPrecisionPreserved(child) && isQuantizationPerChannel(child)) { } else if (layerTransformationsManager->isPrecisionPreserved(child) && operationIsSupportedInConcat(child)) {
if (!fillSubgraphForIntermediate(child, handledLayers)) { if (!fillSubgraphForIntermediate(child, handledLayers)) {
return false; return false;
} }
@ -221,6 +216,13 @@ bool Subgraph::empty() const {
} }
bool Subgraph::fillSubgraphForConcat(const std::shared_ptr<ngraph::opset1::Concat>& concat, std::unordered_set<std::string>& handledLayers) { bool Subgraph::fillSubgraphForConcat(const std::shared_ptr<ngraph::opset1::Concat>& concat, std::unordered_set<std::string>& handledLayers) {
const auto axis = concat->get_axis();
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, concat->get_output_partial_shape(0).rank());
// supported only per-channel concat
if (normalizedAxis != 1ul) {
return false;
}
concatLayers.push_back(concat); concatLayers.push_back(concat);
handledLayers.insert(concat->get_friendly_name()); handledLayers.insert(concat->get_friendly_name());
layers.emplace(concat->get_friendly_name(), concat); layers.emplace(concat->get_friendly_name(), concat);

View File

@ -229,9 +229,11 @@ LowPrecisionTransformations LowPrecisionTransformer::getAllTransformations(const
add<ReluTransformation, opset1::Relu>(params). add<ReluTransformation, opset1::Relu>(params).
add<ReshapeTransformation, opset1::Reshape>(params). add<ReshapeTransformation, opset1::Reshape>(params).
add<SqueezeTransformation, opset1::Squeeze>(params). add<SqueezeTransformation, opset1::Squeeze>(params).
add<SplitTransformation, opset1::Split>(params).
add<StridedSliceTransformation, opset1::StridedSlice>(params). add<StridedSliceTransformation, opset1::StridedSlice>(params).
add<TransposeTransformation, opset1::Transpose>(params). add<TransposeTransformation, opset1::Transpose>(params).
add<UnsqueezeTransformation, opset1::Unsqueeze>(params). add<UnsqueezeTransformation, opset1::Unsqueeze>(params).
add<VariadicSplitTransformation, opset1::VariadicSplit>(params).
addCleanup<FoldConvertTransformation, opset1::Subtract>(params). addCleanup<FoldConvertTransformation, opset1::Subtract>(params).
addCleanup<FuseConvertTransformation, opset1::Multiply>(params). addCleanup<FuseConvertTransformation, opset1::Multiply>(params).

View File

@ -20,26 +20,6 @@ void VariadicSplitTransformation::registerMatcherIn(GraphRewrite& pass, Transfor
make_op_label<opset1::Constant>() })); make_op_label<opset1::Constant>() }));
} }
std::vector<size_t> VariadicSplitTransformation::getConstSplitLengths(
const OutputVector& inputs,
const ngraph::Shape& constShape,
const size_t outputSize) const {
std::vector<size_t> lengths = as_type_ptr<opset1::Constant>(inputs[2].get_node_shared_ptr())->cast_vector<size_t>();
int64_t axis = as_type_ptr<opset1::Constant>(inputs[1].get_node_shared_ptr())->cast_vector<int64_t>()[0];
size_t splitedAxis = axis > 0 ? axis : inputs[0].get_shape().size() + axis;
if ((!constShape.empty()) && (constShape[splitedAxis] != 1)) {
std::vector<size_t> result(outputSize + 1);
result[0] = 0;
for (size_t i = 1; i < result.size(); ++i) {
result[i] = result[i - 1] + lengths[i - 1];
}
return result;
} else {
return std::vector<size_t>();
}
}
} // namespace low_precision } // namespace low_precision
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -74,6 +74,7 @@ class ConcatTransformationTestValues {
public: public:
ngraph::pass::low_precision::LayerTransformation::Params params; ngraph::pass::low_precision::LayerTransformation::Params params;
bool multiChannels; bool multiChannels;
std::int64_t axis;
ConcatTransformationActualValues actual; ConcatTransformationActualValues actual;
ConcatTransformationResultValues result; ConcatTransformationResultValues result;
}; };
@ -114,7 +115,8 @@ public:
testValues.actual.convert2, testValues.actual.convert2,
testValues.actual.dequantization2, testValues.actual.dequantization2,
ngraph::element::undefined, ngraph::element::undefined,
{}); {},
testValues.axis);
SimpleLowPrecisionTransformer transform; SimpleLowPrecisionTransformer transform;
if (testValues.multiChannels) { if (testValues.multiChannels) {
@ -146,7 +148,8 @@ public:
testValues.result.convert2, testValues.result.convert2,
testValues.result.dequantization2, testValues.result.dequantization2,
testValues.result.precisionAfterOperation, testValues.result.precisionAfterOperation,
testValues.result.dequantizationAfter); testValues.result.dequantizationAfter,
testValues.axis);
} }
static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) { static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
@ -158,6 +161,7 @@ public:
result << result <<
LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" << LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") << (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
"axis_" << testValues.axis << "_" <<
testValues.actual << "_" << testValues.actual << "_" <<
testValues.result << "_"; testValues.result << "_";
return result.str(); return result.str();
@ -180,6 +184,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -201,6 +206,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} },
{ ngraph::element::u8 }, { ngraph::element::u8 },
@ -232,6 +238,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
true, true,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} },
{ ngraph::element::u8 }, { ngraph::element::u8 },
@ -263,6 +270,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -290,6 +298,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
true, true,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -317,6 +326,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
1,
{ {
{ 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -340,6 +350,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
1,
{ {
{ 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -363,6 +374,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
true, true,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -386,6 +398,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
true, true,
1,
{ {
{ 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -409,6 +422,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
true, true,
1,
{ {
{ {
256ul, 256ul,
@ -450,6 +464,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
true, true,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -477,6 +492,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsI8I8(), LayerTransformation::createParamsI8I8(),
false, false,
1,
{ {
{ 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }, { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{}, {},
@ -500,6 +516,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -523,6 +540,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
true, true,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},
@ -546,6 +564,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
1,
{ {
{ 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }, { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{}, {},
@ -569,6 +588,7 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),
false, false,
1,
{ {
{ 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {2.3007815f} }, { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {2.3007815f} },
{}, {},
@ -588,10 +608,61 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ ngraph::element::f32, { 128 }, { 0.0302619f } } { ngraph::element::f32, { 128 }, { 0.0302619f } }
} }
}, },
// U8: concat multi channels with subtract, negative axis
{
LayerTransformation::createParamsU8I8(),
true,
-3,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
{ 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} },
{},
{}
},
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
{},
{},
{ 256ul, {}, {1.275f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
{},
{},
ngraph::element::u8,
{
ngraph::element::f32,
{{ 0.f, 0.f, 0.f, -255.f, -255.f, -255.f }},
{{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }}
}
}
},
// U8: concat multi channels with subtract, not supported axis
{
LayerTransformation::createParamsU8I8(),
true,
0,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
{ 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} },
{},
{}
},
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
{ 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} },
{},
{}
},
},
// not update precisions // not update precisions
{ {
LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
false, false,
1,
{ {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{}, {},

View File

@ -217,6 +217,40 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ ngraph::element::f32, {}, { 0.005f } } { ngraph::element::f32, {}, { 0.005f } }
} }
}, },
// U8: concat multi channels with per-channel quantization
{
{ 1, 6, 10, 10 },
LayerTransformation::createParamsU8I8(),
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} },
{
256ul,
ngraph::Shape({ 1, 6, 1, 1 }),
{0.f, 0.f, 0.f, 0.f, 0.f, 0.f},
{255.f, 25.5f, 2.55f, 25.5f, 255.f, 2.55f},
{0.f, 0.f, 0.f, 0.f, 0.f, 0.f},
{255.f, 25.5f, 2.55f, 25.5f, 255.f, 2.55f}
}
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {255.f}},
{
256ul,
ngraph::Shape({ 1, 6, 1, 1 }),
{0.f, 0.f, 0.f, 0.f, 0.f, 0.f},
{255.f, 25.5f, 2.55f, 25.5f, 255.f, 2.55f},
{0.f},
{255.f}
},
ngraph::element::u8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::u8,
{ ngraph::element::f32, {}, {{ 0.005f, 0.005f, 0.005f, 1.f, 0.1f, 0.01f }} },
{ ngraph::element::f32, {}, {{ 0.1f, 1.f, 0.01f }} }
}
},
// I8: concat multi channels // I8: concat multi channels
{ {
{ 1, 6, 10, 10 }, { 1, 6, 10, 10 },
@ -259,9 +293,8 @@ const std::vector<ConcatTransformationTestValues> testValues = {
}, },
}; };
// TODO: Split/VariadicSplit operations are not supported in ConcatTransformation
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
DISABLED_smoke_LPT, smoke_LPT,
ConcatWithSplitTransformation, ConcatWithSplitTransformation,
::testing::Combine( ::testing::Combine(
::testing::ValuesIn(precisions), ::testing::ValuesIn(precisions),

View File

@ -160,21 +160,30 @@ const std::vector<SplitTransformationTestValues> testValues = {
{}, {},
ngraph::element::u8, ngraph::element::u8,
{ {
{ {{ngraph::element::f32}, {1.f}, {11.f}},
{ngraph::element::f32}, {{ngraph::element::f32}, {2.f}, {22.f}},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}}, {{ngraph::element::f32}, {3.f}, {33.f}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}} }
}, }
{ },
{ngraph::element::f32}, // U8 per channel quantization with different values (constants without batch)
{{2.f}, ngraph::element::f32, {1, 1, 1, 1}}, {
{{22.f}, ngraph::element::f32, {1, 1, 1, 1}} ngraph::Shape({ 1, 3, 16, 16 }), std::int64_t{-3}, size_t{3},
}, LayerTransformation::createParamsU8I8(),
{ {
{ngraph::element::f32}, ngraph::element::u8,
{{3.f}, ngraph::element::f32, {1, 1, 1, 1}}, {{ngraph::element::f32},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}} {{1.f, 2.f, 3.f}, ngraph::element::f32, {3, 1, 1}},
}, {{11.f, 22.f, 33.f}, ngraph::element::f32, {3, 1, 1}}}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{{ngraph::element::f32}, {1.f}, {11.f}},
{{ngraph::element::f32}, {2.f}, {22.f}},
{{ngraph::element::f32}, {3.f}, {33.f}},
} }
} }
}, },
@ -193,21 +202,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{}, {},
ngraph::element::i8, ngraph::element::i8,
{ {
{ {{ngraph::element::f32}, {1.f}, {11.f}},
{ngraph::element::f32}, {{ngraph::element::f32}, {2.f}, {22.f}},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}}, {{ngraph::element::f32}, {3.f}, {33.f}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{2.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{22.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{3.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
} }
} }
}, },
@ -226,21 +223,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{}, {},
ngraph::element::u8, ngraph::element::u8,
{ {
{ {{ngraph::element::f32}, {1.f}, {11.f}},
{ngraph::element::f32}, {{ngraph::element::f32}, {1.f}, {11.f}},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}}, {{ngraph::element::f32}, {1.f}, {11.f}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
} }
} }
}, },
@ -259,21 +244,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{}, {},
ngraph::element::i8, ngraph::element::i8,
{ {
{ {{ngraph::element::f32}, {1.f}, {11.f}},
{ngraph::element::f32}, {{ngraph::element::f32}, {1.f}, {11.f}},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}}, {{ngraph::element::f32}, {1.f}, {11.f}}
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
} }
} }
}, },
@ -358,21 +331,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{}, {},
ngraph::element::u8, ngraph::element::u8,
{ {
{ {{ngraph::element::f32}, {}, {11.f}},
{ngraph::element::f32}, {{ngraph::element::f32}, {}, {22.f}},
{}, {{ngraph::element::f32}, {}, {33.f}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{},
{{22.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
} }
} }
}, },
@ -391,21 +352,9 @@ const std::vector<SplitTransformationTestValues> testValues = {
{}, {},
ngraph::element::i8, ngraph::element::i8,
{ {
{ {{ngraph::element::f32}, {}, {11.f}},
{ngraph::element::f32}, {{ngraph::element::f32}, {}, {22.f}},
{}, {{ngraph::element::f32}, {}, {33.f}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{},
{{22.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{
{ngraph::element::f32},
{},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
} }
} }
}, },

View File

@ -150,6 +150,17 @@ StridedSliceTransformationTestValues::LayerParams specialDimensionSlice = {
{} {}
}; };
StridedSliceTransformationTestValues::LayerParams specialDimensionEndSlice = {
{ 0, 0, 20, 0 },
{ 1, 3, 24, 24 },
{ 1, 1, 1, 1 },
{ 1, 1, 0, 1 },
{ 1, 1, 0, 1 },
{},
{},
{}
};
const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformationTestValues = { const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformationTestValues = {
// U8: channel slice, per-tensor quantization // U8: channel slice, per-tensor quantization
{ {
@ -311,6 +322,38 @@ const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformati
{{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}} {{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}}
} }
}, },
// I8: special dimension end slice, per-channel quantization with different values
{
ngraph::Shape{1, 3, 24, 24},
LayerTransformation::createParamsI8I8(),
specialDimensionEndSlice,
{
ngraph::element::i8,
{{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}}
},
{
ngraph::element::i8,
{},
ngraph::element::i8,
{{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}}
}
},
// I8: special dimension end slice, per-tensor quantization with different values
{
ngraph::Shape{1, 3, 24, 24},
LayerTransformation::createParamsI8I8(),
specialDimensionEndSlice,
{
ngraph::element::i8,
{{ngraph::element::f32}, { 32.f }, { 0.1f }}
},
{
ngraph::element::i8,
{},
ngraph::element::i8,
{{ngraph::element::f32}, { 32.f }, { 0.1f }}
}
},
// I8: channel slice, quantization by special dimension // I8: channel slice, quantization by special dimension
{ {
ngraph::Shape{1, 3, 4, 4}, ngraph::Shape{1, 3, 4, 4},

View File

@ -177,11 +177,31 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{{1.f, 2.f}, ngraph::element::f32, {1, 2, 1, 1}}, {{1.f, 2.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{11.f, 22.f}, ngraph::element::f32, {1, 2, 1, 1}} {{11.f, 22.f}, ngraph::element::f32, {1, 2, 1, 1}}
}, },
{{ngraph::element::f32}, {3.f}, {33.f}}
}
}
},
// U8 per channel quantization with different values (constants without batch)
{
ngraph::Shape({ 1, 3, 16, 16 }), std::int64_t{ -3 }, std::vector<size_t>{ 2, 1 },
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32},
{{1.f, 2.f, 3.f}, ngraph::element::f32, {3, 1, 1}},
{{11.f, 22.f, 33.f}, ngraph::element::f32, {3, 1, 1}}}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{ {
{ngraph::element::f32}, {ngraph::element::f32},
{{3.f}, ngraph::element::f32, {1, 1, 1, 1}}, {{1.f, 2.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}} {{11.f, 22.f}, ngraph::element::f32, {1, 2, 1, 1}}
} },
{{ngraph::element::f32}, {3.f}, {33.f}}
} }
} }
}, },
@ -205,11 +225,7 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{{1.f, 2.f}, ngraph::element::f32, {1, 2, 1, 1}}, {{1.f, 2.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{11.f, 22.f}, ngraph::element::f32, {1, 2, 1, 1}} {{11.f, 22.f}, ngraph::element::f32, {1, 2, 1, 1}}
}, },
{ {{ngraph::element::f32}, {3.f}, {33.f}}
{ngraph::element::f32},
{{3.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{33.f}, ngraph::element::f32, {1, 1, 1, 1}}
}
} }
} }
}, },
@ -228,16 +244,8 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{}, {},
ngraph::element::u8, ngraph::element::u8,
{ {
{ {{ngraph::element::f32}, {1.f}, {11.f}},
{ngraph::element::f32}, {{ngraph::element::f32}, {1.f}, {11.f}}
{{1.f, 1.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{11.f, 11.f}, ngraph::element::f32, {1, 2, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
}
} }
} }
}, },
@ -256,16 +264,8 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{}, {},
ngraph::element::i8, ngraph::element::i8,
{ {
{ {{ngraph::element::f32}, {1.f}, {11.f}},
{ngraph::element::f32}, {{ngraph::element::f32}, {1.f}, {11.f}}
{{1.f, 1.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{11.f, 11.f}, ngraph::element::f32, {1, 2, 1, 1}}
},
{
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
}
} }
} }
}, },
@ -322,21 +322,13 @@ const std::vector<VariadicSplitTransformationTestValues> testValues = {
{}, {},
ngraph::element::i8, ngraph::element::i8,
{ {
{ {{ngraph::element::f32}, {1.f}, {11.f}},
{ngraph::element::f32},
{{1.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{11.f}, ngraph::element::f32, {1, 1, 1, 1}}
},
{ {
{ngraph::element::f32}, {ngraph::element::f32},
{{2.f, 3.f}, ngraph::element::f32, {1, 2, 1, 1}}, {{2.f, 3.f}, ngraph::element::f32, {1, 2, 1, 1}},
{{22.f, 33.f}, ngraph::element::f32, {1, 2, 1, 1}} {{22.f, 33.f}, ngraph::element::f32, {1, 2, 1, 1}}
}, },
{ {{ngraph::element::f32}, {4.f}, {44.f}}
{ngraph::element::f32},
{{4.f}, ngraph::element::f32, {1, 1, 1, 1}},
{{44.f}, ngraph::element::f32, {1, 1, 1, 1}}
}
} }
} }
}, },

View File

@ -45,8 +45,7 @@ const std::vector<ConcatWithSplitTransformationParam> testValues = {
} }
}; };
// TODO: Split/VariadicSplit operations are not supported in ConcatTransformation INSTANTIATE_TEST_CASE_P(smoke_LPT, ConcatWithSplitTransformation,
INSTANTIATE_TEST_CASE_P(DISABLED_smoke_LPT, ConcatWithSplitTransformation,
::testing::Combine( ::testing::Combine(
::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(netPrecisions),
::testing::Values(ngraph::Shape({ 1, 6, 10, 10 })), ::testing::Values(ngraph::Shape({ 1, 6, 10, 10 })),

View File

@ -45,8 +45,7 @@ const std::vector<ConcatWithSplitTransformationParam> testValues = {
} }
}; };
// TODO: Split/VariadicSplit operations are not supported in ConcatTransformation INSTANTIATE_TEST_CASE_P(smoke_LPT, ConcatWithSplitTransformation,
INSTANTIATE_TEST_CASE_P(DISABLED_smoke_LPT, ConcatWithSplitTransformation,
::testing::Combine( ::testing::Combine(
::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(netPrecisions),
::testing::Values(ngraph::Shape({ 1, 6, 10, 10 })), ::testing::Values(ngraph::Shape({ 1, 6, 10, 10 })),

View File

@ -114,7 +114,8 @@ public:
const DequantizationOperations::Convert& convert2, const DequantizationOperations::Convert& convert2,
const DequantizationOperations& dequantization2, const DequantizationOperations& dequantization2,
const ngraph::element::Type precisionAfterOperation, const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationAfter); const DequantizationOperations& dequantizationAfter,
const std::int64_t& axis);
static std::shared_ptr<ngraph::Function> getReferenceWithNeighbors( static std::shared_ptr<ngraph::Function> getReferenceWithNeighbors(
const ngraph::element::Type precision, const ngraph::element::Type precision,

View File

@ -752,7 +752,8 @@ std::shared_ptr<ngraph::Function> ConcatFunction::get(
const DequantizationOperations::Convert& convert2, const DequantizationOperations::Convert& convert2,
const DequantizationOperations& dequantization2, const DequantizationOperations& dequantization2,
const ngraph::element::Type precisionAfterOperation, const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationAfter) { const DequantizationOperations& dequantizationAfter,
const std::int64_t& axis) {
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape); const auto input1 = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
input1->set_friendly_name("input1"); input1->set_friendly_name("input1");
@ -775,7 +776,7 @@ std::shared_ptr<ngraph::Function> ConcatFunction::get(
parent2 = makeDequantization(parent2, dequantization2); parent2 = makeDequantization(parent2, dequantization2);
} }
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{ parent1, parent2 }, 1); const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{ parent1, parent2 }, axis);
auto& rtInfo = concat->get_rt_info(); auto& rtInfo = concat->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat"); rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
@ -989,6 +990,13 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedInterme
input2->set_friendly_name("input2"); input2->set_friendly_name("input2");
const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(input2, precision, fqOnData2); const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(input2, precision, fqOnData2);
replace_node(
fakeQuantize2->get_input_node_shared_ptr(3),
ngraph::pass::low_precision::NetworkHelper::toScalarIfPossible(fakeQuantize2->get_input_node_shared_ptr(3)));
replace_node(
fakeQuantize2->get_input_node_shared_ptr(4),
ngraph::pass::low_precision::NetworkHelper::toScalarIfPossible(fakeQuantize2->get_input_node_shared_ptr(4)));
fakeQuantize2->set_friendly_name("fakeQuantize2"); fakeQuantize2->set_friendly_name("fakeQuantize2");
low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize2, precisionAfterOperation); low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize2, precisionAfterOperation);
const auto deqBefore2 = makeDequantization(fakeQuantize2, dequantizationBefore1); const auto deqBefore2 = makeDequantization(fakeQuantize2, dequantizationBefore1);