diff --git a/tools/pot/openvino/tools/pot/algorithms/quantization/fake_quantize_configuration.py b/tools/pot/openvino/tools/pot/algorithms/quantization/fake_quantize_configuration.py index 3b13387a4e8..a8ba9f332fb 100644 --- a/tools/pot/openvino/tools/pot/algorithms/quantization/fake_quantize_configuration.py +++ b/tools/pot/openvino/tools/pot/algorithms/quantization/fake_quantize_configuration.py @@ -7,10 +7,10 @@ from copy import deepcopy from .range_estimator import get_range_estimator_config from .utils import get_hardware_config_operation_type, load_hardware_config from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNIFY_OUTPUTS, CONCAT_UNIFY_INPUTS -from ...graph.utils import find_operation_matches, get_operation_list +from ...graph.utils import find_operation_matches, get_operation_list, is_data_type_quantizable from ...graph.model_utils import get_nodes_by_type, get_node_by_name from ...graph.node_utils import get_input_shape, get_all_node_outputs,\ - get_node_input, get_node_inputs + get_node_input, get_node_inputs, get_node_data_type from ...utils.logger import get_logger logger = get_logger(__name__) @@ -372,13 +372,15 @@ def find_fqs_to_unify(model, config): # traverse down if node_.type == 'FakeQuantize' or _is_quantize_agnostic_op(node_): for child in get_all_node_outputs(node_): - if not visited_[child.name] and \ + node_data_type = get_node_data_type(child) + if not visited_[child.name] and is_data_type_quantizable(node_data_type) and \ (_is_quantize_agnostic_op(child) or _is_unified_scales_op(child)): stack_.append(child) # traverse up if node_.type != 'FakeQuantize': for parent in get_node_inputs(node_): - if parent and not visited_[parent.name] and \ + node_data_type = get_node_data_type(parent) + if parent and not visited_[parent.name] and is_data_type_quantizable(node_data_type) and \ (parent.type == 'FakeQuantize' or _is_quantize_agnostic_op(parent)): stack_.append(parent) diff --git a/tools/pot/openvino/tools/pot/graph/node_utils.py b/tools/pot/openvino/tools/pot/graph/node_utils.py index 88f7a92869f..947a58f3d57 100644 --- a/tools/pot/openvino/tools/pot/graph/node_utils.py +++ b/tools/pot/openvino/tools/pot/graph/node_utils.py @@ -261,3 +261,10 @@ def get_lstm_ends(read_value, assigns, ignore_nodes): lstm_outputs = [n for n in get_all_node_outputs(assign_input) if n.name not in ignore_nodes] return lstm_outputs + + +def get_node_data_type(node): + if node.type != 'Const' and node.in_port(0).get_source() is not None \ + and node.in_port(0).get_source().is_data_type_defined(): + return node.in_port(0).get_source().get_data_type() + return None diff --git a/tools/pot/openvino/tools/pot/graph/utils.py b/tools/pot/openvino/tools/pot/graph/utils.py index 67fe1e21d15..bce43fd8823 100644 --- a/tools/pot/openvino/tools/pot/graph/utils.py +++ b/tools/pot/openvino/tools/pot/graph/utils.py @@ -5,6 +5,8 @@ from pathlib import PosixPath, WindowsPath from copy import deepcopy import json +import numpy as np + import openvino.tools.pot.version from .cpu_patterns import get_cpu_ignored_patterns from .gpu_patterns import get_gpu_ignored_patterns @@ -212,3 +214,7 @@ def check_agnostic_and_ignored_params(model, ignored_params): ignored_params = new_ignored_params return ignored_params + + +def is_data_type_quantizable(type_node): + return type_node not in (np.int32, np.int64, bool)