feat: symmetic fq in self attention block transformer (#8616)

This commit is contained in:
Indira Salyahova 2021-12-01 15:34:40 +03:00 committed by GitHub
parent 246e628c79
commit 9d776ffc2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 99 additions and 44 deletions

View File

@ -7,7 +7,7 @@ import numpy as np
from addict import Dict
from .fake_quantize_configuration import read_all_fake_quantize_configurations, get_configurations_by_preset, \
get_configurations_by_qscheme, find_fqs_to_unify, add_range_estimator_configs
get_configurations_by_qscheme, find_fqs_to_unify, add_range_estimator_configs, change_configurations_by_model_type
from .utils import load_hardware_config, merge_nested_dicts, get_ignored_operations
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
from ...graph.node_utils import get_node_input, set_node_value, \
@ -123,6 +123,8 @@ def compute_stats_layouts(config, model, qscheme=None):
else:
fq_configuration = get_configurations_by_qscheme(fq_configuration, qscheme)
change_configurations_by_model_type(model, config, fq_configuration, hardware_config)
# get all fake quantize nodes
fq_nodes = get_nodes_by_type(model, ['FakeQuantize'])

View File

@ -10,7 +10,8 @@ from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNI
from ...graph.utils import find_operation_matches, get_operation_list, is_data_type_quantizable
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
from ...graph.node_utils import get_input_shape, get_all_node_outputs,\
get_node_input, get_node_inputs, get_node_data_type
get_node_input, get_node_inputs, get_node_data_type, check_const_input
from ...utils.logger import get_logger
logger = get_logger(__name__)
@ -109,47 +110,6 @@ def read_all_fake_quantize_configurations(config, hardware_config, model):
:return dictionary with fake quantize names as keys and
list of corresponding configurations as values
"""
def _fake_quantize_to_types():
""" Helper function to bypass graph and get fake quantize node
children nodes with predefined types
:return dictionary with fake quantize node name as a key and tuple with list of
its quantizable descendant types and boolean specifying if fake quantize node is weights
"""
def _is_quantizable(node):
return not find_operation_matches(quantize_agnostic_ops, node)
def _get_node_valuable_descendant(node):
descendants = []
queue = deque([node])
while queue:
current = queue.popleft()
children = get_all_node_outputs(current)
for child in children:
if not _is_quantizable(child):
queue.append(child)
elif child.type not in descendants:
descendants.append((child.fullname,
get_hardware_config_operation_type(child, available_types)))
if current.type == 'Split' \
and child.type == 'Concat' \
and len({child_.fullname for child_ in children}) == 1:
break
return descendants
hw_ops = get_operation_list(hardware_config)
quantize_agnostic_ops = [op[1] for op in
find_operation_matches(QUANTIZE_AGNOSTIC_OPERATIONS, hw_ops)]
out = {}
available_types = [layer['type'] for layer in hardware_config]
for fq in get_nodes_by_type(model, ['FakeQuantize']):
node_input = get_node_input(fq, 0)
out[fq.fullname] = (_get_node_valuable_descendant(fq), node_input.type == 'Const')
return out
def _is_subset(left: dict, right: dict):
""" Checks that x is a subset of y
:param left: supposed to be subset of set 'right'
@ -181,7 +141,7 @@ def read_all_fake_quantize_configurations(config, hardware_config, model):
q_config = get_fake_quantize_configuration(config)
res_fq_to_hw_conf = {}
for fq_name, (types, is_weights) in _fake_quantize_to_types().items():
for fq_name, (types, is_weights) in _fake_quantize_to_types(model, hardware_config).items():
fq_type = 'weights' if is_weights else 'activations'
res_fq_to_hw_conf[fq_name] = {fq_type: []}
for type_ in types:
@ -416,3 +376,68 @@ def find_fqs_to_unify(model, config):
logger.debug('')
return fqs_to_unify
def _fake_quantize_to_types(model, hardware_config):
""" Helper function to bypass graph and get fake quantize node
children nodes with predefined types
:return dictionary with fake quantize node name as a key and tuple with list of
its quantizable descendant types and boolean specifying if fake quantize node is weights
"""
def _is_quantizable(node):
return not find_operation_matches(quantize_agnostic_ops, node)
def _get_node_valuable_descendant(node):
descendants = []
queue = deque([node])
while queue:
current = queue.popleft()
children = get_all_node_outputs(current)
for child in children:
if not _is_quantizable(child):
queue.append(child)
elif child.type not in descendants:
descendants.append((child.name,
get_hardware_config_operation_type(child, available_types)))
if current.type == 'Split' \
and child.type == 'Concat' \
and len({child_.name for child_ in children}) == 1:
break
return descendants
hw_ops = get_operation_list(hardware_config)
quantize_agnostic_ops = [op[1] for op in
find_operation_matches(QUANTIZE_AGNOSTIC_OPERATIONS, hw_ops)]
out = {}
available_types = [layer['type'] for layer in hardware_config]
for fq in get_nodes_by_type(model, ['FakeQuantize']):
node_input = get_node_input(fq, 0)
out[fq.name] = (_get_node_valuable_descendant(fq), node_input.type == 'Const')
return out
def change_configurations_by_model_type(model, config, fq_configuration, hardware_config):
if config['model_type'] == 'transformer' and config['target_device'] in ['ANY', 'CPU', 'GPU']:
change_configurations_by_model_type_transformer(model, fq_configuration, hardware_config)
def change_configurations_by_model_type_transformer(model, fq_configuration, hardware_config):
fq_types = _fake_quantize_to_types(model, hardware_config)
for fq in get_nodes_by_type(model, ['FakeQuantize']):
node_creator_fq, is_weights = fq_types[fq.name]
node_name = None
for name, type_node in node_creator_fq:
if type_node == 'MatMul':
node_name = name
if node_name is None or is_weights:
continue
node = get_node_by_name(model, node_name)
if not check_const_input(node):
fq_configuration[fq.name]['activations'] = deepcopy(fq_configuration[fq.name]['activations'])
fq_configuration[fq.name]['activations']['mode'] = 'symmetric'

View File

@ -363,6 +363,20 @@ graph TB
---
**Name:** softmax_reshape_transpose_matmul<br/>
**Pattern:** <br/>
```mermaid
graph TB
softmax(SoftMax) --> matmul(MatMul)
add(Add) --> reshape(Reshape)
reshape_const(Const) --> reshape(Reshape)
reshape(Reshape) --> transpose(Transpose)
transpose(Transpose) --> matmul(MatMul)
```
---
**Name:** swish_activation<br/>
**Pattern:** <br/>

View File

@ -159,6 +159,20 @@ def create_softmax_reshape_matmul_pattern():
return pattern.set_name('softmax_reshape_matmul').pattern
@registry_ignore_patterns('blocks')
def create_softmax_reshape_transpose_matmul_pattern():
pattern = PatternBuilder()
pattern_2 = PatternBuilder()
softmax_out = pattern.append_single_op('SoftMax', 'softmax').get_last_node()
pattern_2.append_single_op('Add', 'add').get_last_node()
pattern_2.append_op_const('Reshape', 'reshape')
transp_out = pattern_2.append_single_op('Transpose', 'transpose').get_last_node()
pattern.pattern['nodes'] += pattern_2.pattern['nodes']
pattern.pattern['edges'] += pattern_2.pattern['edges']
pattern.insert_single_op([transp_out, softmax_out], None, 'MatMul', 'matmul')
return pattern.set_name('softmax_reshape_transpose_matmul').pattern
@registry_ignore_patterns('blocks')
def create_hswish_without_denominator_pattern():
pattern = PatternBuilder()