[POT] Added typecast while nodes updating & creation (#8621)
* Update type casting while update * Fix Pylint * Update precision setting for FQ values * Update FQ values precision setting * Fix scales dump * Update reference * Another reference update
This commit is contained in:
parent
6156626706
commit
516d510045
@ -84,7 +84,10 @@ def set_node_value(node: Node, value: np.ndarray):
|
|||||||
"""
|
"""
|
||||||
if node.type != 'Const':
|
if node.type != 'Const':
|
||||||
raise Exception('Can\'t set value for non-constant node {}'.format(node.name))
|
raise Exception('Can\'t set value for non-constant node {}'.format(node.name))
|
||||||
node.out_port(0).data.set_value(value)
|
data_type = np.float32
|
||||||
|
if node.out_port(0).is_data_type_defined():
|
||||||
|
data_type = node.out_port(0).get_data_type()
|
||||||
|
node.out_port(0).data.set_value(np.array(value).astype(data_type))
|
||||||
|
|
||||||
|
|
||||||
def get_node_value(node: Node):
|
def get_node_value(node: Node):
|
||||||
|
@ -27,7 +27,7 @@ from . import node_utils as nu
|
|||||||
from .pattern_utils import get_fq_result_pattern
|
from .pattern_utils import get_fq_result_pattern
|
||||||
from .special_operations import OPERATIONS_WITH_WEIGHTS, DETECTION_OUTPUT_FINAL_TYPES, SPLIT_OPERATIONS
|
from .special_operations import OPERATIONS_WITH_WEIGHTS, DETECTION_OUTPUT_FINAL_TYPES, SPLIT_OPERATIONS
|
||||||
from .utils import find_operation_matches, is_ignored, get_hw_aware_ignored_patterns
|
from .utils import find_operation_matches, is_ignored, get_hw_aware_ignored_patterns
|
||||||
from ..graph.node_utils import get_all_node_outputs, get_node_inputs, get_node_input
|
from ..graph.node_utils import get_all_node_outputs, get_node_inputs, get_node_input, get_weights_for_node
|
||||||
from ..graph.special_patterns import get_ignored_patterns
|
from ..graph.special_patterns import get_ignored_patterns
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
@ -683,8 +683,12 @@ def create_bias_node(graph: Graph, src_node):
|
|||||||
bias_shape = src_node.out_port(0).data.get_shape()
|
bias_shape = src_node.out_port(0).data.get_shape()
|
||||||
add_bias_shape = [1] * len(bias_shape)
|
add_bias_shape = [1] * len(bias_shape)
|
||||||
add_bias_shape[1] = bias_shape[1]
|
add_bias_shape[1] = bias_shape[1]
|
||||||
|
weights = get_weights_for_node(src_node)
|
||||||
|
bias_dtype = np.float32
|
||||||
|
if weights and weights.out_port(0).is_data_type_defined():
|
||||||
|
bias_dtype = weights.out_port(0).get_data_type()
|
||||||
add_bias = Const(graph,
|
add_bias = Const(graph,
|
||||||
{'value': np.zeros(add_bias_shape, dtype=np.float32),
|
{'value': np.zeros(add_bias_shape, dtype=bias_dtype),
|
||||||
'shape': add_bias_shape,
|
'shape': add_bias_shape,
|
||||||
'need_shape_inference': True
|
'need_shape_inference': True
|
||||||
}).create_node()
|
}).create_node()
|
||||||
@ -706,10 +710,10 @@ def create_fake_quantize_node(graph: Graph, name):
|
|||||||
fq = FakeQuantize(graph, {'name': name, 'levels': 0,
|
fq = FakeQuantize(graph, {'name': name, 'levels': 0,
|
||||||
'stop_value_propagation': True}).create_node()
|
'stop_value_propagation': True}).create_node()
|
||||||
|
|
||||||
input_low = Const(graph, {'value': 0.0}).create_node()
|
input_low = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
|
||||||
input_height = Const(graph, {'value': 0.0}).create_node()
|
input_height = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
|
||||||
output_low = Const(graph, {'value': 0.0}).create_node()
|
output_low = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
|
||||||
output_height = Const(graph, {'value': 0.0}).create_node()
|
output_height = Const(graph, {'value': np.array(0.0).astype(np.float32)}).create_node()
|
||||||
|
|
||||||
input_low.out_port(0).connect(fq.in_port(1))
|
input_low.out_port(0).connect(fq.in_port(1))
|
||||||
input_height.out_port(0).connect(fq.in_port(2))
|
input_height.out_port(0).connect(fq.in_port(2))
|
||||||
|
File diff suppressed because one or more lines are too long
@ -4,6 +4,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from addict import Dict
|
from addict import Dict
|
||||||
@ -124,8 +125,7 @@ def test_activation_scales(tmp_path, models, preset, bits, stats_path, clipping_
|
|||||||
nodes = normalize(get_fq_nodes_stats_algo(model, preset, bits, False,
|
nodes = normalize(get_fq_nodes_stats_algo(model, preset, bits, False,
|
||||||
clipping_value=clipping_value))
|
clipping_value=clipping_value))
|
||||||
local_path = os.path.join(tmp_path, '{}.json'.format(stats_path.split("_")[-2]))
|
local_path = os.path.join(tmp_path, '{}.json'.format(stats_path.split("_")[-2]))
|
||||||
local_file = open(local_path, 'w')
|
dump_intermediate_scales(local_path, nodes)
|
||||||
json.dump(nodes, local_file)
|
|
||||||
|
|
||||||
assert len(ref_nodes) == len(nodes)
|
assert len(ref_nodes) == len(nodes)
|
||||||
processed_nodes = []
|
processed_nodes = []
|
||||||
@ -152,9 +152,7 @@ def test_weights_scales(tmp_path, models):
|
|||||||
ref_weights = get_ref_stats(path_to_weights)
|
ref_weights = get_ref_stats(path_to_weights)
|
||||||
weights = get_fq_nodes_stats_algo(model, False, 8, True)
|
weights = get_fq_nodes_stats_algo(model, False, 8, True)
|
||||||
local_path = os.path.join(tmp_path, '{}.json'.format('mv2_weights'))
|
local_path = os.path.join(tmp_path, '{}.json'.format('mv2_weights'))
|
||||||
dumped = json.dumps(weights, cls=NumpyEncoder)
|
dump_intermediate_scales(local_path, weights)
|
||||||
local_file = open(local_path, 'w')
|
|
||||||
json.dump(dumped, local_file)
|
|
||||||
|
|
||||||
for fq_name in weights:
|
for fq_name in weights:
|
||||||
item_min, item_max = weights[fq_name]['low_level'], weights[fq_name]['high_level']
|
item_min, item_max = weights[fq_name]['low_level'], weights[fq_name]['high_level']
|
||||||
@ -359,3 +357,9 @@ def _get_tf_accuracy_checker_config(path_to_dataset):
|
|||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
]}]})
|
]}]})
|
||||||
|
|
||||||
|
|
||||||
|
def dump_intermediate_scales(local_path, data):
|
||||||
|
data = json.dumps(deepcopy(data), cls=NumpyEncoder)
|
||||||
|
local_file = open(local_path, 'w')
|
||||||
|
json.dump(data, local_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user