[LPT]MoveFakeQuantize resolved minor comments (#9614)
* Q/DQ + mulichannel support backup fix interval mfk_functiun.cpp WIP moveDequantizationBefore add moveDequantizationBefore function add cpu and gpu tests attribute cmp false attribute cmp false rm temp line mkl-dnn update concat with multichanels for mOve_fake_quantize_function, bad runtime info for q/dq rm extra qualification fix run time info for q/dq add support of multichanel fakequantize, bad test for it work tests for multi chanel FQ rm workaround cpplint fix cpplint fix don't worl Variadic split ieFuncTest work cpuFuncTest work Fix benchmark_app build (#7577) [GPU] Added onednn dependency. (#6564) cpp lint cpplint fix get_shape fix fq constants cpp lint some fix in mfk.cpp resolve conversations, add spil_nodes function add new tests for multi-chanels, rename NetworkHelper::split_consts_before_concat() fix get fq constants * add new multi-chanels test and use constant_fold to split constant * remove extra spaces fix namespase terminated fix namespase terminated * resolved minor comments * added check for convert_q
This commit is contained in:
parent
77c2c5fab3
commit
1df9f958e2
@ -179,7 +179,7 @@ public:
|
||||
const bool updatePrecision,
|
||||
const bool moveSubtract);
|
||||
|
||||
static std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> split_consts_before_concat(
|
||||
static std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> splitConstantsBeforeConcat(
|
||||
const std::shared_ptr<ov::Node> concat,
|
||||
const std::vector<std::shared_ptr<opset1::Constant>> currConstants);
|
||||
|
||||
|
@ -62,9 +62,8 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
|
||||
auto operation = fq->get_input_node_shared_ptr(0);
|
||||
std::shared_ptr<ngraph::Node> concat;
|
||||
bool without_operation = true;
|
||||
std::string fq_original_name = fq->get_friendly_name(),
|
||||
operation_original_name,
|
||||
convert_q_original_name;
|
||||
const std::string fq_original_name = fq->get_friendly_name();
|
||||
std::string operation_original_name;
|
||||
if (is_type<opset1::Concat>(operation)) {
|
||||
concat = operation;
|
||||
} else {
|
||||
@ -75,7 +74,7 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
|
||||
if (!ConcatTransformation::isQuantizedStatic(concat)) {
|
||||
return false;
|
||||
}
|
||||
auto convert_q = (*fq->output(0).get_target_inputs().begin()).get_node()->shared_from_this();
|
||||
const auto convert_q = fq->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
|
||||
bool q_dq = is_type<opset1::Convert>(convert_q);
|
||||
std::vector<std::shared_ptr<opset1::Constant>> currConstants(4);
|
||||
bool multi_chanels = false;
|
||||
@ -84,13 +83,13 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
|
||||
const auto concat_axis = concatNode->get_concatenation_axis();
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
currConstants[i] = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(i + 1));
|
||||
if (!multi_chanels && currConstants[i]->get_shape().size() > 1 && currConstants[i]->get_shape()[concat_axis] != 1) {
|
||||
if (!multi_chanels && currConstants[i]->get_shape().size() > (concat_axis + 1ul) && currConstants[i]->get_shape()[concat_axis] != 1) {
|
||||
multi_chanels = true;
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> newConstants;
|
||||
if (multi_chanels) {
|
||||
newConstants = NetworkHelper::split_consts_before_concat(concat, currConstants);
|
||||
newConstants = NetworkHelper::splitConstantsBeforeConcat(concat, currConstants);
|
||||
}
|
||||
std::vector<std::shared_ptr<ngraph::Node>> newNodes;
|
||||
for (size_t i{ 0 }; i < number_of_concat_inputs; ++i) {
|
||||
@ -152,7 +151,11 @@ bool MoveFakeQuantize::canBeTransformed(const TransformationContext& context, st
|
||||
if (!ConcatTransformation::isQuantizedStatic(concat)) {
|
||||
return false;
|
||||
}
|
||||
auto convert_q = (*layer->output(0).get_target_inputs().begin()).get_node()->shared_from_this();
|
||||
const auto convert_q_target_inputs = layer->output(0).get_target_inputs();
|
||||
if (convert_q_target_inputs.empty()) {
|
||||
return false;
|
||||
}
|
||||
const auto convert_q = convert_q_target_inputs.begin()->get_node()->shared_from_this();
|
||||
bool q_dq = is_type<opset1::Convert>(convert_q);
|
||||
if (q_dq && (convert_q->get_output_size() != 1 || layer->get_output_size() != 1)) {
|
||||
return false;
|
||||
|
@ -1679,11 +1679,14 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefor
|
||||
const auto concatNode = as_type_ptr<opset1::Concat>(operation);
|
||||
auto axis = concatNode->get_concatenation_axis();
|
||||
if (dequantization.multiply && dequantization.multiplyConstant->get_shape().size() > 1 && dequantization.multiplyConstant->get_shape()[axis] != 1) {
|
||||
multiplyConstants = NetworkHelper::split_consts_before_concat(operation, { dequantization.multiplyConstant });
|
||||
multiplyConstants = NetworkHelper::splitConstantsBeforeConcat(operation, { dequantization.multiplyConstant });
|
||||
}
|
||||
if (dequantization.subtract && dequantization.subtractConstant->get_shape().size() > 1 && dequantization.subtractConstant->get_shape()[axis] != 1) {
|
||||
subtractConstants = NetworkHelper::split_consts_before_concat(operation, { dequantization.subtractConstant });
|
||||
subtractConstants = NetworkHelper::splitConstantsBeforeConcat(operation, { dequantization.subtractConstant });
|
||||
}
|
||||
} else {
|
||||
multiplyConstants = {{ dequantization.multiplyConstant }};
|
||||
subtractConstants = {{ dequantization.subtractConstant }};
|
||||
}
|
||||
std::vector<std::shared_ptr<ngraph::Node>> newNodes;
|
||||
for (size_t i = 0; i < operation->get_input_size(); ++i) {
|
||||
@ -1750,21 +1753,17 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefor
|
||||
NetworkHelper::copyInfo(operation, newOperation);
|
||||
replace_node(dequantization.multiply, newOperation);
|
||||
|
||||
auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation);
|
||||
if (op != nullptr) {
|
||||
if (updatePrecision) {
|
||||
op->set_overridden_output_type(newOperation->get_input_element_type(0));
|
||||
} else if (dequantization.multiply) {
|
||||
op->set_overridden_output_type(dequantization.multiplyConstant->get_element_type());
|
||||
} else if (dequantization.subtract) {
|
||||
op->set_overridden_output_type(dequantization.subtractConstant->get_element_type());
|
||||
}
|
||||
std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
|
||||
if (const auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation)) {
|
||||
op->set_overridden_output_type(updatePrecision ?
|
||||
newOperation->get_input_element_type(0) :
|
||||
dequantization.multiplyConstant->get_element_type());
|
||||
newOperation->validate_and_infer_types();
|
||||
}
|
||||
|
||||
return InsertDequantizationResult(newOperation, dequantization.multiply);
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> NetworkHelper::split_consts_before_concat(const std::shared_ptr<ov::Node> concat,
|
||||
std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> NetworkHelper::splitConstantsBeforeConcat(const std::shared_ptr<ov::Node> concat,
|
||||
const std::vector<std::shared_ptr<opset1::Constant>> currConstants) {
|
||||
std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> newConstants(currConstants.size());
|
||||
auto number_of_concat_inputs = concat->get_input_size();
|
||||
@ -1788,7 +1787,8 @@ std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> NetworkHelpe
|
||||
OutputVector outputResults(split->get_output_size());
|
||||
auto foldResult = split->constant_fold(outputResults, split->input_values());
|
||||
if (!foldResult) {
|
||||
// handle potential constant fold issue here
|
||||
THROW_IE_LPT_EXCEPTION(*concat) << "error when splitting constants before concat " <<
|
||||
concat->get_friendly_name();
|
||||
}
|
||||
for (auto outputResult : outputResults) {
|
||||
auto constant = as_type_ptr<opset1::Constant>(outputResult.get_node_shared_ptr());
|
||||
|
Loading…
Reference in New Issue
Block a user