Update rule for multi-results quantization

This commit is contained in:
Malinin, Nikita 2021-10-27 17:40:07 +03:00
parent b8fb666dbc
commit d94fe7d758
2 changed files with 18 additions and 5 deletions

View File

@ -441,7 +441,12 @@ def unify_fq_scales(model, config):
def create_renamed_layers_mapping(model, stats_layout):
changed_names_map = {}
for layer_name in stats_layout:
node = get_node_by_name(model, layer_name)
node_name = layer_name
port_id = None
if isinstance(layer_name, tuple):
node_name, port_id = layer_name
node = get_node_by_name(model, node_name)
if node is not None and 'orig_node_name' in node:
changed_names_map[node.name] = node['orig_node_name']
name_change_to = node['orig_node_name'] if port_id is None else (node['orig_node_name'], port_id)
changed_names_map[layer_name] = name_change_to
return changed_names_map

View File

@ -649,12 +649,20 @@ class FakeQuantizeNameSwapper(BackReplacementPattern):
def change_names(_, match):
fq_node = match['fq']
input_node = get_node_input(fq_node, 0)
new_fq_name = copy(input_node.name)
if 'orig_node_name' in input_node:
new_fq_name = copy(input_node['orig_node_name'])
input_node_outputs = get_all_node_outputs(input_node)
if all([op.type == 'FakeQuantize' for op in input_node_outputs]):
new_fq_name += '.{}'.format(fq_node.in_port(0).get_source().idx)
fq_node['orig_fq_name'] = copy(fq_node.name)
fq_node.name = copy(input_node.name)
fq_node.name = copy(new_fq_name)
input_node['orig_node_name'] = copy(input_node.name)
input_node.name = '{original_name}/pre_fq_input'.format(original_name=input_node.name)
if 'orig_node_name' not in input_node:
input_node['orig_node_name'] = copy(input_node.name)
input_node.name = '{original_name}/pre_fq_input'.format(original_name=input_node.name)
pattern = get_fq_result_pattern()
apply_pattern(