[POT] Added typecast while nodes updating & creation (#8621)

* Update type casting while update

* Fix Pylint

* Update precision setting for FQ values

* Update FQ values precision setting

* Fix scales dump

* Update reference

* Another reference update
This commit is contained in:
Nikita Malinin 2021-11-23 17:13:04 +03:00 committed by GitHub
parent 6156626706
commit 516d510045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 13 deletions

View File

@ -84,7 +84,10 @@ def set_node_value(node: Node, value: np.ndarray):
"""
if node.type != 'Const':
raise Exception('Can\'t set value for non-constant node {}'.format(node.name))
node.out_port(0).data.set_value(value)
data_type = np.float32
if node.out_port(0).is_data_type_defined():
data_type = node.out_port(0).get_data_type()
node.out_port(0).data.set_value(np.array(value).astype(data_type))
def get_node_value(node: Node):

View File

@ -27,7 +27,7 @@ from . import node_utils as nu
from .pattern_utils import get_fq_result_pattern
from .special_operations import OPERATIONS_WITH_WEIGHTS, DETECTION_OUTPUT_FINAL_TYPES, SPLIT_OPERATIONS
from .utils import find_operation_matches, is_ignored, get_hw_aware_ignored_patterns
from ..graph.node_utils import get_all_node_outputs, get_node_inputs, get_node_input
from ..graph.node_utils import get_all_node_outputs, get_node_inputs, get_node_input, get_weights_for_node
from ..graph.special_patterns import get_ignored_patterns
from ..utils.logger import get_logger
@ -683,8 +683,12 @@ def create_bias_node(graph: Graph, src_node):
bias_shape = src_node.out_port(0).data.get_shape()
add_bias_shape = [1] * len(bias_shape)
add_bias_shape[1] = bias_shape[1]
weights = get_weights_for_node(src_node)
bias_dtype = np.float32
if weights and weights.out_port(0).is_data_type_defined():
bias_dtype = weights.out_port(0).get_data_type()
add_bias = Const(graph,
{'value': np.zeros(add_bias_shape, dtype=np.float32),
{'value': np.zeros(add_bias_shape, dtype=bias_dtype),
'shape': add_bias_shape,
'need_shape_inference': True
}).create_node()
@ -706,10 +710,10 @@ def create_fake_quantize_node(graph: Graph, name):
fq = FakeQuantize(graph, {'name': name, 'levels': 0,
'stop_value_propagation': True}).create_node()
input_low = Const(graph, {'value': 0.0}).create_node()
input_height = Const(graph, {'value': 0.0}).create_node()
output_low = Const(graph, {'value': 0.0}).create_node()
output_height = Const(graph, {'value': 0.0}).create_node()
input_low = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
input_height = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
output_low = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
output_height = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
input_low.out_port(0).connect(fq.in_port(1))
input_height.out_port(0).connect(fq.in_port(2))

File diff suppressed because one or more lines are too long

View File

@ -4,6 +4,7 @@
import json
import os
from copy import deepcopy
import numpy as np
import pytest
from addict import Dict
@ -124,8 +125,7 @@ def test_activation_scales(tmp_path, models, preset, bits, stats_path, clipping_
nodes = normalize(get_fq_nodes_stats_algo(model, preset, bits, False,
clipping_value=clipping_value))
local_path = os.path.join(tmp_path, '{}.json'.format(stats_path.split("_")[-2]))
local_file = open(local_path, 'w')
json.dump(nodes, local_file)
dump_intermediate_scales(local_path, nodes)
assert len(ref_nodes) == len(nodes)
processed_nodes = []
@ -152,9 +152,7 @@ def test_weights_scales(tmp_path, models):
ref_weights = get_ref_stats(path_to_weights)
weights = get_fq_nodes_stats_algo(model, False, 8, True)
local_path = os.path.join(tmp_path, '{}.json'.format('mv2_weights'))
dumped = json.dumps(weights, cls=NumpyEncoder)
local_file = open(local_path, 'w')
json.dump(dumped, local_file)
dump_intermediate_scales(local_path, weights)
for fq_name in weights:
item_min, item_max = weights[fq_name]['low_level'], weights[fq_name]['high_level']
@ -359,3 +357,9 @@ def _get_tf_accuracy_checker_config(path_to_dataset):
}]
}
]}]})
def dump_intermediate_scales(local_path, data):
data = json.dumps(deepcopy(data), cls=NumpyEncoder)
local_file = open(local_path, 'w')
json.dump(data, local_file)