fix: check int branch while find fqs_to_unify (#8661)
This commit is contained in:
parent
e34a66d828
commit
d9241dda72
@ -7,10 +7,10 @@ from copy import deepcopy
|
|||||||
from .range_estimator import get_range_estimator_config
|
from .range_estimator import get_range_estimator_config
|
||||||
from .utils import get_hardware_config_operation_type, load_hardware_config
|
from .utils import get_hardware_config_operation_type, load_hardware_config
|
||||||
from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNIFY_OUTPUTS, CONCAT_UNIFY_INPUTS
|
from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNIFY_OUTPUTS, CONCAT_UNIFY_INPUTS
|
||||||
from ...graph.utils import find_operation_matches, get_operation_list
|
from ...graph.utils import find_operation_matches, get_operation_list, is_data_type_quantizable
|
||||||
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
|
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
|
||||||
from ...graph.node_utils import get_input_shape, get_all_node_outputs,\
|
from ...graph.node_utils import get_input_shape, get_all_node_outputs,\
|
||||||
get_node_input, get_node_inputs
|
get_node_input, get_node_inputs, get_node_data_type
|
||||||
from ...utils.logger import get_logger
|
from ...utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@ -372,13 +372,15 @@ def find_fqs_to_unify(model, config):
|
|||||||
# traverse down
|
# traverse down
|
||||||
if node_.type == 'FakeQuantize' or _is_quantize_agnostic_op(node_):
|
if node_.type == 'FakeQuantize' or _is_quantize_agnostic_op(node_):
|
||||||
for child in get_all_node_outputs(node_):
|
for child in get_all_node_outputs(node_):
|
||||||
if not visited_[child.name] and \
|
node_data_type = get_node_data_type(child)
|
||||||
|
if not visited_[child.name] and is_data_type_quantizable(node_data_type) and \
|
||||||
(_is_quantize_agnostic_op(child) or _is_unified_scales_op(child)):
|
(_is_quantize_agnostic_op(child) or _is_unified_scales_op(child)):
|
||||||
stack_.append(child)
|
stack_.append(child)
|
||||||
# traverse up
|
# traverse up
|
||||||
if node_.type != 'FakeQuantize':
|
if node_.type != 'FakeQuantize':
|
||||||
for parent in get_node_inputs(node_):
|
for parent in get_node_inputs(node_):
|
||||||
if parent and not visited_[parent.name] and \
|
node_data_type = get_node_data_type(parent)
|
||||||
|
if parent and not visited_[parent.name] and is_data_type_quantizable(node_data_type) and \
|
||||||
(parent.type == 'FakeQuantize' or _is_quantize_agnostic_op(parent)):
|
(parent.type == 'FakeQuantize' or _is_quantize_agnostic_op(parent)):
|
||||||
stack_.append(parent)
|
stack_.append(parent)
|
||||||
|
|
||||||
|
@ -261,3 +261,10 @@ def get_lstm_ends(read_value, assigns, ignore_nodes):
|
|||||||
lstm_outputs = [n for n in get_all_node_outputs(assign_input)
|
lstm_outputs = [n for n in get_all_node_outputs(assign_input)
|
||||||
if n.name not in ignore_nodes]
|
if n.name not in ignore_nodes]
|
||||||
return lstm_outputs
|
return lstm_outputs
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_data_type(node):
|
||||||
|
if node.type != 'Const' and node.in_port(0).get_source() is not None \
|
||||||
|
and node.in_port(0).get_source().is_data_type_defined():
|
||||||
|
return node.in_port(0).get_source().get_data_type()
|
||||||
|
return None
|
||||||
|
@ -5,6 +5,8 @@ from pathlib import PosixPath, WindowsPath
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import openvino.tools.pot.version
|
import openvino.tools.pot.version
|
||||||
from .cpu_patterns import get_cpu_ignored_patterns
|
from .cpu_patterns import get_cpu_ignored_patterns
|
||||||
from .gpu_patterns import get_gpu_ignored_patterns
|
from .gpu_patterns import get_gpu_ignored_patterns
|
||||||
@ -212,3 +214,7 @@ def check_agnostic_and_ignored_params(model, ignored_params):
|
|||||||
ignored_params = new_ignored_params
|
ignored_params = new_ignored_params
|
||||||
|
|
||||||
return ignored_params
|
return ignored_params
|
||||||
|
|
||||||
|
|
||||||
|
def is_data_type_quantizable(type_node):
|
||||||
|
return type_node not in (np.int32, np.int64, bool)
|
||||||
|
Loading…
Reference in New Issue
Block a user