feat: symmetic fq in self attention block transformer (#8616)
This commit is contained in:
parent
246e628c79
commit
9d776ffc2f
@ -7,7 +7,7 @@ import numpy as np
|
||||
from addict import Dict
|
||||
|
||||
from .fake_quantize_configuration import read_all_fake_quantize_configurations, get_configurations_by_preset, \
|
||||
get_configurations_by_qscheme, find_fqs_to_unify, add_range_estimator_configs
|
||||
get_configurations_by_qscheme, find_fqs_to_unify, add_range_estimator_configs, change_configurations_by_model_type
|
||||
from .utils import load_hardware_config, merge_nested_dicts, get_ignored_operations
|
||||
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
|
||||
from ...graph.node_utils import get_node_input, set_node_value, \
|
||||
@ -123,6 +123,8 @@ def compute_stats_layouts(config, model, qscheme=None):
|
||||
else:
|
||||
fq_configuration = get_configurations_by_qscheme(fq_configuration, qscheme)
|
||||
|
||||
change_configurations_by_model_type(model, config, fq_configuration, hardware_config)
|
||||
|
||||
# get all fake quantize nodes
|
||||
fq_nodes = get_nodes_by_type(model, ['FakeQuantize'])
|
||||
|
||||
|
@ -10,7 +10,8 @@ from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNI
|
||||
from ...graph.utils import find_operation_matches, get_operation_list, is_data_type_quantizable
|
||||
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
|
||||
from ...graph.node_utils import get_input_shape, get_all_node_outputs,\
|
||||
get_node_input, get_node_inputs, get_node_data_type
|
||||
get_node_input, get_node_inputs, get_node_data_type, check_const_input
|
||||
|
||||
from ...utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -109,47 +110,6 @@ def read_all_fake_quantize_configurations(config, hardware_config, model):
|
||||
:return dictionary with fake quantize names as keys and
|
||||
list of corresponding configurations as values
|
||||
"""
|
||||
|
||||
def _fake_quantize_to_types():
|
||||
""" Helper function to bypass graph and get fake quantize node
|
||||
children nodes with predefined types
|
||||
:return dictionary with fake quantize node name as a key and tuple with list of
|
||||
its quantizable descendant types and boolean specifying if fake quantize node is weights
|
||||
"""
|
||||
|
||||
def _is_quantizable(node):
|
||||
return not find_operation_matches(quantize_agnostic_ops, node)
|
||||
|
||||
def _get_node_valuable_descendant(node):
|
||||
descendants = []
|
||||
queue = deque([node])
|
||||
while queue:
|
||||
current = queue.popleft()
|
||||
children = get_all_node_outputs(current)
|
||||
for child in children:
|
||||
if not _is_quantizable(child):
|
||||
queue.append(child)
|
||||
elif child.type not in descendants:
|
||||
descendants.append((child.fullname,
|
||||
get_hardware_config_operation_type(child, available_types)))
|
||||
if current.type == 'Split' \
|
||||
and child.type == 'Concat' \
|
||||
and len({child_.fullname for child_ in children}) == 1:
|
||||
break
|
||||
return descendants
|
||||
|
||||
hw_ops = get_operation_list(hardware_config)
|
||||
quantize_agnostic_ops = [op[1] for op in
|
||||
find_operation_matches(QUANTIZE_AGNOSTIC_OPERATIONS, hw_ops)]
|
||||
|
||||
out = {}
|
||||
available_types = [layer['type'] for layer in hardware_config]
|
||||
for fq in get_nodes_by_type(model, ['FakeQuantize']):
|
||||
node_input = get_node_input(fq, 0)
|
||||
out[fq.fullname] = (_get_node_valuable_descendant(fq), node_input.type == 'Const')
|
||||
|
||||
return out
|
||||
|
||||
def _is_subset(left: dict, right: dict):
|
||||
""" Checks that x is a subset of y
|
||||
:param left: supposed to be subset of set 'right'
|
||||
@ -181,7 +141,7 @@ def read_all_fake_quantize_configurations(config, hardware_config, model):
|
||||
q_config = get_fake_quantize_configuration(config)
|
||||
|
||||
res_fq_to_hw_conf = {}
|
||||
for fq_name, (types, is_weights) in _fake_quantize_to_types().items():
|
||||
for fq_name, (types, is_weights) in _fake_quantize_to_types(model, hardware_config).items():
|
||||
fq_type = 'weights' if is_weights else 'activations'
|
||||
res_fq_to_hw_conf[fq_name] = {fq_type: []}
|
||||
for type_ in types:
|
||||
@ -416,3 +376,68 @@ def find_fqs_to_unify(model, config):
|
||||
logger.debug('')
|
||||
|
||||
return fqs_to_unify
|
||||
|
||||
|
||||
def _fake_quantize_to_types(model, hardware_config):
|
||||
""" Helper function to bypass graph and get fake quantize node
|
||||
children nodes with predefined types
|
||||
:return dictionary with fake quantize node name as a key and tuple with list of
|
||||
its quantizable descendant types and boolean specifying if fake quantize node is weights
|
||||
"""
|
||||
|
||||
def _is_quantizable(node):
|
||||
return not find_operation_matches(quantize_agnostic_ops, node)
|
||||
|
||||
def _get_node_valuable_descendant(node):
|
||||
descendants = []
|
||||
queue = deque([node])
|
||||
while queue:
|
||||
current = queue.popleft()
|
||||
children = get_all_node_outputs(current)
|
||||
for child in children:
|
||||
if not _is_quantizable(child):
|
||||
queue.append(child)
|
||||
elif child.type not in descendants:
|
||||
descendants.append((child.name,
|
||||
get_hardware_config_operation_type(child, available_types)))
|
||||
if current.type == 'Split' \
|
||||
and child.type == 'Concat' \
|
||||
and len({child_.name for child_ in children}) == 1:
|
||||
break
|
||||
return descendants
|
||||
|
||||
hw_ops = get_operation_list(hardware_config)
|
||||
quantize_agnostic_ops = [op[1] for op in
|
||||
find_operation_matches(QUANTIZE_AGNOSTIC_OPERATIONS, hw_ops)]
|
||||
|
||||
out = {}
|
||||
available_types = [layer['type'] for layer in hardware_config]
|
||||
for fq in get_nodes_by_type(model, ['FakeQuantize']):
|
||||
node_input = get_node_input(fq, 0)
|
||||
out[fq.name] = (_get_node_valuable_descendant(fq), node_input.type == 'Const')
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def change_configurations_by_model_type(model, config, fq_configuration, hardware_config):
|
||||
if config['model_type'] == 'transformer' and config['target_device'] in ['ANY', 'CPU', 'GPU']:
|
||||
change_configurations_by_model_type_transformer(model, fq_configuration, hardware_config)
|
||||
|
||||
|
||||
def change_configurations_by_model_type_transformer(model, fq_configuration, hardware_config):
|
||||
fq_types = _fake_quantize_to_types(model, hardware_config)
|
||||
for fq in get_nodes_by_type(model, ['FakeQuantize']):
|
||||
node_creator_fq, is_weights = fq_types[fq.name]
|
||||
node_name = None
|
||||
for name, type_node in node_creator_fq:
|
||||
if type_node == 'MatMul':
|
||||
node_name = name
|
||||
|
||||
if node_name is None or is_weights:
|
||||
continue
|
||||
|
||||
node = get_node_by_name(model, node_name)
|
||||
|
||||
if not check_const_input(node):
|
||||
fq_configuration[fq.name]['activations'] = deepcopy(fq_configuration[fq.name]['activations'])
|
||||
fq_configuration[fq.name]['activations']['mode'] = 'symmetric'
|
||||
|
@ -363,6 +363,20 @@ graph TB
|
||||
|
||||
---
|
||||
|
||||
**Name:** softmax_reshape_transpose_matmul<br/>
|
||||
**Pattern:** <br/>
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
softmax(SoftMax) --> matmul(MatMul)
|
||||
add(Add) --> reshape(Reshape)
|
||||
reshape_const(Const) --> reshape(Reshape)
|
||||
reshape(Reshape) --> transpose(Transpose)
|
||||
transpose(Transpose) --> matmul(MatMul)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**Name:** swish_activation<br/>
|
||||
**Pattern:** <br/>
|
||||
|
||||
|
@ -159,6 +159,20 @@ def create_softmax_reshape_matmul_pattern():
|
||||
return pattern.set_name('softmax_reshape_matmul').pattern
|
||||
|
||||
|
||||
@registry_ignore_patterns('blocks')
|
||||
def create_softmax_reshape_transpose_matmul_pattern():
|
||||
pattern = PatternBuilder()
|
||||
pattern_2 = PatternBuilder()
|
||||
softmax_out = pattern.append_single_op('SoftMax', 'softmax').get_last_node()
|
||||
pattern_2.append_single_op('Add', 'add').get_last_node()
|
||||
pattern_2.append_op_const('Reshape', 'reshape')
|
||||
transp_out = pattern_2.append_single_op('Transpose', 'transpose').get_last_node()
|
||||
pattern.pattern['nodes'] += pattern_2.pattern['nodes']
|
||||
pattern.pattern['edges'] += pattern_2.pattern['edges']
|
||||
pattern.insert_single_op([transp_out, softmax_out], None, 'MatMul', 'matmul')
|
||||
return pattern.set_name('softmax_reshape_transpose_matmul').pattern
|
||||
|
||||
|
||||
@registry_ignore_patterns('blocks')
|
||||
def create_hswish_without_denominator_pattern():
|
||||
pattern = PatternBuilder()
|
||||
|
Loading…
Reference in New Issue
Block a user