Update rule for multi-results quantization
This commit is contained in:
parent
b8fb666dbc
commit
d94fe7d758
@ -441,7 +441,12 @@ def unify_fq_scales(model, config):
|
||||
def create_renamed_layers_mapping(model, stats_layout):
|
||||
changed_names_map = {}
|
||||
for layer_name in stats_layout:
|
||||
node = get_node_by_name(model, layer_name)
|
||||
node_name = layer_name
|
||||
port_id = None
|
||||
if isinstance(layer_name, tuple):
|
||||
node_name, port_id = layer_name
|
||||
node = get_node_by_name(model, node_name)
|
||||
if node is not None and 'orig_node_name' in node:
|
||||
changed_names_map[node.name] = node['orig_node_name']
|
||||
name_change_to = node['orig_node_name'] if port_id is None else (node['orig_node_name'], port_id)
|
||||
changed_names_map[layer_name] = name_change_to
|
||||
return changed_names_map
|
||||
|
@ -649,12 +649,20 @@ class FakeQuantizeNameSwapper(BackReplacementPattern):
|
||||
def change_names(_, match):
|
||||
fq_node = match['fq']
|
||||
input_node = get_node_input(fq_node, 0)
|
||||
new_fq_name = copy(input_node.name)
|
||||
if 'orig_node_name' in input_node:
|
||||
new_fq_name = copy(input_node['orig_node_name'])
|
||||
|
||||
input_node_outputs = get_all_node_outputs(input_node)
|
||||
if all([op.type == 'FakeQuantize' for op in input_node_outputs]):
|
||||
new_fq_name += '.{}'.format(fq_node.in_port(0).get_source().idx)
|
||||
|
||||
fq_node['orig_fq_name'] = copy(fq_node.name)
|
||||
fq_node.name = copy(input_node.name)
|
||||
fq_node.name = copy(new_fq_name)
|
||||
|
||||
input_node['orig_node_name'] = copy(input_node.name)
|
||||
input_node.name = '{original_name}/pre_fq_input'.format(original_name=input_node.name)
|
||||
if 'orig_node_name' not in input_node:
|
||||
input_node['orig_node_name'] = copy(input_node.name)
|
||||
input_node.name = '{original_name}/pre_fq_input'.format(original_name=input_node.name)
|
||||
|
||||
pattern = get_fq_result_pattern()
|
||||
apply_pattern(
|
||||
|
Loading…
Reference in New Issue
Block a user