[POT] Added typecast while nodes updating & creation (#8621)

* Update type casting while update

* Fix Pylint

* Update precision setting for FQ values

* Update FQ values precision setting

* Fix scales dump

* Update reference

* Another reference update
This commit is contained in:
Nikita Malinin 2021-11-23 17:13:04 +03:00 committed by GitHub
parent 6156626706
commit 516d510045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 13 deletions

View File

@ -84,7 +84,10 @@ def set_node_value(node: Node, value: np.ndarray):
""" """
if node.type != 'Const': if node.type != 'Const':
raise Exception('Can\'t set value for non-constant node {}'.format(node.name)) raise Exception('Can\'t set value for non-constant node {}'.format(node.name))
node.out_port(0).data.set_value(value) data_type = np.float32
if node.out_port(0).is_data_type_defined():
data_type = node.out_port(0).get_data_type()
node.out_port(0).data.set_value(np.array(value).astype(data_type))
def get_node_value(node: Node): def get_node_value(node: Node):

View File

@ -27,7 +27,7 @@ from . import node_utils as nu
from .pattern_utils import get_fq_result_pattern from .pattern_utils import get_fq_result_pattern
from .special_operations import OPERATIONS_WITH_WEIGHTS, DETECTION_OUTPUT_FINAL_TYPES, SPLIT_OPERATIONS from .special_operations import OPERATIONS_WITH_WEIGHTS, DETECTION_OUTPUT_FINAL_TYPES, SPLIT_OPERATIONS
from .utils import find_operation_matches, is_ignored, get_hw_aware_ignored_patterns from .utils import find_operation_matches, is_ignored, get_hw_aware_ignored_patterns
from ..graph.node_utils import get_all_node_outputs, get_node_inputs, get_node_input from ..graph.node_utils import get_all_node_outputs, get_node_inputs, get_node_input, get_weights_for_node
from ..graph.special_patterns import get_ignored_patterns from ..graph.special_patterns import get_ignored_patterns
from ..utils.logger import get_logger from ..utils.logger import get_logger
@ -683,8 +683,12 @@ def create_bias_node(graph: Graph, src_node):
bias_shape = src_node.out_port(0).data.get_shape() bias_shape = src_node.out_port(0).data.get_shape()
add_bias_shape = [1] * len(bias_shape) add_bias_shape = [1] * len(bias_shape)
add_bias_shape[1] = bias_shape[1] add_bias_shape[1] = bias_shape[1]
weights = get_weights_for_node(src_node)
bias_dtype = np.float32
if weights and weights.out_port(0).is_data_type_defined():
bias_dtype = weights.out_port(0).get_data_type()
add_bias = Const(graph, add_bias = Const(graph,
{'value': np.zeros(add_bias_shape, dtype=np.float32), {'value': np.zeros(add_bias_shape, dtype=bias_dtype),
'shape': add_bias_shape, 'shape': add_bias_shape,
'need_shape_inference': True 'need_shape_inference': True
}).create_node() }).create_node()
@ -706,10 +710,10 @@ def create_fake_quantize_node(graph: Graph, name):
fq = FakeQuantize(graph, {'name': name, 'levels': 0, fq = FakeQuantize(graph, {'name': name, 'levels': 0,
'stop_value_propagation': True}).create_node() 'stop_value_propagation': True}).create_node()
input_low = Const(graph, {'value': 0.0}).create_node() input_low = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
input_height = Const(graph, {'value': 0.0}).create_node() input_height = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
output_low = Const(graph, {'value': 0.0}).create_node() output_low = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
output_height = Const(graph, {'value': 0.0}).create_node() output_height = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
input_low.out_port(0).connect(fq.in_port(1)) input_low.out_port(0).connect(fq.in_port(1))
input_height.out_port(0).connect(fq.in_port(2)) input_height.out_port(0).connect(fq.in_port(2))

File diff suppressed because one or more lines are too long

View File

@ -4,6 +4,7 @@
import json import json
import os import os
from copy import deepcopy
import numpy as np import numpy as np
import pytest import pytest
from addict import Dict from addict import Dict
@ -124,8 +125,7 @@ def test_activation_scales(tmp_path, models, preset, bits, stats_path, clipping_
nodes = normalize(get_fq_nodes_stats_algo(model, preset, bits, False, nodes = normalize(get_fq_nodes_stats_algo(model, preset, bits, False,
clipping_value=clipping_value)) clipping_value=clipping_value))
local_path = os.path.join(tmp_path, '{}.json'.format(stats_path.split("_")[-2])) local_path = os.path.join(tmp_path, '{}.json'.format(stats_path.split("_")[-2]))
local_file = open(local_path, 'w') dump_intermediate_scales(local_path, nodes)
json.dump(nodes, local_file)
assert len(ref_nodes) == len(nodes) assert len(ref_nodes) == len(nodes)
processed_nodes = [] processed_nodes = []
@ -152,9 +152,7 @@ def test_weights_scales(tmp_path, models):
ref_weights = get_ref_stats(path_to_weights) ref_weights = get_ref_stats(path_to_weights)
weights = get_fq_nodes_stats_algo(model, False, 8, True) weights = get_fq_nodes_stats_algo(model, False, 8, True)
local_path = os.path.join(tmp_path, '{}.json'.format('mv2_weights')) local_path = os.path.join(tmp_path, '{}.json'.format('mv2_weights'))
dumped = json.dumps(weights, cls=NumpyEncoder) dump_intermediate_scales(local_path, weights)
local_file = open(local_path, 'w')
json.dump(dumped, local_file)
for fq_name in weights: for fq_name in weights:
item_min, item_max = weights[fq_name]['low_level'], weights[fq_name]['high_level'] item_min, item_max = weights[fq_name]['low_level'], weights[fq_name]['high_level']
@ -359,3 +357,9 @@ def _get_tf_accuracy_checker_config(path_to_dataset):
}] }]
} }
]}]}) ]}]})
def dump_intermediate_scales(local_path, data):
data = json.dumps(deepcopy(data), cls=NumpyEncoder)
local_file = open(local_path, 'w')
json.dump(data, local_file)