fix: check int branch while find fqs_to_unify (#8661)
This commit is contained in:
parent
e34a66d828
commit
d9241dda72
@ -7,10 +7,10 @@ from copy import deepcopy
|
||||
from .range_estimator import get_range_estimator_config
|
||||
from .utils import get_hardware_config_operation_type, load_hardware_config
|
||||
from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNIFY_OUTPUTS, CONCAT_UNIFY_INPUTS
|
||||
from ...graph.utils import find_operation_matches, get_operation_list
|
||||
from ...graph.utils import find_operation_matches, get_operation_list, is_data_type_quantizable
|
||||
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
|
||||
from ...graph.node_utils import get_input_shape, get_all_node_outputs,\
|
||||
get_node_input, get_node_inputs
|
||||
get_node_input, get_node_inputs, get_node_data_type
|
||||
from ...utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -372,13 +372,15 @@ def find_fqs_to_unify(model, config):
|
||||
# traverse down
|
||||
if node_.type == 'FakeQuantize' or _is_quantize_agnostic_op(node_):
|
||||
for child in get_all_node_outputs(node_):
|
||||
if not visited_[child.name] and \
|
||||
node_data_type = get_node_data_type(child)
|
||||
if not visited_[child.name] and is_data_type_quantizable(node_data_type) and \
|
||||
(_is_quantize_agnostic_op(child) or _is_unified_scales_op(child)):
|
||||
stack_.append(child)
|
||||
# traverse up
|
||||
if node_.type != 'FakeQuantize':
|
||||
for parent in get_node_inputs(node_):
|
||||
if parent and not visited_[parent.name] and \
|
||||
node_data_type = get_node_data_type(parent)
|
||||
if parent and not visited_[parent.name] and is_data_type_quantizable(node_data_type) and \
|
||||
(parent.type == 'FakeQuantize' or _is_quantize_agnostic_op(parent)):
|
||||
stack_.append(parent)
|
||||
|
||||
|
@ -261,3 +261,10 @@ def get_lstm_ends(read_value, assigns, ignore_nodes):
|
||||
lstm_outputs = [n for n in get_all_node_outputs(assign_input)
|
||||
if n.name not in ignore_nodes]
|
||||
return lstm_outputs
|
||||
|
||||
|
||||
def get_node_data_type(node):
|
||||
if node.type != 'Const' and node.in_port(0).get_source() is not None \
|
||||
and node.in_port(0).get_source().is_data_type_defined():
|
||||
return node.in_port(0).get_source().get_data_type()
|
||||
return None
|
||||
|
@ -5,6 +5,8 @@ from pathlib import PosixPath, WindowsPath
|
||||
from copy import deepcopy
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
import openvino.tools.pot.version
|
||||
from .cpu_patterns import get_cpu_ignored_patterns
|
||||
from .gpu_patterns import get_gpu_ignored_patterns
|
||||
@ -212,3 +214,7 @@ def check_agnostic_and_ignored_params(model, ignored_params):
|
||||
ignored_params = new_ignored_params
|
||||
|
||||
return ignored_params
|
||||
|
||||
|
||||
def is_data_type_quantizable(type_node):
|
||||
return type_node not in (np.int32, np.int64, bool)
|
||||
|
Loading…
Reference in New Issue
Block a user