do not change names while replacing FQ nodes (#17936)
* do not change names while replacing FQ nodes * more solid fix * even more solid fix * fix UT failure
This commit is contained in:
parent
a4519f0a2c
commit
cb63b39c72
@ -36,7 +36,9 @@ public:
|
|||||||
const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support);
|
const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool decomposeFakeQuantizeForWeightsPath(const std::shared_ptr<Node>& weightableLayer, size_t outChannelsShapeIndex = 0ul) const;
|
std::tuple<bool, std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantizeForWeightsPath(
|
||||||
|
const std::shared_ptr<Node>& weightableLayer,
|
||||||
|
size_t outChannelsShapeIndex = 0ul) const;
|
||||||
static bool isGroup(const std::shared_ptr<Node>& node);
|
static bool isGroup(const std::shared_ptr<Node>& node);
|
||||||
static bool isDepthwise(const std::shared_ptr<Node>& node);
|
static bool isDepthwise(const std::shared_ptr<Node>& node);
|
||||||
virtual size_t getInputChannels(const std::shared_ptr<ngraph::Node> conv) const = 0;
|
virtual size_t getInputChannels(const std::shared_ptr<ngraph::Node> conv) const = 0;
|
||||||
|
@ -88,7 +88,14 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
|
|||||||
|
|
||||||
convolution = NetworkHelper::separateInStandaloneBranch(convolution, defaultPrecisions);
|
convolution = NetworkHelper::separateInStandaloneBranch(convolution, defaultPrecisions);
|
||||||
|
|
||||||
const bool fqOnWeightsWasDecomposed = decomposeFakeQuantizeForWeightsPath(convolution);
|
const auto& res_tuple = decomposeFakeQuantizeForWeightsPath(convolution);
|
||||||
|
|
||||||
|
auto fqOnWeightsWasDecomposed = std::get<0>(res_tuple);
|
||||||
|
auto newFQ = std::get<1>(res_tuple);
|
||||||
|
auto dequantize = std::get<2>(res_tuple);
|
||||||
|
if (newFQ != nullptr && dequantize != nullptr)
|
||||||
|
updateOutput(context, dequantize, newFQ);
|
||||||
|
|
||||||
if (updatePrecisions && !fqOnWeightsWasDecomposed) {
|
if (updatePrecisions && !fqOnWeightsWasDecomposed) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -141,7 +141,12 @@ bool ConvolutionBackpropDataTransformation::transform(TransformationContext &con
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
decomposeFakeQuantizeForWeightsPath(convolutionBackpropData, 1ul);
|
const auto& res_tuple = decomposeFakeQuantizeForWeightsPath(convolutionBackpropData, 1ul);
|
||||||
|
auto newFQ = std::get<1>(res_tuple);
|
||||||
|
auto dequantize = std::get<2>(res_tuple);
|
||||||
|
if (newFQ != nullptr && dequantize != nullptr)
|
||||||
|
updateOutput(context, dequantize, newFQ);
|
||||||
|
|
||||||
dequantization = NetworkHelper::getDequantization(convolutionBackpropData, defaultPrecisions, 1ul);
|
dequantization = NetworkHelper::getDequantization(convolutionBackpropData, defaultPrecisions, 1ul);
|
||||||
|
|
||||||
if (const auto fq = ov::as_type_ptr<ov::opset1::FakeQuantize>(dequantization.data.get_node_shared_ptr())) {
|
if (const auto fq = ov::as_type_ptr<ov::opset1::FakeQuantize>(dequantization.data.get_node_shared_ptr())) {
|
||||||
|
@ -324,11 +324,13 @@ bool WeightableLayerTransformation::isPrecisionPreserved(std::shared_ptr<Node> l
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool WeightableLayerTransformation::decomposeFakeQuantizeForWeightsPath(const std::shared_ptr<Node>& node, const size_t outChannelsShapeIndex) const {
|
std::tuple<bool, std::shared_ptr<Node>, std::shared_ptr<Node>> WeightableLayerTransformation::decomposeFakeQuantizeForWeightsPath(
|
||||||
|
const std::shared_ptr<Node>& node,
|
||||||
|
const size_t outChannelsShapeIndex) const {
|
||||||
const auto fq = getFakeQuantizeOnWeights(node);
|
const auto fq = getFakeQuantizeOnWeights(node);
|
||||||
if (fq == nullptr) {
|
if (fq == nullptr) {
|
||||||
// FakeQuantize has been decomposed already
|
// FakeQuantize has been decomposed already
|
||||||
return true;
|
return std::make_tuple(true, nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fq);
|
const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fq);
|
||||||
@ -339,7 +341,7 @@ bool WeightableLayerTransformation::decomposeFakeQuantizeForWeightsPath(const st
|
|||||||
|
|
||||||
const DataPrecision dataPrecision = getDataPrecision(fq, quantizationDetails, precisions);
|
const DataPrecision dataPrecision = getDataPrecision(fq, quantizationDetails, precisions);
|
||||||
if (dataPrecision.empty()) {
|
if (dataPrecision.empty()) {
|
||||||
return false;
|
return std::make_tuple(false, nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tuple = NetworkHelper::decomposeFakeQuantize(
|
auto tuple = NetworkHelper::decomposeFakeQuantize(
|
||||||
@ -352,17 +354,19 @@ bool WeightableLayerTransformation::decomposeFakeQuantizeForWeightsPath(const st
|
|||||||
element::f32,
|
element::f32,
|
||||||
outChannelsShapeIndex);
|
outChannelsShapeIndex);
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> fqOnWeights = std::get<0>(tuple);
|
std::shared_ptr<Node> fqOnWeights = std::get<0>(tuple);
|
||||||
|
std::shared_ptr<Node> dequantize = std::get<1>(tuple);
|
||||||
|
|
||||||
// TODO: LPT: issue #58685
|
// TODO: LPT: issue #58685
|
||||||
if ((!updatePrecisions) && (fqOnWeights == nullptr)) {
|
if ((!updatePrecisions) && (fqOnWeights == nullptr)) {
|
||||||
return false;
|
return std::make_tuple(false, nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ov::as_type_ptr<ov::opset1::Constant>(fqOnWeights) == nullptr) {
|
if (ov::as_type_ptr<ov::opset1::Constant>(fqOnWeights) == nullptr) {
|
||||||
THROW_IE_LPT_EXCEPTION(*fqOnWeights) << "FakeQuantize on weights was not folded to constant";
|
THROW_IE_LPT_EXCEPTION(*fqOnWeights) << "FakeQuantize on weights was not folded to constant";
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return std::make_tuple(true, fqOnWeights, dequantize);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool WeightableLayerTransformation::isGroup(const std::shared_ptr<Node>& layer) {
|
bool WeightableLayerTransformation::isGroup(const std::shared_ptr<Node>& layer) {
|
||||||
|
Loading…
Reference in New Issue
Block a user