[LPT]MoveFakeQuantize resolved minor comments (#9614)

* Q/DQ + mulichannel support

backup

fix interval

mfk_functiun.cpp

WIP moveDequantizationBefore

add moveDequantizationBefore function

add cpu and gpu tests

attribute cmp false

attribute cmp false

rm temp line

mkl-dnn update

concat with multichanels for mOve_fake_quantize_function, bad runtime info for q/dq

rm extra qualification

fix run time info for q/dq

add support of multichanel fakequantize, bad test for it

work tests for multi chanel FQ

rm workaround

cpplint fix

cpplint fix

don't worl Variadic split

ieFuncTest work

cpuFuncTest work

Fix benchmark_app build (#7577)

[GPU] Added onednn dependency. (#6564)

cpp lint

cpplint

fix get_shape

fix fq constants

cpp lint

some fix in mfk.cpp

resolve conversations, add spil_nodes function

add new tests for multi-chanels, rename NetworkHelper::split_consts_before_concat()

fix get fq constants

* add new multi-chanels test and use constant_fold to split constant

* remove extra spaces

fix namespase terminated

fix namespase terminated

* resolved minor comments

* added check for convert_q
This commit is contained in:
Nikita Demashov 2022-01-14 10:40:10 +03:00 committed by GitHub
parent 77c2c5fab3
commit 1df9f958e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 22 deletions

View File

@ -179,7 +179,7 @@ public:
const bool updatePrecision,
const bool moveSubtract);
static std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> split_consts_before_concat(
static std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> splitConstantsBeforeConcat(
const std::shared_ptr<ov::Node> concat,
const std::vector<std::shared_ptr<opset1::Constant>> currConstants);

View File

@ -62,9 +62,8 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
auto operation = fq->get_input_node_shared_ptr(0);
std::shared_ptr<ngraph::Node> concat;
bool without_operation = true;
std::string fq_original_name = fq->get_friendly_name(),
operation_original_name,
convert_q_original_name;
const std::string fq_original_name = fq->get_friendly_name();
std::string operation_original_name;
if (is_type<opset1::Concat>(operation)) {
concat = operation;
} else {
@ -75,7 +74,7 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
if (!ConcatTransformation::isQuantizedStatic(concat)) {
return false;
}
auto convert_q = (*fq->output(0).get_target_inputs().begin()).get_node()->shared_from_this();
const auto convert_q = fq->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
bool q_dq = is_type<opset1::Convert>(convert_q);
std::vector<std::shared_ptr<opset1::Constant>> currConstants(4);
bool multi_chanels = false;
@ -84,13 +83,13 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
const auto concat_axis = concatNode->get_concatenation_axis();
for (size_t i = 0; i < 4; i++) {
currConstants[i] = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(i + 1));
if (!multi_chanels && currConstants[i]->get_shape().size() > 1 && currConstants[i]->get_shape()[concat_axis] != 1) {
if (!multi_chanels && currConstants[i]->get_shape().size() > (concat_axis + 1ul) && currConstants[i]->get_shape()[concat_axis] != 1) {
multi_chanels = true;
}
}
std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> newConstants;
if (multi_chanels) {
newConstants = NetworkHelper::split_consts_before_concat(concat, currConstants);
newConstants = NetworkHelper::splitConstantsBeforeConcat(concat, currConstants);
}
std::vector<std::shared_ptr<ngraph::Node>> newNodes;
for (size_t i{ 0 }; i < number_of_concat_inputs; ++i) {
@ -152,7 +151,11 @@ bool MoveFakeQuantize::canBeTransformed(const TransformationContext& context, st
if (!ConcatTransformation::isQuantizedStatic(concat)) {
return false;
}
auto convert_q = (*layer->output(0).get_target_inputs().begin()).get_node()->shared_from_this();
const auto convert_q_target_inputs = layer->output(0).get_target_inputs();
if (convert_q_target_inputs.empty()) {
return false;
}
const auto convert_q = convert_q_target_inputs.begin()->get_node()->shared_from_this();
bool q_dq = is_type<opset1::Convert>(convert_q);
if (q_dq && (convert_q->get_output_size() != 1 || layer->get_output_size() != 1)) {
return false;

View File

@ -1679,11 +1679,14 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefor
const auto concatNode = as_type_ptr<opset1::Concat>(operation);
auto axis = concatNode->get_concatenation_axis();
if (dequantization.multiply && dequantization.multiplyConstant->get_shape().size() > 1 && dequantization.multiplyConstant->get_shape()[axis] != 1) {
multiplyConstants = NetworkHelper::split_consts_before_concat(operation, { dequantization.multiplyConstant });
multiplyConstants = NetworkHelper::splitConstantsBeforeConcat(operation, { dequantization.multiplyConstant });
}
if (dequantization.subtract && dequantization.subtractConstant->get_shape().size() > 1 && dequantization.subtractConstant->get_shape()[axis] != 1) {
subtractConstants = NetworkHelper::split_consts_before_concat(operation, { dequantization.subtractConstant });
subtractConstants = NetworkHelper::splitConstantsBeforeConcat(operation, { dequantization.subtractConstant });
}
} else {
multiplyConstants = {{ dequantization.multiplyConstant }};
subtractConstants = {{ dequantization.subtractConstant }};
}
std::vector<std::shared_ptr<ngraph::Node>> newNodes;
for (size_t i = 0; i < operation->get_input_size(); ++i) {
@ -1750,21 +1753,17 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefor
NetworkHelper::copyInfo(operation, newOperation);
replace_node(dequantization.multiply, newOperation);
auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation);
if (op != nullptr) {
if (updatePrecision) {
op->set_overridden_output_type(newOperation->get_input_element_type(0));
} else if (dequantization.multiply) {
op->set_overridden_output_type(dequantization.multiplyConstant->get_element_type());
} else if (dequantization.subtract) {
op->set_overridden_output_type(dequantization.subtractConstant->get_element_type());
}
std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
if (const auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation)) {
op->set_overridden_output_type(updatePrecision ?
newOperation->get_input_element_type(0) :
dequantization.multiplyConstant->get_element_type());
newOperation->validate_and_infer_types();
}
return InsertDequantizationResult(newOperation, dequantization.multiply);
}
std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> NetworkHelper::split_consts_before_concat(const std::shared_ptr<ov::Node> concat,
std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> NetworkHelper::splitConstantsBeforeConcat(const std::shared_ptr<ov::Node> concat,
const std::vector<std::shared_ptr<opset1::Constant>> currConstants) {
std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> newConstants(currConstants.size());
auto number_of_concat_inputs = concat->get_input_size();
@ -1788,7 +1787,8 @@ std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> NetworkHelpe
OutputVector outputResults(split->get_output_size());
auto foldResult = split->constant_fold(outputResults, split->input_values());
if (!foldResult) {
// handle potential constant fold issue here
THROW_IE_LPT_EXCEPTION(*concat) << "error when splitting constants before concat " <<
concat->get_friendly_name();
}
for (auto outputResult : outputResults) {
auto constant = as_type_ptr<opset1::Constant>(outputResult.get_node_shared_ptr());