[LPT] Avoid using std::shared_ptr<Node> when creating a node (#7357)

* [LPT] Avoid using std::shared_ptr<Node> when creating a node

* [LPT] removed unused files

* [LPT] D2STransformation: transform & isPrecisionPreserved methods are moved to base class

* [LPT] Revert redundant changes
This commit is contained in:
Vladislav Golubev 2021-09-27 11:50:17 +03:00 committed by GitHub
parent 818f385398
commit 7fa9bbf6fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 158 additions and 267 deletions

View File

@ -39,8 +39,6 @@ protected:
NodeVector& convertNodes, NodeVector& convertNodes,
NodeVector& subtractNodes, NodeVector& subtractNodes,
NodeVector& multiplyNodes) const; NodeVector& multiplyNodes) const;
std::shared_ptr<Node> concatenateDeqNodes(NodeVector& nodes) const;
}; };
} // namespace low_precision } // namespace low_precision

View File

@ -14,8 +14,6 @@ class LP_TRANSFORMATIONS_API DepthToSpaceTransformation : public TransparentBase
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
DepthToSpaceTransformation(const Params& params = Params()); DepthToSpaceTransformation(const Params& params = Params());
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override; bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
}; };

View File

@ -18,6 +18,7 @@ public:
~TransparentBaseTransformation() override {}; ~TransparentBaseTransformation() override {};
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override; bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
}; };
} // namespace low_precision } // namespace low_precision

View File

