fix: check int branch while find fqs_to_unify (#8661)

This commit is contained in:
Indira Salyahova 2021-11-19 15:21:27 +03:00 committed by GitHub
parent e34a66d828
commit d9241dda72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 4 deletions

View File

@ -7,10 +7,10 @@ from copy import deepcopy
from .range_estimator import get_range_estimator_config
from .utils import get_hardware_config_operation_type, load_hardware_config
from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNIFY_OUTPUTS, CONCAT_UNIFY_INPUTS
from ...graph.utils import find_operation_matches, get_operation_list
from ...graph.utils import find_operation_matches, get_operation_list, is_data_type_quantizable
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
from ...graph.node_utils import get_input_shape, get_all_node_outputs,\
get_node_input, get_node_inputs
get_node_input, get_node_inputs, get_node_data_type
from ...utils.logger import get_logger
logger = get_logger(__name__)
@ -372,13 +372,15 @@ def find_fqs_to_unify(model, config):
# traverse down
if node_.type == 'FakeQuantize' or _is_quantize_agnostic_op(node_):
for child in get_all_node_outputs(node_):
if not visited_[child.name] and \
node_data_type = get_node_data_type(child)
if not visited_[child.name] and is_data_type_quantizable(node_data_type) and \
(_is_quantize_agnostic_op(child) or _is_unified_scales_op(child)):
stack_.append(child)
# traverse up
if node_.type != 'FakeQuantize':
for parent in get_node_inputs(node_):
if parent and not visited_[parent.name] and \
node_data_type = get_node_data_type(parent)
if parent and not visited_[parent.name] and is_data_type_quantizable(node_data_type) and \
(parent.type == 'FakeQuantize' or _is_quantize_agnostic_op(parent)):
stack_.append(parent)

View File

@ -261,3 +261,10 @@ def get_lstm_ends(read_value, assigns, ignore_nodes):
lstm_outputs = [n for n in get_all_node_outputs(assign_input)
if n.name not in ignore_nodes]
return lstm_outputs
def get_node_data_type(node):
if node.type != 'Const' and node.in_port(0).get_source() is not None \
and node.in_port(0).get_source().is_data_type_defined():
return node.in_port(0).get_source().get_data_type()
return None

View File

@ -5,6 +5,8 @@ from pathlib import PosixPath, WindowsPath
from copy import deepcopy
import json
import numpy as np
import openvino.tools.pot.version
from .cpu_patterns import get_cpu_ignored_patterns
from .gpu_patterns import get_gpu_ignored_patterns
@ -212,3 +214,7 @@ def check_agnostic_and_ignored_params(model, ignored_params):
ignored_params = new_ignored_params
return ignored_params
def is_data_type_quantizable(type_node):
return type_node not in (np.int32, np.int64, bool)