@ -176,13 +176,13 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
// after : Y = SC2 * ( SC1' * (X1 - SH1') + X2 ) , where : // after : Y = SC2 * ( SC1' * (X1 - SH1') + X2 ) , where :
// SC1' = SC1 / SC2 // SC1' = SC1 / SC2
// SH1' = SH1 + SC2 * SH2 / SC1 // SH1' = SH1 + SC2 * SH2 / SC1
std::shared_ptr<Node> newSubtractFullPathValues = fold<opset1::Add>( auto newSubtractFullPathValues = fold<opset1::Add>(
subtractFullPathValues, subtractFullPathValues,
fold<opset1::Divide>( fold<opset1::Divide>(
fold<opset1::Multiply>(subtractEmptyPathValues, multiplyEmptyPathValues), fold<opset1::Multiply>(subtractEmptyPathValues, multiplyEmptyPathValues),
multiplyFullPathValues)); multiplyFullPathValues));
std::shared_ptr<Node> newMultiplyFullPathValues = fold<opset1::Divide>(multiplyFullPathValues, multiplyEmptyPathValues); auto newMultiplyFullPathValues = fold<opset1::Divide>(multiplyFullPathValues, multiplyEmptyPathValues);
if (NetworkHelper::isZeroConst(newSubtractFullPathValues)) { if (NetworkHelper::isZeroConst(newSubtractFullPathValues)) {
newSubtractFullPathValues = nullptr; newSubtractFullPathValues = nullptr;

View File

@ -1,19 +0,0 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision/common/operation_precision_restriction.hpp"
#include <memory>
#include <unordered_set>
#include <set>
#include <vector>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/precisions_attribute.hpp"
using namespace ngraph;

View File

@ -70,20 +70,11 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
} }
} }
auto broadcastElementWiseConst = []( // FakeQuantize constant shape must be broadcastable to the shape on data.
// FakeQuantize constant shape must be broadcastable to the shape on data. auto broadcastElementWiseConst = [](std::shared_ptr<opset1::Constant> operation, const Shape targetShape) {
std::shared_ptr<ngraph::opset1::Constant> operation, auto targetShapeConst = std::make_shared<opset1::Constant>(element::i64, Shape{ targetShape.size() }, targetShape);
const ngraph::Shape targetShape) -> std::shared_ptr<Node> { auto broadcast = fold<ngraph::opset1::Broadcast>(operation, targetShapeConst);
auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>( return broadcast;
element::i64, ngraph::Shape{ targetShape.size() },
targetShape);
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
operation,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
return broadcast;
}; };
bool someDqInLowPrecision = std::any_of( bool someDqInLowPrecision = std::any_of(
@ -247,15 +238,8 @@ void ConcatTransformation::fillDequantizationNodes(
// FakeQuantize constant shape must be broadcastable to the shape on data. // FakeQuantize constant shape must be broadcastable to the shape on data.
std::shared_ptr<ngraph::opset1::Constant> operation, std::shared_ptr<ngraph::opset1::Constant> operation,
const ngraph::Shape targetShape) -> std::shared_ptr<Node> { const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>( auto targetShapeConst = opset1::Constant::create(element::i64, ngraph::Shape{ targetShape.size() }, targetShape);
element::i64, ngraph::Shape{ targetShape.size() }, auto broadcast = fold<ngraph::opset1::Broadcast>(operation, targetShapeConst);
targetShape);
auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
operation,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);
return broadcast; return broadcast;
}; };
@ -308,10 +292,6 @@ void ConcatTransformation::fillDequantizationNodes(
} }
} }
std::shared_ptr<Node> ConcatTransformation::concatenateDeqNodes(NodeVector& nodes) const {
return nodes.size() == 1ul ? nodes[0] : fold<ngraph::opset1::Concat>(nodes, 1);
}
bool ConcatTransformation::isHandled(const TransformationContext& context, const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations) { bool ConcatTransformation::isHandled(const TransformationContext& context, const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations) {
for (const std::shared_ptr<ngraph::Node>& quantizationLayer : quantizationOperations) { for (const std::shared_ptr<ngraph::Node>& quantizationLayer : quantizationOperations) {
if (context.quantizedFakeQuantizeNames.find(quantizationLayer->get_friendly_name()) != context.quantizedFakeQuantizeNames.end()) { if (context.quantizedFakeQuantizeNames.find(quantizationLayer->get_friendly_name()) != context.quantizedFakeQuantizeNames.end()) {

View File

@ -49,7 +49,7 @@ bool ConvertTransformation::transform(TransformationContext& context, ngraph::pa
const ngraph::element::Type precisionBefore = convert->get_input_element_type(0); const ngraph::element::Type precisionBefore = convert->get_input_element_type(0);
std::shared_ptr<opset1::Subtract> subtract = std::make_shared<op::TypeRelaxed<opset1::Subtract>>( std::shared_ptr<opset1::Subtract> subtract = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
convert->get_input_node_shared_ptr(0), convert->input_value(0),
std::make_shared<opset1::Constant>(precisionBefore, Shape{}, std::vector<size_t>({ 0 }))); std::make_shared<opset1::Constant>(precisionBefore, Shape{}, std::vector<size_t>({ 0 })));
NetworkHelper::setOutDataPrecision(subtract, convert->get_output_element_type(0)); NetworkHelper::setOutDataPrecision(subtract, convert->get_output_element_type(0));

View File

@ -181,7 +181,7 @@ bool ConvolutionBackpropDataTransformation::transform(TransformationContext &con
zeroPointShape[1] = static_cast<size_t>(weightsPShape[1].get_length()); zeroPointShape[1] = static_cast<size_t>(weightsPShape[1].get_length());
auto zeroPointConstant = fold<opset1::Broadcast>( auto zeroPointConstant = fold<opset1::Broadcast>(
subtractFromWeights->get_input_node_shared_ptr(1), subtractFromWeights->input_value(1),
std::make_shared<opset1::Constant>(element::i32, Shape{zeroPointShape.size()}, zeroPointShape)); std::make_shared<opset1::Constant>(element::i32, Shape{zeroPointShape.size()}, zeroPointShape));
replace_node(subtractFromWeights->get_input_node_shared_ptr(1), zeroPointConstant); replace_node(subtractFromWeights->get_input_node_shared_ptr(1), zeroPointConstant);
} }

View File

@ -1,22 +0,0 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision/create_precisions_dependent_attribute.hpp"
#include <assert.h>
#include <deque>
#include <memory>
#include <unordered_map>
#include <set>
#include <vector>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include "low_precision/rt_info/precisions_attribute.hpp"
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
#include "low_precision/network_helper.hpp"
using namespace ngraph;
using namespace ngraph::pass::low_precision;

View File

@ -29,21 +29,6 @@ DepthToSpaceTransformation::DepthToSpaceTransformation(const Params& params) : T
this->register_matcher(m, callback); this->register_matcher(m, callback);
} }
bool DepthToSpaceTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) {
std::shared_ptr<Node> depthToSpace = m.get_match_root();
if (!canBeTransformed(context, depthToSpace)) {
return false;
}
depthToSpace = NetworkHelper::separateInStandaloneBranch(depthToSpace);
moveDequantizationAfter(context, depthToSpace, NetworkHelper::getDequantization(depthToSpace), true);
return true;
}
bool DepthToSpaceTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return true;
}
bool DepthToSpaceTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const { bool DepthToSpaceTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
if (!LayerTransformation::canBeTransformed(context, layer)) { if (!LayerTransformation::canBeTransformed(context, layer)) {
return false; return false;

View File

@ -67,7 +67,7 @@ static std::shared_ptr<Node> updateShape(std::shared_ptr<Node> constantOp, const
return constantOp; return constantOp;
} }
static std::shared_ptr<Node> getData(const std::shared_ptr<Node>& eltwise) { static std::shared_ptr<Node> getDataNode(const std::shared_ptr<Node>& eltwise) {
if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) { if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
return eltwise->get_input_node_shared_ptr(0); return eltwise->get_input_node_shared_ptr(0);
} }
@ -123,7 +123,7 @@ bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& e
} }
} }
return fq::getData(eltwise) != nullptr; return fq::getDataNode(eltwise) != nullptr;
} }
std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwise( std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwise(
@ -132,8 +132,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const { const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0); const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
std::shared_ptr<Node> inputLowConst_f32 = foldConvert(fakeQuantize->get_input_node_shared_ptr(1), deqPrecision); std::shared_ptr<Node> inputLowConst_f32 = foldConvert(fakeQuantize->input_value(1), deqPrecision);
std::shared_ptr<Node> inputHighConst_f32 = foldConvert(fakeQuantize->get_input_node_shared_ptr(2), deqPrecision); std::shared_ptr<Node> inputHighConst_f32 = foldConvert(fakeQuantize->input_value(2), deqPrecision);
std::shared_ptr<opset1::Constant> constant = fq::getConstant(eltwise); std::shared_ptr<opset1::Constant> constant = fq::getConstant(eltwise);
if (ov::is_type<opset1::Multiply>(eltwise) && checkElementwise(eltwise)) { if (ov::is_type<opset1::Multiply>(eltwise) && checkElementwise(eltwise)) {
@ -166,10 +166,10 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
inputLowConst_f32 = fq::updateShape(fold<opset1::Add>(inputLowConst_f32, value), fakeQuantize->get_output_partial_shape(0)); inputLowConst_f32 = fq::updateShape(fold<opset1::Add>(inputLowConst_f32, value), fakeQuantize->get_output_partial_shape(0));
inputHighConst_f32 = fq::updateShape(fold<opset1::Add>(inputHighConst_f32, value), fakeQuantize->get_output_partial_shape(0)); inputHighConst_f32 = fq::updateShape(fold<opset1::Add>(inputHighConst_f32, value), fakeQuantize->get_output_partial_shape(0));
} else if (ov::is_type<opset1::Add>(eltwise) && checkElementwise(eltwise)) { } else if (ov::is_type<opset1::Add>(eltwise) && checkElementwise(eltwise)) {
if (ov::is_type<opset1::Convolution>(fq::getData(eltwise)) || if (ov::is_type<opset1::Convolution>(fq::getDataNode(eltwise)) ||
ov::is_type<opset1::GroupConvolution>(fq::getData(eltwise)) || ov::is_type<opset1::GroupConvolution>(fq::getDataNode(eltwise)) ||
ov::is_type<opset1::ConvolutionBackpropData>(fq::getData(eltwise)) || ov::is_type<opset1::ConvolutionBackpropData>(fq::getDataNode(eltwise)) ||
ov::is_type<opset1::GroupConvolutionBackpropData>(fq::getData(eltwise))) { ov::is_type<opset1::GroupConvolutionBackpropData>(fq::getDataNode(eltwise))) {
return nullptr; return nullptr;
} }
@ -189,7 +189,7 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
return nullptr; return nullptr;
} }
const auto data = fq::getData(eltwise); const auto data = fq::getDataNode(eltwise);
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise); const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
const auto newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({ const auto newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({

View File

@ -42,7 +42,7 @@ bool FoldConvertTransformation::transform(TransformationContext& context, ngraph
return; return;
} }
const auto resultConstant = ngraph::pass::low_precision::foldConvert(convert->get_input_node_shared_ptr(0), convert->output(0).get_element_type()); const auto resultConstant = ngraph::pass::low_precision::foldConvert(convert->input_value(0), convert->get_output_element_type(0));
assert(ov::is_type<opset1::Constant>(resultConstant)); assert(ov::is_type<opset1::Constant>(resultConstant));
replace_node(convert, resultConstant); replace_node(convert, resultConstant);

View File

@ -47,8 +47,8 @@ std::shared_ptr<Node> removeConvertIfPossibleForSubtract(
if (NetworkHelper::checkConstantValuePrecision(precisionBeforeConvert, subtract->get_input_node_shared_ptr(1))) { if (NetworkHelper::checkConstantValuePrecision(precisionBeforeConvert, subtract->get_input_node_shared_ptr(1))) {
newSubtract = std::make_shared<ngraph::op::TypeRelaxed<opset1::Subtract>>( newSubtract = std::make_shared<ngraph::op::TypeRelaxed<opset1::Subtract>>(
std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{}, std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
ngraph::op::TemporaryReplaceOutputType(convert->get_input_source_output(0), element::f32).get(), ngraph::op::TemporaryReplaceOutputType(convert->input_value(0), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(subtract->get_input_node_shared_ptr(1), element::f32).get()); ngraph::op::TemporaryReplaceOutputType(subtract->input_value(1), element::f32).get());
NetworkHelper::setOutDataPrecisionForTypeRelaxed(newSubtract, subtract->get_output_element_type(0)); NetworkHelper::setOutDataPrecisionForTypeRelaxed(newSubtract, subtract->get_output_element_type(0));
replace_node(subtract, newSubtract); replace_node(subtract, newSubtract);
} }
@ -63,11 +63,11 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ngraph
} }
const auto convert = ov::as_type_ptr<opset1::Convert>(op->get_input_node_shared_ptr(0)); const auto convert = ov::as_type_ptr<opset1::Convert>(op->get_input_node_shared_ptr(0));
std::shared_ptr<Node> parent = convert->get_input_node_shared_ptr(0); auto parent = convert->input_value(0);
if (ov::is_type<opset1::Constant>(parent)) { if (ov::is_type<opset1::Constant>(parent.get_node_shared_ptr())) {
auto convertedConstant = foldConvert(parent, convert->get_convert_element_type()); auto convertedConstant = foldConvert(parent, convert->get_convert_element_type());
NetworkHelper::copyInfo(parent, convertedConstant); NetworkHelper::copyInfo(parent.get_node_shared_ptr(), convertedConstant);
replace_node(convert, convertedConstant); replace_node(convert, convertedConstant);
} else { } else {
std::shared_ptr<Node> newOp; std::shared_ptr<Node> newOp;
@ -77,15 +77,15 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ngraph
} else if (ov::is_type<opset1::Multiply>(op)) { } else if (ov::is_type<opset1::Multiply>(op)) {
newOp = std::make_shared<ngraph::op::TypeRelaxed<opset1::Multiply>>( newOp = std::make_shared<ngraph::op::TypeRelaxed<opset1::Multiply>>(
std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{}, std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
ngraph::op::TemporaryReplaceOutputType(convert->get_input_source_output(0), element::f32).get(), ngraph::op::TemporaryReplaceOutputType(convert->input_value(0), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(op->get_input_node_shared_ptr(1), element::f32).get()); ngraph::op::TemporaryReplaceOutputType(op->input_value(1), element::f32).get());
NetworkHelper::setOutDataPrecisionForTypeRelaxed(newOp, op->get_output_element_type(0)); NetworkHelper::setOutDataPrecisionForTypeRelaxed(newOp, op->get_output_element_type(0));
replace_node(op, newOp); replace_node(op, newOp);
} else if (ov::is_type<opset1::Add>(op)) { } else if (ov::is_type<opset1::Add>(op)) {
newOp = std::make_shared<ngraph::op::TypeRelaxed<opset1::Add>>( newOp = std::make_shared<ngraph::op::TypeRelaxed<opset1::Add>>(
std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{}, std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
ngraph::op::TemporaryReplaceOutputType(convert->get_input_source_output(0), element::f32).get(), ngraph::op::TemporaryReplaceOutputType(convert->input_value(0), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(op->get_input_node_shared_ptr(1), element::f32).get()); ngraph::op::TemporaryReplaceOutputType(op->input_value(1), element::f32).get());
NetworkHelper::setOutDataPrecisionForTypeRelaxed(newOp, op->get_output_element_type(0)); NetworkHelper::setOutDataPrecisionForTypeRelaxed(newOp, op->get_output_element_type(0));
replace_node(op, newOp); replace_node(op, newOp);
} }

View File

@ -54,7 +54,7 @@ std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const PartialShape&
return op; return op;
} }
std::shared_ptr<Node> getData(const std::shared_ptr<Node>& eltwise) { std::shared_ptr<Node> getDataNode(const std::shared_ptr<Node>& eltwise) {
if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) { if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
return eltwise->get_input_node_shared_ptr(0); return eltwise->get_input_node_shared_ptr(0);
} }
@ -108,7 +108,7 @@ bool eltwiseWithConstant(const std::shared_ptr<Node>& eltwise) {
} }
} }
return getData(eltwise) != nullptr; return getDataNode(eltwise) != nullptr;
} }
} // namespace fuse_fq } // namespace fuse_fq
@ -144,8 +144,8 @@ std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
inputLowConst = fuse_fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0)); inputLowConst = fuse_fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
inputHightConst = fuse_fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0)); inputHightConst = fuse_fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
} else if (ov::is_type<opset1::Add>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) { } else if (ov::is_type<opset1::Add>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
if (ov::is_type<opset1::Convolution>(fuse_fq::getData(eltwise)) || if (ov::is_type<opset1::Convolution>(fuse_fq::getDataNode(eltwise)) ||
ov::is_type<opset1::GroupConvolution>(fuse_fq::getData(eltwise))) { ov::is_type<opset1::GroupConvolution>(fuse_fq::getDataNode(eltwise))) {
return nullptr; return nullptr;
} }
@ -157,15 +157,18 @@ std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
inputHightConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0)); inputHightConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
} else if (ov::is_type<opset1::Convert>(eltwise)) { } else if (ov::is_type<opset1::Convert>(eltwise)) {
// issue #40611 // issue #40611
if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) { if ((eltwise->get_input_element_type(0) == element::i32) && (eltwise->get_output_element_type(0) == element::f32)) {
return nullptr; return nullptr;
} }
} else { } else {
return nullptr; return nullptr;
} }
const auto data = fuse_fq::getDataNode(eltwise);
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({ std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
fuse_fq::getData(eltwise), data->output(outputIdx),
inputLowConst, inputLowConst,
inputHightConst, inputHightConst,
fakeQuantize->input_value(3), fakeQuantize->input_value(3),

View File

@ -46,9 +46,12 @@ bool FuseMultiplyToFakeQuantizeTransformation::transform(TransformationContext&
} }
const auto multiplyConstant = multiply->get_input_node_shared_ptr(1); const auto multiplyConstant = multiply->get_input_node_shared_ptr(1);
if (!ov::is_type<opset1::Constant>(multiplyConstant)) {
return false;
}
auto outputLowConst_f32 = foldConvert(fakeQuantize->get_input_node_shared_ptr(3), deqPrecision); auto outputLowConst_f32 = foldConvert(fakeQuantize->input_value(3), deqPrecision);
auto outputHighConst_f32 = foldConvert(fakeQuantize->get_input_node_shared_ptr(4), deqPrecision); auto outputHighConst_f32 = foldConvert(fakeQuantize->input_value(4), deqPrecision);
const auto value = multiplyConstant->get_output_element_type(0) == element::f32 ? const auto value = multiplyConstant->get_output_element_type(0) == element::f32 ?
multiplyConstant : multiplyConstant :
@ -57,9 +60,6 @@ bool FuseMultiplyToFakeQuantizeTransformation::transform(TransformationContext&
outputLowConst_f32 = fold<opset1::Multiply>(outputLowConst_f32, value); outputLowConst_f32 = fold<opset1::Multiply>(outputLowConst_f32, value);
outputHighConst_f32 = fold<opset1::Multiply>(outputHighConst_f32, value); outputHighConst_f32 = fold<opset1::Multiply>(outputHighConst_f32, value);
const auto fakeQuantizeParent = fakeQuantize->get_input_node_shared_ptr(0);
const size_t parentIndex = NetworkHelper::getParentOutputIndex(fakeQuantizeParent, fakeQuantize);
const auto inputLow = foldConvert(fakeQuantize->input_value(1), deqPrecision); const auto inputLow = foldConvert(fakeQuantize->input_value(1), deqPrecision);
const auto inputHigh = foldConvert(fakeQuantize->input_value(2), deqPrecision); const auto inputHigh = foldConvert(fakeQuantize->input_value(2), deqPrecision);
NetworkHelper::copyInfo(fakeQuantize->get_input_node_shared_ptr(1), inputLow); NetworkHelper::copyInfo(fakeQuantize->get_input_node_shared_ptr(1), inputLow);
@ -69,7 +69,7 @@ bool FuseMultiplyToFakeQuantizeTransformation::transform(TransformationContext&
auto newFakeQuantize = std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>( auto newFakeQuantize = std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>(
opset1::FakeQuantize( opset1::FakeQuantize(
fakeQuantizeParent->output(parentIndex), fakeQuantize->input_value(0),
inputLow, inputLow,
inputHigh, inputHigh,
outputLowConst_f32, outputLowConst_f32,

View File

@ -45,9 +45,12 @@ bool FuseSubtractToFakeQuantizeTransformation::transform(TransformationContext&
} }
const auto subtractConstant = subtract->get_input_node_shared_ptr(1); const auto subtractConstant = subtract->get_input_node_shared_ptr(1);
if (!ov::is_type<opset1::Constant>(subtractConstant)) {
return false;
}
auto outputLowConst_f32 = foldConvert(fakeQuantize->get_input_node_shared_ptr(3), deqPrecision); auto outputLowConst_f32 = foldConvert(fakeQuantize->input_value(3), deqPrecision);
auto outputHighConst_f32 = foldConvert(fakeQuantize->get_input_node_shared_ptr(4), deqPrecision); auto outputHighConst_f32 = foldConvert(fakeQuantize->input_value(4), deqPrecision);
const auto value = subtractConstant->get_output_element_type(0) == element::f32 ? const auto value = subtractConstant->get_output_element_type(0) == element::f32 ?
subtractConstant : subtractConstant :
@ -56,9 +59,6 @@ bool FuseSubtractToFakeQuantizeTransformation::transform(TransformationContext&
outputLowConst_f32 = fold<opset1::Subtract>(outputLowConst_f32, value); outputLowConst_f32 = fold<opset1::Subtract>(outputLowConst_f32, value);
outputHighConst_f32 = fold<opset1::Subtract>(outputHighConst_f32, value); outputHighConst_f32 = fold<opset1::Subtract>(outputHighConst_f32, value);
const auto fakeQuantizeParent = fakeQuantize->get_input_node_shared_ptr(0);
const size_t parentIndex = NetworkHelper::getParentOutputIndex(fakeQuantizeParent, fakeQuantize);
const auto inputLow = foldConvert(fakeQuantize->input_value(1), deqPrecision); const auto inputLow = foldConvert(fakeQuantize->input_value(1), deqPrecision);
const auto inputHigh = foldConvert(fakeQuantize->input_value(2), deqPrecision); const auto inputHigh = foldConvert(fakeQuantize->input_value(2), deqPrecision);
NetworkHelper::copyInfo(fakeQuantize->get_input_node_shared_ptr(1), inputLow); NetworkHelper::copyInfo(fakeQuantize->get_input_node_shared_ptr(1), inputLow);
@ -68,7 +68,7 @@ bool FuseSubtractToFakeQuantizeTransformation::transform(TransformationContext&
auto newFakeQuantize = std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>( auto newFakeQuantize = std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>(
opset1::FakeQuantize( opset1::FakeQuantize(
fakeQuantizeParent->output(parentIndex), fakeQuantize->input_value(0),
inputLow, inputLow,
inputHigh, inputHigh,
outputLowConst_f32, outputLowConst_f32,

View File

@ -109,7 +109,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
// multiply by weights: [1, ..., 1, Y] x [Y, Z] => [1, ..., 1, Z] // multiply by weights: [1, ..., 1, Y] x [Y, Z] => [1, ..., 1, Z]
const auto newSubConst = NetworkHelper::toScalarIfPossible(fold<opset1::MatMul>( const auto newSubConst = NetworkHelper::toScalarIfPossible(fold<opset1::MatMul>(
broadcastedConst, broadcastedConst,
foldConvert(newMatMul->get_input_node_shared_ptr(1), newMatMul->get_element_type()), foldConvert(newMatMul->input_value(1), newMatMul->get_element_type()),
newMatMul->get_transpose_a(), newMatMul->get_transpose_a(),
newMatMul->get_transpose_b())); newMatMul->get_transpose_b()));

View File

@ -77,10 +77,10 @@ bool MultiplyTransformation::transform(TransformationContext& context, ngraph::p
return false; return false;
} }
auto multiplyParent = multiply->get_input_source_output(multiplyBranch.first); auto multiplyParent = multiply->input_value(multiplyBranch.first);
auto constParent = multiply->get_input_source_output(multiplyBranch.first == 0 ? 1 : 0); auto constParent = multiply->input_value(multiplyBranch.first == 0 ? 1 : 0);
auto multiplyParentParent = multiplyParent.get_node_shared_ptr()->get_input_source_output(multiplyBranch.second); auto multiplyParentParent = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second);
auto multiplyParentConst = multiplyParent.get_node_shared_ptr()->get_input_source_output(multiplyBranch.second == 0 ? 1 : 0); auto multiplyParentConst = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second == 0 ? 1 : 0);
newMultiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>( newMultiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>(
std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{ element::f32, element::f32 },
@ -127,7 +127,7 @@ bool MultiplyTransformation::transform(TransformationContext& context, ngraph::p
// before: Y = (SC1 * (X1 - SH1)) * (SC2 * X2) // before: Y = (SC1 * (X1 - SH1)) * (SC2 * X2)
// after : Y = (SC1' * (X1 - SH1)) * (X2) , where : // after : Y = (SC1' * (X1 - SH1)) * (X2) , where :
// SC1' = SC1 * SC2 // SC1' = SC1 * SC2
std::shared_ptr<Node> newMultiplyValuesFullPath = fold<opset1::Multiply>(multiplyValuesEmptyPath, multiplyValuesFullPath); auto newMultiplyValuesFullPath = fold<opset1::Multiply>(multiplyValuesEmptyPath, multiplyValuesFullPath);
OutputVector inputs{ {}, {} }; OutputVector inputs{ {}, {} };
inputs[emptyPathIndex] = dequantizationEmptyPath.data; inputs[emptyPathIndex] = dequantizationEmptyPath.data;
inputs[fullPathIndex] = std::make_shared<opset1::Multiply>( inputs[fullPathIndex] = std::make_shared<opset1::Multiply>(

View File

@ -149,7 +149,7 @@ bool MVNTransformation::transform(TransformationContext &context, ngraph::patter
if (ov::is_type<op::MVN>(mvn)) { if (ov::is_type<op::MVN>(mvn)) {
newMVN = mvn->copy_with_new_inputs({dequantization.data}); newMVN = mvn->copy_with_new_inputs({dequantization.data});
} else { } else {
newMVN = mvn->copy_with_new_inputs({dequantization.data, mvn->get_input_node_shared_ptr(1)}); newMVN = mvn->copy_with_new_inputs({dequantization.data, mvn->input_value(1)});
} }
NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMVN, deqPrecision); NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMVN, deqPrecision);
NetworkHelper::copyInfo(mvn, newMVN); NetworkHelper::copyInfo(mvn, newMVN);

View File

@ -233,10 +233,10 @@ std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::
if (multiplyConst == nullptr) if (multiplyConst == nullptr)
return addAfterMultiply; return addAfterMultiply;
const auto x = multiply->get_input_source_output(multiplyInputBranch); const auto x = multiply->input_value(multiplyInputBranch);
auto a = multiply->get_input_node_shared_ptr(multiplyInputBranch == 0 ? 1 : 0); auto a = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(multiplyInputBranch == 0 ? 1 : 0));
auto b = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch == 0 ? 1 : 0); auto b = as_type_ptr<opset1::Constant>(addAfterMultiply->get_input_node_shared_ptr(multiplyBranch == 0 ? 1 : 0));
std::shared_ptr<Node> bDivA; std::shared_ptr<opset1::Constant> bDivA;
const auto aPShape = a->get_output_partial_shape(0); const auto aPShape = a->get_output_partial_shape(0);
assert(aPShape.is_static()); assert(aPShape.is_static());
@ -248,8 +248,8 @@ std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::
if ((shape_size(bShape) == 1) || (shape_size(aShape) == 1) || (shape_size(bShape) == shape_size(aShape))) { if ((shape_size(bShape) == 1) || (shape_size(aShape) == 1) || (shape_size(bShape) == shape_size(aShape))) {
// safely division to avoid NaN // safely division to avoid NaN
const std::vector<float> bValues = ov::as_type_ptr<opset1::Constant>(b)->cast_vector<float>(); const std::vector<float> bValues = b->cast_vector<float>();
const std::vector<float> aValues = ov::as_type_ptr<opset1::Constant>(a)->cast_vector<float>(); const std::vector<float> aValues = a->cast_vector<float>();
const bool aBroadcasted = bValues.size() > aValues.size(); const bool aBroadcasted = bValues.size() > aValues.size();
const bool bBroadcasted = bValues.size() < aValues.size(); const bool bBroadcasted = bValues.size() < aValues.size();
std::vector<float> bDivAValues(aBroadcasted ? bValues.size() : aValues.size()); std::vector<float> bDivAValues(aBroadcasted ? bValues.size() : aValues.size());
@ -271,16 +271,16 @@ std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::
aBroadcasted ? bShape : aShape, aBroadcasted ? bShape : aShape,
bDivAValues); bDivAValues);
} else { } else {
b = foldConvert(b, element::f32); b = as_type_ptr<opset1::Constant>(foldConvert(b->output(0), element::f32));
a = foldConvert(a, element::f32); a = as_type_ptr<opset1::Constant>(foldConvert(a->output(0), element::f32));
bDivA = fold<opset1::Divide>(b, a); bDivA = as_type_ptr<opset1::Constant>(fold<opset1::Divide>(b->output(0), a->output(0)));
// TODO: issue #49868 // TODO: issue #49868
bDivA = foldConvert(bDivA, a->get_output_element_type(0)); bDivA = as_type_ptr<opset1::Constant>(foldConvert(bDivA->output(0), a->get_element_type()));
} }
OutputVector inputs{ {}, {} }; OutputVector inputs{ {}, {} };
inputs[0] = x; inputs[0] = x;
inputs[1] = bDivA; inputs[1] = bDivA->output(0);
std::shared_ptr<opset1::Add> newAdd = std::make_shared<op::TypeRelaxed<opset1::Add>>( std::shared_ptr<opset1::Add> newAdd = std::make_shared<op::TypeRelaxed<opset1::Add>>(
std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{element::f32, element::f32},
@ -292,8 +292,8 @@ std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::
auto newMultiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>( auto newMultiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>(
std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{ multiply->get_output_element_type(0) }, std::vector<element::Type>{ multiply->get_output_element_type(0) },
ngraph::op::TemporaryReplaceOutputType(newAdd, element::f32).get(), ngraph::op::TemporaryReplaceOutputType(newAdd->output(0), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(a, element::f32).get()); ngraph::op::TemporaryReplaceOutputType(a->output(0), element::f32).get());
copyInfo({ multiply, newMultiply }, newMultiply); copyInfo({ multiply, newMultiply }, newMultiply);
replace_node(addAfterMultiply, newMultiply); replace_node(addAfterMultiply, newMultiply);
@ -460,7 +460,7 @@ std::shared_ptr<ngraph::opset1::Multiply> NetworkHelper::optimizeMultipliesAfter
} }
auto newInput = multiply->input_value(1 - constant1->output(0).get_target_inputs().begin()->get_index()); auto newInput = multiply->input_value(1 - constant1->output(0).get_target_inputs().begin()->get_index());
auto multiplyResult = fold<opset1::Multiply>(constant1, constant2); auto multiplyResult = fold<opset1::Multiply>(constant1->output(0), constant2->output(0));
{ {
// optimize constant shape: used in rfcn-resnet101-coco // optimize constant shape: used in rfcn-resnet101-coco
const auto multiplyResultConstant = ov::as_type_ptr<opset1::Constant>(multiplyResult); const auto multiplyResultConstant = ov::as_type_ptr<opset1::Constant>(multiplyResult);
@ -526,13 +526,13 @@ FakeQuantizeDequantization NetworkHelper::foldDequantization(const std::shared_p
} }
if (dequantization.subtract != nullptr) { if (dequantization.subtract != nullptr) {
if (dequantization.subtract->input(0).get_element_type() != dequantization.subtract->input(1).get_element_type()) { if (dequantization.subtract->get_input_element_type(0) != dequantization.subtract->get_input_element_type(1)) {
return dequantization; return dequantization;
} }
if (dequantization.subtractConvert != nullptr) { if (dequantization.subtractConvert != nullptr) {
const auto convertionResult = foldConvert( const auto convertionResult = foldConvert(
dequantization.subtractConstant, dequantization.subtractConstant->output(0),
dequantization.subtractConvert->get_element_type()); dequantization.subtractConvert->get_element_type());
if (ov::is_type<opset1::Constant>(convertionResult)) { if (ov::is_type<opset1::Constant>(convertionResult)) {
replace_node(dequantization.subtractConvert, convertionResult); replace_node(dequantization.subtractConvert, convertionResult);
@ -541,8 +541,8 @@ FakeQuantizeDequantization NetworkHelper::foldDequantization(const std::shared_p
} }
const std::shared_ptr<Node> result = fold<opset1::Subtract>( const std::shared_ptr<Node> result = fold<opset1::Subtract>(
dequantization.subtract->get_input_node_shared_ptr(0), dequantization.subtract->input_value(0),
dequantization.subtract->get_input_node_shared_ptr(1)); dequantization.subtract->input_value(1));
if (ov::is_type<opset1::Constant>(result)) { if (ov::is_type<opset1::Constant>(result)) {
if (inPlace) { if (inPlace) {
copyInfo(dequantization.subtract, result); copyInfo(dequantization.subtract, result);
@ -555,18 +555,18 @@ FakeQuantizeDequantization NetworkHelper::foldDequantization(const std::shared_p
} }
if (dequantization.multiply != nullptr) { if (dequantization.multiply != nullptr) {
if (dequantization.multiply->input(0).get_element_type() != dequantization.multiply->input(1).get_element_type()) { if (dequantization.multiply->get_input_element_type(0) != dequantization.multiply->get_input_element_type(1)) {
return dequantization; return dequantization;
} }
std::shared_ptr<Node> result = fold<opset1::Multiply>( std::shared_ptr<Node> result = fold<opset1::Multiply>(
dequantization.multiply->get_input_node_shared_ptr(0), dequantization.multiply->input_value(0),
dequantization.multiply->get_input_node_shared_ptr(1)); dequantization.multiply->input_value(1));
if (!ov::is_type<opset1::Constant>(result)) { if (!ov::is_type<opset1::Constant>(result)) {
return dequantization; return dequantization;
} }
if (dequantization.multiply->get_output_element_type(0) != result->get_element_type()) { if (dequantization.multiply->get_output_element_type(0) != result->get_element_type()) {
result = foldConvert(result, dequantization.multiply->get_output_element_type(0)); result = foldConvert(result->output(0), dequantization.multiply->get_output_element_type(0));
} }
if (inPlace) { if (inPlace) {
copyInfo(dequantization.multiply, result); copyInfo(dequantization.multiply, result);
@ -599,7 +599,7 @@ std::shared_ptr<ngraph::Node> NetworkHelper::separateInStandaloneBranch(std::sha
outputs.push_back(input.get_source_output()); outputs.push_back(input.get_source_output());
} }
auto subtract = dequantization.subtract->clone_with_new_inputs({parent, parentOnWeights->clone_with_new_inputs(outputs) }); auto subtract = dequantization.subtract->clone_with_new_inputs({parent, parentOnWeights->clone_with_new_inputs(outputs)->output(0) });
subtract->set_friendly_name(""); subtract->set_friendly_name("");
copy_runtime_info(parent.get_node_shared_ptr(), subtract); copy_runtime_info(parent.get_node_shared_ptr(), subtract);
parent = subtract->output(0); parent = subtract->output(0);
@ -608,7 +608,7 @@ std::shared_ptr<ngraph::Node> NetworkHelper::separateInStandaloneBranch(std::sha
if (dequantization.multiply != nullptr) { if (dequantization.multiply != nullptr) {
auto multiply = dequantization.multiply->clone_with_new_inputs({ auto multiply = dequantization.multiply->clone_with_new_inputs({
parent, parent,
dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) }); dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({})->output(0) });
multiply->set_friendly_name(""); multiply->set_friendly_name("");
copy_runtime_info(parent.get_node_shared_ptr(), multiply); copy_runtime_info(parent.get_node_shared_ptr(), multiply);
parent = multiply->output(0); parent = multiply->output(0);
@ -650,11 +650,11 @@ std::shared_ptr<opset1::FakeQuantize> NetworkHelper::fuseConvert(const std::shar
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = std::make_shared<ngraph::op::TypeRelaxed<opset1::FakeQuantize>>( std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = std::make_shared<ngraph::op::TypeRelaxed<opset1::FakeQuantize>>(
std::vector<ngraph::element::Type>{ element::f32, element::f32, element::f32, element::f32, element::f32 }, std::vector<ngraph::element::Type>{ element::f32, element::f32, element::f32, element::f32, element::f32 },
std::vector<ngraph::element::Type>{}, std::vector<ngraph::element::Type>{},
ngraph::op::TemporaryReplaceOutputType(fakeQuantize->get_input_node_shared_ptr(0), element::f32).get(), ngraph::op::TemporaryReplaceOutputType(fakeQuantize->input_value(0), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(fakeQuantize->get_input_node_shared_ptr(1), element::f32).get(), ngraph::op::TemporaryReplaceOutputType(fakeQuantize->input_value(1), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(fakeQuantize->get_input_node_shared_ptr(2), element::f32).get(), ngraph::op::TemporaryReplaceOutputType(fakeQuantize->input_value(2), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(fakeQuantize->get_input_node_shared_ptr(3), element::f32).get(), ngraph::op::TemporaryReplaceOutputType(fakeQuantize->input_value(3), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(fakeQuantize->get_input_node_shared_ptr(4), element::f32).get(), ngraph::op::TemporaryReplaceOutputType(fakeQuantize->input_value(4), element::f32).get(),
fakeQuantize->get_levels()); fakeQuantize->get_levels());
NetworkHelper::setOutDataPrecisionForTypeRelaxed(newFakeQuantize, node->get_output_element_type(0)); NetworkHelper::setOutDataPrecisionForTypeRelaxed(newFakeQuantize, node->get_output_element_type(0));
replace_node(node->shared_from_this(), newFakeQuantize); replace_node(node->shared_from_this(), newFakeQuantize);
@ -889,14 +889,14 @@ std::shared_ptr<opset1::FakeQuantize> NetworkHelper::composeFakeQuantize(const s
if (dequantization.subtract != nullptr) { if (dequantization.subtract != nullptr) {
const auto subtractValue = (dequantization.subtractConvert == nullptr) ? const auto subtractValue = (dequantization.subtractConvert == nullptr) ?
dequantization.subtractConstant : dequantization.subtractConstant :
foldConvert(dequantization.subtractConstant, dequantization.subtractConvert->output(0).get_element_type()); foldConvert(dequantization.subtractConstant->output(0), dequantization.subtractConvert->get_destination_type());
const std::shared_ptr<opset1::FakeQuantize> replacement = std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>( const std::shared_ptr<opset1::FakeQuantize> replacement = std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>(
newFakeQuantize->input_value(0), newFakeQuantize->input_value(0),
newFakeQuantize->input_value(1), newFakeQuantize->input_value(1),
newFakeQuantize->input_value(2), newFakeQuantize->input_value(2),
fold<opset1::Subtract>(newFakeQuantize->get_input_node_shared_ptr(3), subtractValue), fold<opset1::Subtract>(newFakeQuantize->input_value(3), subtractValue),
fold<opset1::Subtract>(newFakeQuantize->get_input_node_shared_ptr(4), subtractValue), fold<opset1::Subtract>(newFakeQuantize->input_value(4), subtractValue),
newFakeQuantize->get_levels(), newFakeQuantize->get_levels(),
newFakeQuantize->get_auto_broadcast()); newFakeQuantize->get_auto_broadcast());
replace_node(dequantization.subtract, replacement); replace_node(dequantization.subtract, replacement);
@ -907,11 +907,9 @@ std::shared_ptr<opset1::FakeQuantize> NetworkHelper::composeFakeQuantize(const s
if (dequantization.multiply != nullptr) { if (dequantization.multiply != nullptr) {
// multiply different precision constants (value1 & value2) and convert result to first argument precision (value1) // multiply different precision constants (value1 & value2) and convert result to first argument precision (value1)
auto multiply = []( auto multiply = [](const Output<Node>& value1, const Output<Node>& value2) {
const std::shared_ptr<ngraph::Node>& value1, const ngraph::element::Type precision1 = value1.get_element_type();
const std::shared_ptr<ngraph::Node>& value2) -> std::shared_ptr<ngraph::Node> { const ngraph::element::Type precision2 = value2.get_element_type();
const ngraph::element::Type precision1 = value1->output(0).get_element_type();
const ngraph::element::Type precision2 = value2->output(0).get_element_type();
// 1) precision1 & precision2 are not equal but similar // 1) precision1 & precision2 are not equal but similar
// 2) precision2 >= precision1 // 2) precision2 >= precision1
assert((precision2.is_real() == precision1.is_real()) && (precision2.bitwidth() >= precision1.bitwidth())); assert((precision2.is_real() == precision1.is_real()) && (precision2.bitwidth() >= precision1.bitwidth()));
@ -921,7 +919,7 @@ std::shared_ptr<opset1::FakeQuantize> NetworkHelper::composeFakeQuantize(const s
value2); value2);
if (output->output(0).get_element_type() != precision1) { if (output->output(0).get_element_type() != precision1) {
output = foldConvert(output, precision1); output = foldConvert(output->output(0), precision1);
} }
return output; return output;
@ -931,8 +929,8 @@ std::shared_ptr<opset1::FakeQuantize> NetworkHelper::composeFakeQuantize(const s
newFakeQuantize->input_value(0ul), newFakeQuantize->input_value(0ul),
newFakeQuantize->input_value(1ul), newFakeQuantize->input_value(1ul),
newFakeQuantize->input_value(2ul), newFakeQuantize->input_value(2ul),
multiply(newFakeQuantize->get_input_node_shared_ptr(3ul), dequantization.multiplyConstant), multiply(newFakeQuantize->input_value(3ul), dequantization.multiplyConstant),
multiply(newFakeQuantize->get_input_node_shared_ptr(4ul), dequantization.multiplyConstant), multiply(newFakeQuantize->input_value(4ul), dequantization.multiplyConstant),
newFakeQuantize->get_levels(), newFakeQuantize->get_levels(),
newFakeQuantize->get_auto_broadcast()); newFakeQuantize->get_auto_broadcast());
@ -956,8 +954,6 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
const bool updatePrecision, const bool updatePrecision,
const element::Type deqPrecision, const element::Type deqPrecision,
const size_t outChannelsShapeIndex) { const size_t outChannelsShapeIndex) {
using std::make_shared;
const auto outputLow = fq->input_value(3); const auto outputLow = fq->input_value(3);
const auto outputHigh = fq->input_value(4); const auto outputHigh = fq->input_value(4);
@ -1015,8 +1011,8 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
nullptr; nullptr;
std::shared_ptr<Node> scale = std::make_shared<opset1::Constant>(element::f32, outputLow.get_shape(), scales); std::shared_ptr<Node> scale = std::make_shared<opset1::Constant>(element::f32, outputLow.get_shape(), scales);
auto newMin = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), minValues); auto newMin = std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), minValues);
auto newMax = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), maxValues); auto newMax = std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), maxValues);
if (isScalarLike(newMin)) { if (isScalarLike(newMin)) {
newMin = toScalar(newMin); newMin = toScalar(newMin);
@ -1072,7 +1068,7 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
std::shared_ptr<opset1::Constant> newFqConstant = ov::as_type_ptr<opset1::Constant>(newFQ); std::shared_ptr<opset1::Constant> newFqConstant = ov::as_type_ptr<opset1::Constant>(newFQ);
if (ov::is_type<opset1::Constant>(newFQ)) { if (ov::is_type<opset1::Constant>(newFQ)) {
convert = foldConvert(newFQ, precision); convert = foldConvert(newFQ->output(0), precision);
} else if (ov::is_type<opset1::FakeQuantize>(newFQ)) { } else if (ov::is_type<opset1::FakeQuantize>(newFQ)) {
newFQ = setOutDataPrecision(ov::as_type_ptr<opset1::FakeQuantize>(newFQ), precision); newFQ = setOutDataPrecision(ov::as_type_ptr<opset1::FakeQuantize>(newFQ), precision);
convert = newFQ; convert = newFQ;
@ -1192,11 +1188,9 @@ FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
const bool hasZeroPoint, const bool hasZeroPoint,
const bool updatePrecision, const bool updatePrecision,
const element::Type deqPrecision) { const element::Type deqPrecision) {
using std::make_shared;
const ngraph::element::Type_t fqPrecision = fq->get_output_element_type(0); const ngraph::element::Type_t fqPrecision = fq->get_output_element_type(0);
auto newMin = make_shared<opset1::Constant>(fqPrecision, Shape{}, min); auto newMin = std::make_shared<opset1::Constant>(fqPrecision, Shape{}, min);
auto newMax = make_shared<opset1::Constant>(fqPrecision, Shape{}, max); auto newMax = std::make_shared<opset1::Constant>(fqPrecision, Shape{}, max);
auto outputLow = fq->input_value(3); auto outputLow = fq->input_value(3);
auto outputHigh = fq->input_value(4); auto outputHigh = fq->input_value(4);
@ -1205,12 +1199,12 @@ FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
const std::shared_ptr<opset1::Constant> scale = ov::as_type_ptr<opset1::Constant>(foldConvert(fold<opset1::Divide>( const std::shared_ptr<opset1::Constant> scale = ov::as_type_ptr<opset1::Constant>(foldConvert(fold<opset1::Divide>(
fold<opset1::Subtract>(outputHigh, outputLow), fold<opset1::Subtract>(outputHigh, outputLow),
fold<opset1::Subtract>(newMax, newMin)), deqPrecision)); fold<opset1::Subtract>(newMax->output(0), newMin->output(0))), deqPrecision));
assert(scale != nullptr); assert(scale != nullptr);
std::shared_ptr<opset1::Constant> shift = hasZeroPoint ? std::shared_ptr<opset1::Constant> shift = hasZeroPoint ?
ov::as_type_ptr<opset1::Constant>(foldConvert(fold<opset1::Divide>( ov::as_type_ptr<opset1::Constant>(foldConvert(fold<opset1::Divide>(
fold<opset1::Subtract>(fold<opset1::Multiply>(newMin, outputHigh), fold<opset1::Multiply>(newMax, outputLow)), fold<opset1::Subtract>(fold<opset1::Multiply>(newMin->output(0), outputHigh), fold<opset1::Multiply>(newMax->output(0), outputLow)),
fold<opset1::Subtract>(outputHigh, outputLow)), deqPrecision)) : fold<opset1::Subtract>(outputHigh, outputLow)), deqPrecision)) :
nullptr; nullptr;
assert((!hasZeroPoint) || (hasZeroPoint && shift != nullptr)); assert((!hasZeroPoint) || (hasZeroPoint && shift != nullptr));
@ -1240,7 +1234,7 @@ FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
std::shared_ptr<ngraph::opset1::Subtract> subtract; std::shared_ptr<ngraph::opset1::Subtract> subtract;
if (shift != nullptr) { if (shift != nullptr) {
subtract = make_shared<ngraph::op::TypeRelaxed<opset1::Subtract>>(parent, shift); subtract = std::make_shared<ngraph::op::TypeRelaxed<opset1::Subtract>>(parent, shift);
subtract->set_output_type(0, deqPrecision, subtract->get_output_partial_shape(0)); subtract->set_output_type(0, deqPrecision, subtract->get_output_partial_shape(0));
parent = subtract; parent = subtract;
} else { } else {
@ -1416,16 +1410,16 @@ FakeQuantizeDequantization NetworkHelper::normalizeDequantization(FakeQuantizeDe
return dequantization; return dequantization;
} }
if (dequantization.multiply != nullptr && ov::as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0))) { if (dequantization.multiply != nullptr && ov::as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0))) {
std::shared_ptr<Node> leftParent = dequantization.multiply->get_input_node_shared_ptr(0); const auto leftParent = dequantization.multiply->input_value(0);
std::shared_ptr<Node> rightParent = dequantization.multiply->get_input_node_shared_ptr(1); const auto rightParent = dequantization.multiply->input_value(1);
std::shared_ptr<opset1::Multiply> normalized_multiply = ov::as_type_ptr<opset1::Multiply>( std::shared_ptr<opset1::Multiply> normalized_multiply = ov::as_type_ptr<opset1::Multiply>(
dequantization.multiply->clone_with_new_inputs({rightParent, leftParent})); dequantization.multiply->clone_with_new_inputs({rightParent, leftParent}));
replace_node(dequantization.multiply, normalized_multiply); replace_node(dequantization.multiply, normalized_multiply);
dequantization.multiply = normalized_multiply; dequantization.multiply = normalized_multiply;
} }
if (dequantization.subtract != nullptr && ov::as_type_ptr<ngraph::opset1::Constant>(dequantization.subtract->get_input_node_shared_ptr(0))) { if (dequantization.subtract != nullptr && ov::as_type_ptr<ngraph::opset1::Constant>(dequantization.subtract->get_input_node_shared_ptr(0))) {
std::shared_ptr<Node> leftParent = dequantization.subtract->get_input_node_shared_ptr(0); const auto leftParent = dequantization.subtract->input_value(0);
std::shared_ptr<Node> rightParent = dequantization.subtract->get_input_node_shared_ptr(1); const auto rightParent = dequantization.subtract->input_value(1);
std::shared_ptr<opset1::Subtract> normalized_subtract = ov::as_type_ptr<opset1::Subtract>( std::shared_ptr<opset1::Subtract> normalized_subtract = ov::as_type_ptr<opset1::Subtract>(
dequantization.subtract->clone_with_new_inputs({rightParent, leftParent})); dequantization.subtract->clone_with_new_inputs({rightParent, leftParent}));
replace_node(dequantization.subtract, normalized_subtract); replace_node(dequantization.subtract, normalized_subtract);
@ -1452,7 +1446,7 @@ std::shared_ptr<opset1::Constant> NetworkHelper::normalizeDequantizationShape(co
std::iota(unsqueezeConstantShape.begin(), unsqueezeConstantShape.end(), 0ul); std::iota(unsqueezeConstantShape.begin(), unsqueezeConstantShape.end(), 0ul);
const auto newConstant = fold<opset1::Unsqueeze>( const auto newConstant = fold<opset1::Unsqueeze>(
constant, constant->output(0),
op::Constant::create(element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape)); op::Constant::create(element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape));
return ov::as_type_ptr<opset1::Constant>(newConstant); return ov::as_type_ptr<opset1::Constant>(newConstant);
@ -1471,13 +1465,13 @@ std::shared_ptr<opset1::Constant> NetworkHelper::normalizeDequantizationShape(co
FakeQuantizeDequantizationValues NetworkHelper::createEmptyValues(const FakeQuantizeDequantization& dequantization, const element::Type precision) { FakeQuantizeDequantizationValues NetworkHelper::createEmptyValues(const FakeQuantizeDequantization& dequantization, const element::Type precision) {
const std::shared_ptr<Node> multiplyConstant = dequantization.multiply ? const std::shared_ptr<Node> multiplyConstant = dequantization.multiply ?
dequantization.multiplyConstant->get_element_type() != precision ? dequantization.multiplyConstant->get_element_type() != precision ?
foldConvert(dequantization.multiplyConstant, precision) : foldConvert(dequantization.multiplyConstant->output(0), precision) :
dequantization.multiplyConstant : dequantization.multiplyConstant :
std::make_shared<opset1::Constant>(precision, Shape({}), std::vector<float>({ 1.f })); std::make_shared<opset1::Constant>(precision, Shape({}), std::vector<float>({ 1.f }));
const std::shared_ptr<Node> subtractConstant = dequantization.subtract ? const std::shared_ptr<Node> subtractConstant = dequantization.subtract ?
dequantization.subtractConstant->get_element_type() != precision ? dequantization.subtractConstant->get_element_type() != precision ?
foldConvert(dequantization.subtractConstant, precision) : foldConvert(dequantization.subtractConstant->output(0), precision) :
dequantization.subtractConstant : dequantization.subtractConstant :
std::make_shared<opset1::Constant>(precision, Shape({}), std::vector<float>({ 0.f })); std::make_shared<opset1::Constant>(precision, Shape({}), std::vector<float>({ 0.f }));
@ -1538,7 +1532,7 @@ std::shared_ptr<Node> NetworkHelper::optimizeSubtract(std::shared_ptr<opset1::Su
NetworkHelper::copyInfo(shift, roundedShift); NetworkHelper::copyInfo(shift, roundedShift);
// Propagate convertInputType down // Propagate convertInputType down
replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift); replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift->output(0));
NetworkHelper::copyInfo(subtract, replacement); NetworkHelper::copyInfo(subtract, replacement);
NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType); NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
replace_node(subtract, replacement); replace_node(subtract, replacement);
@ -1546,7 +1540,7 @@ std::shared_ptr<Node> NetworkHelper::optimizeSubtract(std::shared_ptr<opset1::Su
return replacement; return replacement;
} else if (ov::is_type<opset1::Convert>(subtractParent) && ov::is_type<opset1::Constant>(subtractParent->get_input_node_shared_ptr(0))) { } else if (ov::is_type<opset1::Convert>(subtractParent) && ov::is_type<opset1::Constant>(subtractParent->get_input_node_shared_ptr(0))) {
auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, subtractParent->get_input_node_shared_ptr(0)); auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, subtractParent->input_value(0));
NetworkHelper::copyInfo(subtract, replacement); NetworkHelper::copyInfo(subtract, replacement);
NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType); NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
replace_node(subtract, replacement); replace_node(subtract, replacement);
@ -1569,11 +1563,9 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
(NetworkHelper::getDequantization(operation).multiplyConstant == nullptr) || (NetworkHelper::getDequantization(operation).multiplyConstant == nullptr) ||
(NetworkHelper::getDequantization(operation).multiplyConstant.get() == dequantization.multiplyConstant.get())); (NetworkHelper::getDequantization(operation).multiplyConstant.get() == dequantization.multiplyConstant.get()));
std::vector<Output<Node>> inputs(operation->get_input_size()); assert(operation->get_output_size() == 1);
for (size_t i = 0; i < operation->get_input_size(); ++i) {
inputs[i] = operation->get_input_node_shared_ptr(i);
}
OutputVector inputs = operation->input_values();
const size_t dequantizationIndex = getChildInputIndex(dequantization.multiply, operation); const size_t dequantizationIndex = getChildInputIndex(dequantization.multiply, operation);
inputs[dequantizationIndex] = moveSubtract ? inputs[dequantizationIndex] = moveSubtract ?
dequantization.data : dequantization.data :
@ -1623,7 +1615,7 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
ngraph::op::TemporaryReplaceOutputType( ngraph::op::TemporaryReplaceOutputType(
dequantization.subtractConstant->output(0).get_element_type() == parentPrecision ? dequantization.subtractConstant->output(0).get_element_type() == parentPrecision ?
dequantization.subtractConstant : dequantization.subtractConstant :
foldConvert(dequantization.subtractConstant, parentPrecision), element::f32).get()); foldConvert(dequantization.subtractConstant->output(0), parentPrecision), element::f32).get());
ngraph::copy_runtime_info({ newOperation, parent }, parent); ngraph::copy_runtime_info({ newOperation, parent }, parent);
} else { } else {
parent = std::make_shared<opset1::Subtract>(parent, dequantization.subtractConvert); parent = std::make_shared<opset1::Subtract>(parent, dequantization.subtractConvert);

View File

@ -30,15 +30,15 @@ std::shared_ptr<Node> moveThroughElementwise(const std::shared_ptr<Node>& reshap
assert(ov::is_type<opset1::Constant>(elementwiseValues)); assert(ov::is_type<opset1::Constant>(elementwiseValues));
const std::shared_ptr<opset1::Reshape> newReshape = ov::as_type_ptr<opset1::Reshape>(reshape->clone_with_new_inputs({ const std::shared_ptr<opset1::Reshape> newReshape = ov::as_type_ptr<opset1::Reshape>(reshape->clone_with_new_inputs({
elementwise->get_input_node_shared_ptr(0ul), elementwise->input_value(0),
reshapeValues })); reshapeValues }));
std::shared_ptr<Node> newElementwiseValues; std::shared_ptr<Node> newElementwiseValues;
const Shape elementwiseValuesShape = elementwiseValues->output(0).get_shape(); const Shape elementwiseValuesShape = elementwiseValues->get_output_shape(0);
if (!elementwiseValuesShape.empty() && (elementwiseValuesShape.size() != 1ul)) { if (!elementwiseValuesShape.empty() && (elementwiseValuesShape.size() != 1ul)) {
// update shape constant value to avoid eltwise constan value broadcasting // update shape constant value to avoid eltwise constan value broadcasting
const Shape elementwiseShape = elementwise->output(0).get_shape(); const Shape elementwiseShape = elementwise->get_output_shape(0);
const std::vector<size_t> reshapeValuesVector = ov::as_type_ptr<opset1::Constant>(reshapeValues)->cast_vector<size_t>(); const std::vector<size_t> reshapeValuesVector = ov::as_type_ptr<opset1::Constant>(reshapeValues)->cast_vector<size_t>();
const std::vector<size_t> newReshapeValuesVector = ngraph::pass::low_precision::NetworkHelper::updateReshapeValues( const std::vector<size_t> newReshapeValuesVector = ngraph::pass::low_precision::NetworkHelper::updateReshapeValues(
@ -47,13 +47,13 @@ std::shared_ptr<Node> moveThroughElementwise(const std::shared_ptr<Node>& reshap
reshapeValuesVector); reshapeValuesVector);
const auto newReshapeValues = std::make_shared<opset1::Constant>( const auto newReshapeValues = std::make_shared<opset1::Constant>(
reshapeValues->output(0).get_element_type(), reshapeValues->get_output_element_type(0),
Shape{ newReshapeValuesVector.size() }, Shape{ newReshapeValuesVector.size() },
newReshapeValuesVector); newReshapeValuesVector);
newElementwiseValues = ngraph::pass::low_precision::fold_reshape<opset1::Reshape>( newElementwiseValues = ngraph::pass::low_precision::fold_reshape<opset1::Reshape>(
elementwiseValues->output(0), elementwiseValues,
newReshapeValues->output(0), newReshapeValues,
ov::as_type_ptr<opset1::Reshape>(reshape)->get_special_zero()); ov::as_type_ptr<opset1::Reshape>(reshape)->get_special_zero());
assert(ov::is_type<opset1::Constant>(newElementwiseValues)); assert(ov::is_type<opset1::Constant>(newElementwiseValues));
} else { } else {
@ -71,7 +71,7 @@ std::shared_ptr<Node> moveThroughElementwise(const std::shared_ptr<Node>& reshap
} }
std::shared_ptr<Node> moveThroughConvert(const std::shared_ptr<Node>& reshape, const std::shared_ptr<Node>& convert) { std::shared_ptr<Node> moveThroughConvert(const std::shared_ptr<Node>& reshape, const std::shared_ptr<Node>& convert) {
const auto newReshape = reshape->clone_with_new_inputs({ convert->get_input_node_shared_ptr(0), reshape->get_input_node_shared_ptr(1) }); const auto newReshape = reshape->clone_with_new_inputs({ convert->input_value(0), reshape->input_value(1) });
const auto newConvert = convert->clone_with_new_inputs({ newReshape }); const auto newConvert = convert->clone_with_new_inputs({ newReshape });
replace_node(reshape, newConvert); replace_node(reshape, newConvert);
copy_runtime_info({ convert, reshape }, { newReshape, newConvert }); copy_runtime_info({ convert, reshape }, { newReshape, newConvert });
@ -81,7 +81,7 @@ std::shared_ptr<Node> moveThroughConvert(const std::shared_ptr<Node>& reshape, c
void fuseConstant(const std::shared_ptr<Node>& reshape, const std::shared_ptr<Node>& constant) { void fuseConstant(const std::shared_ptr<Node>& reshape, const std::shared_ptr<Node>& constant) {
ngraph::OutputVector result(1); ngraph::OutputVector result(1);
reshape->constant_fold(result, { constant->output(0), reshape->get_input_node_ptr(1)->output(0) }); reshape->constant_fold(result, { constant, reshape->input_value(1) });
const auto newConstant = result[0].get_node_shared_ptr(); const auto newConstant = result[0].get_node_shared_ptr();
replace_node(reshape, newConstant); replace_node(reshape, newConstant);
copy_runtime_info({ constant, reshape }, newConstant); copy_runtime_info({ constant, reshape }, newConstant);

View File

@ -30,8 +30,8 @@ std::shared_ptr<Node> moveThroughElementwise(const std::shared_ptr<Node>& transp
elementwiseValuesConvert->get_input_node_shared_ptr(0ul); elementwiseValuesConvert->get_input_node_shared_ptr(0ul);
assert(ov::is_type<opset1::Constant>(elementwiseValues)); assert(ov::is_type<opset1::Constant>(elementwiseValues));
const auto transposeValuesShape = transposeValues->output(0).get_shape(); const auto transposeValuesShape = transposeValues->get_output_shape(0);
const auto elementwiseValuesShape = elementwiseValues->output(0).get_shape(); const auto elementwiseValuesShape = elementwiseValues->get_output_shape(0);
if (elementwiseValuesShape.size() != shape_size(transposeValuesShape)) { if (elementwiseValuesShape.size() != shape_size(transposeValuesShape)) {
if (shape_size(elementwiseValuesShape) != 1ul) { if (shape_size(elementwiseValuesShape) != 1ul) {
return nullptr; return nullptr;
@ -51,8 +51,8 @@ std::shared_ptr<Node> moveThroughElementwise(const std::shared_ptr<Node>& transp
transposeValues })); transposeValues }));
const auto newElementwiseValues = ngraph::pass::low_precision::fold<opset1::Transpose>( const auto newElementwiseValues = ngraph::pass::low_precision::fold<opset1::Transpose>(
elementwiseValues->output(0), elementwiseValues,
transposeValues->output(0)); transposeValues);
assert(ov::is_type<opset1::Constant>(newElementwiseValues)); assert(ov::is_type<opset1::Constant>(newElementwiseValues));
const auto newElementwise = elementwise->clone_with_new_inputs({ const auto newElementwise = elementwise->clone_with_new_inputs({
@ -68,7 +68,7 @@ std::shared_ptr<Node> moveThroughElementwise(const std::shared_ptr<Node>& transp
} }
std::shared_ptr<Node> moveThroughConvert(const std::shared_ptr<Node>& transpose, const std::shared_ptr<Node>& convert) { std::shared_ptr<Node> moveThroughConvert(const std::shared_ptr<Node>& transpose, const std::shared_ptr<Node>& convert) {
const auto newTranspose = transpose->clone_with_new_inputs({convert->get_input_node_shared_ptr(0), transpose->get_input_node_ptr(1)->output(0) }); const auto newTranspose = transpose->clone_with_new_inputs({convert->input_value(0), transpose->input_value(1) });
const auto newConvert = convert->clone_with_new_inputs({ newTranspose }); const auto newConvert = convert->clone_with_new_inputs({ newTranspose });
replace_node(transpose, newConvert); replace_node(transpose, newConvert);
copy_runtime_info({ convert, transpose }, { newTranspose, newConvert }); copy_runtime_info({ convert, transpose }, { newTranspose, newConvert });
@ -78,8 +78,8 @@ std::shared_ptr<Node> moveThroughConvert(const std::shared_ptr<Node>& transpose,
void fuseConstant(const std::shared_ptr<Node>& transpose, const std::shared_ptr<Node>& constant) { void fuseConstant(const std::shared_ptr<Node>& transpose, const std::shared_ptr<Node>& constant) {
const auto newConstant = ngraph::pass::low_precision::fold<opset1::Transpose>( const auto newConstant = ngraph::pass::low_precision::fold<opset1::Transpose>(
constant->output(0), constant,
transpose->get_input_node_ptr(1)->output(0)); transpose->input_value(1));
replace_node(transpose, newConstant); replace_node(transpose, newConstant);
copy_runtime_info({ constant, transpose }, newConstant); copy_runtime_info({ constant, transpose }, newConstant);

View File

@ -63,7 +63,7 @@ void reshapeDequantizationConstant(const std::shared_ptr<opset1::Reshape>& resha
} }
} }
const auto reshapeOutputPShape = reshape->output(0).get_partial_shape(); const auto reshapeOutputPShape = reshape->get_output_partial_shape(0);
const auto reshapeOutputRank = reshapeOutputPShape.rank(); const auto reshapeOutputRank = reshapeOutputPShape.rank();
assert(reshapeOutputRank.is_static()); assert(reshapeOutputRank.is_static());
assert(reshapeOutputRank.get_length() >= 2); assert(reshapeOutputRank.get_length() >= 2);

View File

@ -52,7 +52,7 @@ std::shared_ptr<VariantWrapper<std::shared_ptr<IntervalsAlignmentAttribute>>> Va
FakeQuantizeDequantization dequantization; FakeQuantizeDequantization dequantization;
{ {
const auto targetInputs = node->output(0).get_target_inputs(); const auto targetInputs = node->get_output_target_inputs(0);
if (targetInputs.size() == 1ul) { if (targetInputs.size() == 1ul) {
dequantization = NetworkHelper::getDequantizationBelow(node, true); dequantization = NetworkHelper::getDequantizationBelow(node, true);
} }
@ -75,7 +75,7 @@ std::shared_ptr<VariantWrapper<std::shared_ptr<IntervalsAlignmentAttribute>>> Va
auto multiplyResult = dequantization.multiplyConstant == nullptr ? auto multiplyResult = dequantization.multiplyConstant == nullptr ?
node->get_input_node_ptr(3)->shared_from_this() : node->get_input_node_ptr(3)->shared_from_this() :
fold<opset1::Multiply>( fold<opset1::Multiply>(
foldConvert(node->get_input_node_ptr(3)->shared_from_this(), params.deqPrecision), foldConvert(node->input_value(3), params.deqPrecision),
dequantization.multiplyConstant); dequantization.multiplyConstant);
auto multiplyResultConstant = ov::as_type_ptr<opset1::Constant>(multiplyResult); auto multiplyResultConstant = ov::as_type_ptr<opset1::Constant>(multiplyResult);
@ -87,7 +87,7 @@ std::shared_ptr<VariantWrapper<std::shared_ptr<IntervalsAlignmentAttribute>>> Va
auto multiplyResult = dequantization.multiplyConstant == nullptr ? auto multiplyResult = dequantization.multiplyConstant == nullptr ?
node->get_input_node_ptr(4)->shared_from_this() : node->get_input_node_ptr(4)->shared_from_this() :
fold<opset1::Multiply>( fold<opset1::Multiply>(
foldConvert(node->get_input_node_ptr(4)->shared_from_this(), params.deqPrecision), foldConvert(node->input_value(4), params.deqPrecision),
dequantization.multiplyConstant); dequantization.multiplyConstant);
auto multiplyResultConstant = ov::as_type_ptr<opset1::Constant>(multiplyResult); auto multiplyResultConstant = ov::as_type_ptr<opset1::Constant>(multiplyResult);

View File

@ -1,16 +0,0 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision/rt_info/shared_value_attribute.hpp"
#include <memory>
#include <string>
#include <unordered_map>
#include <iterator>
#include <vector>
#include <ngraph/opsets/opset1.hpp>
#include "low_precision/network_helper.hpp"
using namespace ngraph;

View File

@ -47,7 +47,7 @@ bool SqueezeTransformation::transform(TransformationContext& context, ngraph::pa
return NetworkHelper::toScalar(dequantizationOpConstant); return NetworkHelper::toScalar(dequantizationOpConstant);
} }
if (constantShape.size() == inputRankValue) { if (constantShape.size() == inputRankValue) {
return ov::as_type_ptr<opset1::Constant>(fold<opset1::Squeeze>(dequantizationOpConstant, squeeze->get_input_node_shared_ptr(1))); return ov::as_type_ptr<opset1::Constant>(fold<opset1::Squeeze>(dequantizationOpConstant, squeeze->input_value(1)));
} }
return dequantizationOpConstant; return dequantizationOpConstant;

View File

@ -62,9 +62,9 @@ std::shared_ptr<opset1::Constant> stridedSliceDeqConstant(
const auto result = fold<ngraph::opset1::StridedSlice>( const auto result = fold<ngraph::opset1::StridedSlice>(
constant, constant,
stridedSlice->get_input_node_shared_ptr(1), stridedSlice->input_value(1),
stridedSlice->get_input_node_shared_ptr(2), stridedSlice->input_value(2),
stridedSlice->get_input_node_shared_ptr(3), stridedSlice->input_value(3),
beginMask, beginMask,
endMask, endMask,
stridedSlice->get_new_axis_mask(), stridedSlice->get_new_axis_mask(),

View File

@ -55,10 +55,10 @@ bool SubtractTransformation::transform(TransformationContext& context, ngraph::p
// X * SC - SH = X * SC - SH' * SC // X * SC - SH = X * SC - SH' * SC
// SH' = SH / SC // SH' = SH / SC
std::shared_ptr<opset1::Subtract> newSubtract = ov::as_type_ptr<opset1::Subtract>(subtract->copy_with_new_inputs({ std::shared_ptr<opset1::Subtract> newSubtract = ov::as_type_ptr<opset1::Subtract>(subtract->copy_with_new_inputs({
dequantization.multiply->get_input_node_shared_ptr(0), dequantization.multiply->input_value(0),
ngraph::pass::low_precision::fold<ngraph::opset1::Divide>( ngraph::pass::low_precision::fold<ngraph::opset1::Divide>(
subtract->get_input_node_shared_ptr(1), subtract->input_value(1),
dequantization.multiply->get_input_node_shared_ptr(1)) dequantization.multiply->input_value(1))
})); }));
std::shared_ptr<Node> newMultiply = dequantization.multiply->copy_with_new_inputs({ std::shared_ptr<Node> newMultiply = dequantization.multiply->copy_with_new_inputs({
@ -72,8 +72,8 @@ bool SubtractTransformation::transform(TransformationContext& context, ngraph::p
if (dequantization.subtract != nullptr) { if (dequantization.subtract != nullptr) {
std::shared_ptr<opset1::Subtract> newSubtract = ov::as_type_ptr<opset1::Subtract>(subtract->copy_with_new_inputs({ std::shared_ptr<opset1::Subtract> newSubtract = ov::as_type_ptr<opset1::Subtract>(subtract->copy_with_new_inputs({
dequantization.subtract->get_input_node_shared_ptr(0), dequantization.subtract->input_value(0),
fold<ngraph::opset1::Add>(subtract->get_input_node_shared_ptr(1), dequantization.subtractConstant) fold<ngraph::opset1::Add>(subtract->input_value(1), dequantization.subtractConstant)
})); }));
replace_node(subtract, newSubtract); replace_node(subtract, newSubtract);
@ -86,8 +86,8 @@ bool SubtractTransformation::transform(TransformationContext& context, ngraph::p
subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0)); subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
replace_node(subtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>( replace_node(subtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
subtract->get_input_node_shared_ptr(0), subtract->input_value(0),
subtract->get_input_node_shared_ptr(1))); subtract->input_value(1)));
} }
return true; return true;
} }

View File

@ -4,9 +4,7 @@
#include "low_precision/transparent_base_transformation.hpp" #include "low_precision/transparent_base_transformation.hpp"
#include <algorithm>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "low_precision/network_helper.hpp" #include "low_precision/network_helper.hpp"
@ -16,27 +14,20 @@ using namespace ngraph::pass;
using namespace ngraph::pass::low_precision; using namespace ngraph::pass::low_precision;
bool TransparentBaseTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) { bool TransparentBaseTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) {
auto operation = m.get_match_root(); std::shared_ptr<Node> op = m.get_match_root();
const std::shared_ptr<Node> dequantization = operation->input_value(0).get_node_shared_ptr(); if (!canBeTransformed(context, op)) {
// const std::shared_ptr<Node> dequantizationParent = dequantization->input_value(0).get_node_shared_ptr(); return false;
}
// auto newOperation = operation->copy_with_new_inputs({ dequantizationParent }); op = NetworkHelper::separateInStandaloneBranch(op);
// const auto newDequantization = dequantization->copy_with_new_inputs({ moveDequantizationAfter(context, op, NetworkHelper::getDequantization(op), true);
// newOperation,
// dequantization->input_value(1),
// dequantization->input_value(2) });
// const std::string friendlyName = operation->get_friendly_name();
//// TODO: new operation name has to be unique
// newOperation->set_friendly_name(friendlyName + "_original");
// newDequantization->set_friendly_name(friendlyName);
// replace_node(operation, newDequantization);
// NetworkHelper::moveDequantization(operation, dequantization);
return true; return true;
} }
bool TransparentBaseTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const { bool TransparentBaseTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
return true; return true;
} }
bool TransparentBaseTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return true;
}

View File

@ -48,7 +48,7 @@ bool UnsqueezeTransformation::transform(TransformationContext& context, ngraph::
} }
if (constantShape.size() == inputRankValue) { if (constantShape.size() == inputRankValue) {
return ov::as_type_ptr<opset1::Constant>(fold<opset1::Unsqueeze>(dequantizationOpConstant, unsqueeze->get_input_node_shared_ptr(1))); return ov::as_type_ptr<opset1::Constant>(fold<opset1::Unsqueeze>(dequantizationOpConstant, unsqueeze->input_value(1)));
} }
return dequantizationOpConstant; return dequantizationOpConstant;