Nm/outputs quantization scheme (#12035)
* [POT] outputs quantization scheme * [POT] remove needless blank line * add a range_estimator config for outputs * Add fq_config_priority
This commit is contained in:
parent
0224e6a067
commit
d5e8d1d968
@ -118,7 +118,7 @@ def compute_stats_layouts(config, model, qscheme=None):
|
||||
if not config.preset:
|
||||
config.preset = 'performance'
|
||||
if not qscheme:
|
||||
fq_configuration = get_configurations_by_preset(config, model, fq_configuration)
|
||||
fq_configuration = get_configurations_by_preset(config, model, fq_configuration, hardware_config)
|
||||
fq_configuration = add_range_estimator_configs(fq_configuration, config)
|
||||
else:
|
||||
fq_configuration = get_configurations_by_qscheme(fq_configuration, qscheme)
|
||||
@ -130,12 +130,8 @@ def compute_stats_layouts(config, model, qscheme=None):
|
||||
|
||||
fake_quantize_config = {}
|
||||
for fq in fq_nodes:
|
||||
node_input = get_node_input(fq, 0)
|
||||
is_weights = node_input.type == 'Const'
|
||||
if is_weights:
|
||||
fq_config = copy(fq_configuration[fq.fullname]['weights'])
|
||||
else:
|
||||
fq_config = copy(fq_configuration[fq.fullname]['activations'])
|
||||
is_weights = fq['fq_group'] == 'weights'
|
||||
fq_config = copy(fq_configuration[fq.name][fq['fq_group']])
|
||||
fake_quantize_config[fq.fullname] = fq_config
|
||||
if fq.fullname in config.layerwise_configs[0]:
|
||||
fq_config = Dict(merge_nested_dicts(fq_config, config.layerwise_configs[0][fq.fullname]))
|
||||
|
@ -5,9 +5,10 @@ from collections import deque, defaultdict
|
||||
from copy import deepcopy
|
||||
|
||||
from .range_estimator import get_range_estimator_config
|
||||
from .utils import get_hardware_config_operation_type, load_hardware_config
|
||||
from .utils import load_hardware_config
|
||||
from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNIFY_OUTPUTS, CONCAT_UNIFY_INPUTS
|
||||
from ...graph.utils import find_operation_matches, get_operation_list, is_data_type_quantizable
|
||||
from ...graph.utils import find_operation_matches, get_operation_list, is_data_type_quantizable,\
|
||||
get_hardware_config_operation_type
|
||||
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
|
||||
from ...graph.node_utils import get_input_shape, get_all_node_outputs,\
|
||||
get_node_input, get_node_inputs, get_node_data_type, check_const_input
|
||||
@ -25,13 +26,16 @@ QUANTIZATION_PARAMETERS = [
|
||||
'bits'
|
||||
]
|
||||
|
||||
ACTIVATION_QUANTIZATION_MODES = ['activations', 'outputs']
|
||||
QUANTIZATION_MODES = ['weights'] + ACTIVATION_QUANTIZATION_MODES
|
||||
|
||||
|
||||
def get_fake_quantize_configuration(config):
|
||||
""" Create fake quantization configuration from the tool configuration
|
||||
:param config: dictionary with compression section from toolkit config file
|
||||
:return dictionary with fake quantization configuration
|
||||
"""
|
||||
q_config = {'weights': {}, 'activations': {}}
|
||||
q_config = {mode: {} for mode in QUANTIZATION_MODES}
|
||||
for op_type, q_params in q_config.items():
|
||||
op_type_config = config.get(op_type, {})
|
||||
for param_name, param_value in op_type_config.items():
|
||||
@ -40,7 +44,7 @@ def get_fake_quantize_configuration(config):
|
||||
return q_config
|
||||
|
||||
|
||||
def intersect_configs(left, right):
|
||||
def intersect_configs(left, right, primary_bitwidth=None):
|
||||
""" intersect two sets of configurations """
|
||||
def _get_main_param_for_config(config):
|
||||
""" check main parameters intersection """
|
||||
@ -58,7 +62,11 @@ def intersect_configs(left, right):
|
||||
for idx, r_ in enumerate(right_[offset:]):
|
||||
r_main = _get_main_param_for_config(r_)
|
||||
if l_main == r_main:
|
||||
if l_['bits'] <= r_['bits']:
|
||||
if primary_bitwidth and l_['bits'] == primary_bitwidth:
|
||||
result.append(l_)
|
||||
elif primary_bitwidth and r_['bits'] == primary_bitwidth:
|
||||
result.append(r_)
|
||||
elif l_['bits'] <= r_['bits']:
|
||||
result.append(l_)
|
||||
else:
|
||||
result.append(r_)
|
||||
@ -127,29 +135,29 @@ def read_all_fake_quantize_configurations(config, hardware_config, model):
|
||||
confs = [conf for conf in op['quantization'][fq_type_]
|
||||
if _is_subset(q_config[fq_type_], conf)]
|
||||
if confs:
|
||||
res_conf = intersect_configs(res_conf, confs) if res_conf else confs
|
||||
res_conf = intersect_configs(res_conf, confs, primary_bitwidth) if res_conf else confs
|
||||
else:
|
||||
logger.warning('Fake quantize node %s does not support configuration '
|
||||
'from tool config file (mismatch with hardware config)',
|
||||
fq_name_)
|
||||
res_conf = intersect_configs(res_conf, q_config[fq_type_]) \
|
||||
res_conf = intersect_configs(res_conf, q_config[fq_type_], primary_bitwidth) \
|
||||
if res_conf else [q_config[fq_type_]]
|
||||
if not res_conf:
|
||||
raise Exception('Fake quantize configuration cannot be empty')
|
||||
return res_conf
|
||||
|
||||
q_config = get_fake_quantize_configuration(config)
|
||||
primary_bitwidth = hardware_config[1]['primary_bitwidth']
|
||||
|
||||
res_fq_to_hw_conf = {}
|
||||
for fq_name, (types, is_weights) in _fake_quantize_to_types(model, hardware_config).items():
|
||||
fq_type = 'weights' if is_weights else 'activations'
|
||||
res_fq_to_hw_conf[fq_name] = {fq_type: []}
|
||||
for fq_name, (types, fq_group) in _fake_quantize_to_types(model, hardware_config).items():
|
||||
res_fq_to_hw_conf[fq_name] = {fq_group: []}
|
||||
for type_ in types:
|
||||
child_name, op_type = type_
|
||||
ops = [op for op in hardware_config if op_type == op['type']]
|
||||
conf = _find_configurations(fq_name, fq_type)
|
||||
conf = _find_configurations(fq_name, fq_group)
|
||||
if conf:
|
||||
res_fq_to_hw_conf[fq_name][fq_type].append((child_name, conf))
|
||||
res_fq_to_hw_conf[fq_name][fq_group].append((child_name, conf))
|
||||
return res_fq_to_hw_conf
|
||||
|
||||
|
||||
@ -165,7 +173,7 @@ def add_range_estimator_configs(fq_to_hw_confs, config):
|
||||
return fq_to_hw_confs
|
||||
|
||||
|
||||
def get_configurations_by_preset(config, model, fq_to_hw_confs):
|
||||
def get_configurations_by_preset(config, model, fq_to_hw_confs, hardware_config):
|
||||
""" Choose fake quantize configuration by preset
|
||||
:param config: dictionary with params algo section from toolkit config
|
||||
:param model: CompressedModel instance
|
||||
@ -200,25 +208,32 @@ def get_configurations_by_preset(config, model, fq_to_hw_confs):
|
||||
bridge_input_shapes = [get_input_shape(layer, i) for layer in bridge_layers for i in layer.in_ports()]
|
||||
broadcasting = _test_shapes(bridge_input_shapes)
|
||||
for fq in fqs:
|
||||
if with_concat or unclear_layout or broadcasting:
|
||||
configuration = [c for c in cur_conf[fq]['activations'] if c['granularity'] == 'pertensor']
|
||||
else:
|
||||
configuration = cur_conf[fq]['activations']
|
||||
res_conf = intersect_configs(res_conf, configuration) if res_conf else configuration
|
||||
for key in cur_conf[fq]:
|
||||
if key in ACTIVATION_QUANTIZATION_MODES:
|
||||
if with_concat or unclear_layout or broadcasting:
|
||||
configuration = [c for c in cur_conf[fq][key] if c['granularity'] == 'pertensor']
|
||||
else:
|
||||
configuration = cur_conf[fq][key]
|
||||
res_conf = intersect_configs(res_conf, configuration, primary_bitwidth) if res_conf \
|
||||
else configuration
|
||||
if not res_conf:
|
||||
raise Exception('Fake quantize nodes {} cannot be unified'.format(fqs))
|
||||
for fq in fqs:
|
||||
cur_conf[fq]['activations'] = _apply_preset_rule(preset_, fq, 'activations', res_conf)
|
||||
for key in cur_conf[fq]:
|
||||
if key in ACTIVATION_QUANTIZATION_MODES:
|
||||
cur_conf[fq][key] = _apply_preset_rule(preset_, fq, key, res_conf)
|
||||
return cur_conf
|
||||
|
||||
primary_bitwidth = hardware_config[1]['primary_bitwidth']
|
||||
res = {}
|
||||
for key, value in fq_to_hw_confs_.items():
|
||||
conf = dict()
|
||||
for i_type in ['activations', 'weights']:
|
||||
for i_type in QUANTIZATION_MODES:
|
||||
if i_type in value:
|
||||
res_conf = []
|
||||
for _, configuration in value[i_type]:
|
||||
res_conf = intersect_configs(res_conf, configuration) if res_conf else configuration
|
||||
res_conf = intersect_configs(res_conf, configuration, primary_bitwidth) if res_conf \
|
||||
else configuration
|
||||
if not res_conf:
|
||||
raise Exception('Fake quantize node {} does not have a suitable configuration'
|
||||
' for layers {}'.format(key, [layer for layer, _ in value[i_type]]))
|
||||
@ -413,8 +428,13 @@ def _fake_quantize_to_types(model, hardware_config):
|
||||
out = {}
|
||||
available_types = [layer['type'] for layer in hardware_config]
|
||||
for fq in get_nodes_by_type(model, ['FakeQuantize']):
|
||||
node_input = get_node_input(fq, 0)
|
||||
out[fq.fullname] = (_get_node_valuable_descendant(fq), node_input.type == 'Const')
|
||||
if fq['fq_group'] == 'outputs':
|
||||
fq_input = get_node_input(fq, 0)
|
||||
hw_node_types = get_hardware_config_operation_type(fq_input, available_types)
|
||||
out_data = ([(fq_input.name, hw_node_types)], fq['fq_group'])
|
||||
else:
|
||||
out_data = (_get_node_valuable_descendant(fq), fq['fq_group'])
|
||||
out[fq.fullname] = out_data
|
||||
|
||||
return out
|
||||
|
||||
@ -427,7 +447,8 @@ def change_configurations_by_model_type(model, config, fq_configuration, hardwar
|
||||
def change_configurations_by_model_type_transformer(model, fq_configuration, hardware_config):
|
||||
fq_types = _fake_quantize_to_types(model, hardware_config)
|
||||
for fq in get_nodes_by_type(model, ['FakeQuantize']):
|
||||
node_creator_fq, is_weights = fq_types[fq.name]
|
||||
node_creator_fq, fq_group = fq_types[fq.name]
|
||||
is_weights = fq_group == 'weights'
|
||||
node_name = None
|
||||
for name, type_node in node_creator_fq:
|
||||
if type_node == 'MatMul':
|
||||
|
@ -70,13 +70,61 @@ QUANTILE_WEIGHTS_RANGE_ESTIMATOR_CONFIG = {
|
||||
}}
|
||||
|
||||
|
||||
DEFAULT_OUTPUTS_RANGE_ESTIMATOR_CONFIG = {
|
||||
'perchannel': {
|
||||
'symmetric': {
|
||||
'min': {'aggregator': 'min', 'type': 'min', 'granularity': 'pertensor'},
|
||||
'max': {'aggregator': 'max', 'type': 'abs_max'}
|
||||
},
|
||||
'asymmetric': {
|
||||
'min': {'aggregator': 'min', 'type': 'min'},
|
||||
'max': {'aggregator': 'max', 'type': 'max'}
|
||||
}
|
||||
},
|
||||
'pertensor': {
|
||||
'symmetric': {
|
||||
'min': {'aggregator': 'min', 'type': 'min'},
|
||||
'max': {'aggregator': 'mean', 'type': 'abs_max'}
|
||||
},
|
||||
'asymmetric': {
|
||||
'min': {'aggregator': 'mean', 'type': 'min'},
|
||||
'max': {'aggregator': 'mean', 'type': 'max'}
|
||||
}
|
||||
}}
|
||||
|
||||
|
||||
QUANTILE_OUTPUTS_RANGE_ESTIMATOR_CONFIG = {
|
||||
'perchannel': {
|
||||
'symmetric': {
|
||||
'min': {'aggregator': 'min', 'type': 'min', 'granularity': 'pertensor'},
|
||||
'max': {'aggregator': 'max', 'type': 'abs_quantile', 'outlier_prob': 1e-4}
|
||||
},
|
||||
'asymmetric': {
|
||||
'min': {'aggregator': 'min', 'type': 'quantile', 'outlier_prob': 1e-4},
|
||||
'max': {'aggregator': 'max', 'type': 'quantile', 'outlier_prob': 1e-4}
|
||||
}
|
||||
},
|
||||
'pertensor': {
|
||||
'symmetric': {
|
||||
'min': {'aggregator': 'min', 'type': 'min'},
|
||||
'max': {'aggregator': 'mean', 'type': 'abs_quantile', 'outlier_prob': 1e-4}
|
||||
},
|
||||
'asymmetric': {
|
||||
'min': {'aggregator': 'mean', 'type': 'quantile', 'outlier_prob': 1e-4},
|
||||
'max': {'aggregator': 'mean', 'type': 'quantile', 'outlier_prob': 1e-4}
|
||||
}
|
||||
}}
|
||||
|
||||
|
||||
RANGE_ESTIMATOR_CONFIG_PRESETS = {
|
||||
'default': {
|
||||
'activations': DEFAULT_ACTIVATIONS_RANGE_ESTIMATOR_CONFIG,
|
||||
'outputs': DEFAULT_OUTPUTS_RANGE_ESTIMATOR_CONFIG,
|
||||
'weights': DEFAULT_WEIGHTS_RANGE_ESTIMATOR_CONFIG,
|
||||
},
|
||||
'quantile': {
|
||||
'activations': QUANTILE_ACTIVATIONS_RANGE_ESTIMATOR_CONFIG,
|
||||
'outputs': QUANTILE_OUTPUTS_RANGE_ESTIMATOR_CONFIG,
|
||||
'weights': QUANTILE_WEIGHTS_RANGE_ESTIMATOR_CONFIG,
|
||||
}
|
||||
}
|
||||
@ -90,7 +138,7 @@ def get_range_estimator_config(config, tensor_type, granularity, q_mode, preset=
|
||||
preset = range_estimator.get('preset', 'default')
|
||||
preset_config = deepcopy(RANGE_ESTIMATOR_CONFIG_PRESETS[preset][tensor_type])
|
||||
result_config = preset_config[granularity][q_mode] \
|
||||
if tensor_type == 'activations' else preset_config[q_mode]
|
||||
if tensor_type in ['activations', 'outputs'] else preset_config[q_mode]
|
||||
if 'min' in range_estimator:
|
||||
if 'min' not in result_config:
|
||||
result_config['min'] = {}
|
||||
|
@ -10,7 +10,6 @@ from .range_estimator import get_range_estimator_config
|
||||
from ...api.engine import Engine
|
||||
from ...configs.hardware_config import HardwareConfig
|
||||
from ...engines.ac_engine import ACEngine
|
||||
from ...graph.node_utils import get_input_shape
|
||||
from ...statistics.function_selector import ACTIVATIONS, WEIGHTS, get_stats_function, AGGREGATION_FN
|
||||
from ...statistics.statistics import TensorStatistic
|
||||
|
||||
@ -247,33 +246,6 @@ def get_quantize_op_config(op, config, opt_conf=None):
|
||||
return qconfig
|
||||
|
||||
|
||||
def get_hardware_config_operation_type(node, available_types):
|
||||
""" This function gets type by child
|
||||
for hardware configuration of FQ node
|
||||
:param node: node-type object
|
||||
:param available_types: available types with config
|
||||
:return: default or special type of layer as string
|
||||
"""
|
||||
|
||||
def _is_depth_wise(node):
|
||||
if node.type == 'Convolution' and node.has_valid('group'):
|
||||
group = node['group']
|
||||
output = node['output']
|
||||
input_shape = get_input_shape(node, 0)
|
||||
if group == output and input_shape[1] == output:
|
||||
return True
|
||||
return False
|
||||
|
||||
type_checkers = {
|
||||
'DepthWiseConvolution': _is_depth_wise
|
||||
}
|
||||
|
||||
for real_type in type_checkers:
|
||||
if real_type in available_types and type_checkers[real_type](node):
|
||||
return real_type
|
||||
return node.type
|
||||
|
||||
|
||||
def get_tensor_statistics(range_estimator_config, for_weights, **kwargs):
|
||||
stats = {}
|
||||
for stats_name in ['min', 'max']:
|
||||
|
@ -108,6 +108,14 @@ def get_configs(args):
|
||||
'aggregator': 'max'
|
||||
}
|
||||
}
|
||||
},
|
||||
'outputs': {
|
||||
'range_estimator': {
|
||||
'max': {
|
||||
'type': 'abs_max',
|
||||
'aggregator': 'max'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
{
|
||||
"target_device": "GNA",
|
||||
"primary_bitwidth": 16,
|
||||
"config": {
|
||||
"quantization": {
|
||||
"q32_a": {
|
||||
@ -29,6 +30,20 @@
|
||||
"granularity": "pertensor",
|
||||
"level_low": -32767,
|
||||
"level_high": 32767
|
||||
},
|
||||
"q32_o": {
|
||||
"bits": 32,
|
||||
"mode": "symmetric",
|
||||
"granularity": "pertensor",
|
||||
"level_low": -2147483648,
|
||||
"level_high": 2147483647
|
||||
},
|
||||
"q16_o": {
|
||||
"bits": 16,
|
||||
"mode": "symmetric",
|
||||
"granularity": "pertensor",
|
||||
"level_low": -32768,
|
||||
"level_high": 32767
|
||||
}
|
||||
}
|
||||
},
|
||||
@ -37,34 +52,39 @@
|
||||
"type": "Convolution",
|
||||
"quantization": {
|
||||
"activations": "q16_a",
|
||||
"weights": "q16_w"
|
||||
"weights": "q16_w",
|
||||
"outputs": "q32_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "MatMul",
|
||||
"quantization": {
|
||||
"activations": "q16_a",
|
||||
"weights": ["q8_w", "q16_w"]
|
||||
"weights": ["q8_w", "q16_w"],
|
||||
"outputs": "q32_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Add",
|
||||
"quantization": {
|
||||
"activations": "q16_a",
|
||||
"weights": "q16_w"
|
||||
"weights": "q16_w",
|
||||
"outputs": "q32_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Multiply",
|
||||
"quantization": {
|
||||
"activations": "q16_a",
|
||||
"weights": ["q8_w", "q16_w"]
|
||||
"weights": ["q8_w", "q16_w"],
|
||||
"outputs": "q32_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Power",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -88,55 +108,64 @@
|
||||
{
|
||||
"type": "Sigmoid",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Tanh",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "ReLU",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "PReLU",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Clamp",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Log",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Abs",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Exp",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Sign",
|
||||
"quantization": {
|
||||
"activations": "q32_a"
|
||||
"activations": "q32_a",
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -145,6 +174,12 @@
|
||||
"activations": "q16_a"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Parameter",
|
||||
"quantization": {
|
||||
"outputs": "q16_o"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Subtract",
|
||||
"quantization": {
|
||||
|
@ -38,6 +38,7 @@ class HardwareConfig(list):
|
||||
json_config = json.load(f, object_pairs_hook=OrderedDict)
|
||||
hw_config = cls()
|
||||
hw_config.append(Dict(('target_device', json_config['target_device'])))
|
||||
hw_config.append(Dict(('primary_bitwidth', json_config.get('primary_bitwidth', 8))))
|
||||
|
||||
configs = {}
|
||||
for algorithm_name, algorithm_config in json_config.get('config', {}).items():
|
||||
|
@ -27,15 +27,17 @@ from . import editor as ge
|
||||
from . import node_utils as nu
|
||||
from .editor import get_nodes_by_type
|
||||
from .pattern_utils import get_fq_result_pattern
|
||||
from .special_operations import OPERATIONS_WITH_WEIGHTS, DETECTION_OUTPUT_FINAL_TYPES, SPLIT_OPERATIONS
|
||||
from .special_operations import OPERATIONS_WITH_WEIGHTS, DETECTION_OUTPUT_FINAL_TYPES, \
|
||||
SPLIT_OPERATIONS, OPERATIONS_WITH_BIAS
|
||||
from .utils import find_operation_matches, is_ignored, get_hw_aware_ignored_patterns
|
||||
from ..graph.node_utils import get_all_node_outputs, get_node_inputs, get_node_input, get_weights_for_node
|
||||
from ..graph.special_patterns import get_ignored_patterns
|
||||
from ..utils.logger import get_logger
|
||||
from .utils import get_hardware_config_operation_type
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
#pylint: disable=C0302
|
||||
class SaveBNStatistics(FrontReplacementSubgraph):
|
||||
enabled = True
|
||||
|
||||
@ -90,6 +92,22 @@ class InsertFakeQuantize(BackReplacementPattern):
|
||||
def quantize_operations(self, value):
|
||||
setattr(self, '_quantize_operations', value)
|
||||
|
||||
@property
|
||||
def quantize_output_operations(self):
|
||||
return getattr(self, '_quantize_output_operations', [])
|
||||
|
||||
@quantize_output_operations.setter
|
||||
def quantize_output_operations(self, value):
|
||||
setattr(self, '_quantize_output_operations', value)
|
||||
|
||||
@property
|
||||
def hardware_config(self):
|
||||
return getattr(self, '_hardware_config', [])
|
||||
|
||||
@hardware_config.setter
|
||||
def hardware_config(self, value):
|
||||
setattr(self, '_hardware_config', value)
|
||||
|
||||
@property
|
||||
def ignored_params(self):
|
||||
return getattr(self, '_ignored_params', {'skip_model': False, 'scope': [], 'operations': []})
|
||||
@ -130,19 +148,38 @@ class InsertFakeQuantize(BackReplacementPattern):
|
||||
return
|
||||
|
||||
if m_op.type in ['Convolution', 'ConvolutionBackpropData', 'MatMul']:
|
||||
insert_fake_quantize(graph, m_op, [0, 1], ['fq_input', 'fq_weights'])
|
||||
insert_fake_quantize(graph, m_op, [0, 1], ['fq_input', 'fq_weights'], ['activations', 'weights'],
|
||||
hw_config=self.hardware_config)
|
||||
elif m_op.type == 'LSTMCell':
|
||||
insert_fake_quantize(graph, m_op, [0, 1, 2, 3, 4])
|
||||
insert_fake_quantize(graph, m_op, [0, 1, 2, 3, 4], hw_config=self.hardware_config)
|
||||
elif self.quantize_only_input(m_op):
|
||||
insert_fake_quantize(graph, m_op, [0])
|
||||
insert_fake_quantize(graph, m_op, [0], hw_config=self.hardware_config)
|
||||
else:
|
||||
insert_fake_quantize(graph, m_op)
|
||||
insert_fake_quantize(graph, m_op, hw_config=self.hardware_config)
|
||||
|
||||
biased_op = [op['type'] for op in OPERATIONS_WITH_BIAS]
|
||||
if m_op.type in self.quantize_output_operations:
|
||||
bias_node = nu.get_bias_for_node(m_op)
|
||||
if m_op.type in biased_op and bias_node:
|
||||
m_op = nu.get_node_output(bias_node, 0)[0]
|
||||
insert_output_fake_quantize(graph, m_op, hw_config=self.hardware_config,
|
||||
ignored_params=self.ignored_params)
|
||||
|
||||
|
||||
|
||||
class FakeQuantizePropagation(BackReplacementPattern):
|
||||
|
||||
enabled = False
|
||||
|
||||
@property
|
||||
def hardware_config(self):
|
||||
return getattr(self, '_hardware_config', [])
|
||||
|
||||
@hardware_config.setter
|
||||
def hardware_config(self, value):
|
||||
setattr(self, '_hardware_config', value)
|
||||
|
||||
|
||||
def remove_node_and_reset_connections(self, graph, node: Node, in_port):
|
||||
node.in_port(0).disconnect()
|
||||
node.out_port(0).get_connection().set_source(in_port)
|
||||
@ -155,7 +192,7 @@ class FakeQuantizePropagation(BackReplacementPattern):
|
||||
# Disconnect FQ from input and reconnect outputs to input node
|
||||
self.remove_node_and_reset_connections(graph, fq, in_port)
|
||||
|
||||
return insert_fake_quantize(graph, op, [0])
|
||||
return insert_fake_quantize(graph, op, [0], hw_config=self.hardware_config)
|
||||
|
||||
def jump_to_all_inputs(self, graph: Graph, fq: Node) -> []:
|
||||
in_port = fq.in_port(0).get_source()
|
||||
@ -165,7 +202,7 @@ class FakeQuantizePropagation(BackReplacementPattern):
|
||||
self.remove_node_and_reset_connections(graph, fq, in_port)
|
||||
|
||||
# Insert FQ operations for all inputs
|
||||
return insert_fake_quantize(graph, op)
|
||||
return insert_fake_quantize(graph, op, hw_config=self.hardware_config)
|
||||
|
||||
def jump_to_all_branch_except_const(self, graph, fq: Node) -> []:
|
||||
in_port = fq.in_port(0).get_source()
|
||||
@ -176,7 +213,7 @@ class FakeQuantizePropagation(BackReplacementPattern):
|
||||
# Disconnect FQ from input and reconnect outputs to input node
|
||||
self.remove_node_and_reset_connections(graph, fq, in_port)
|
||||
|
||||
return insert_fake_quantize(graph, op, ports)
|
||||
return insert_fake_quantize(graph, op, ports, hw_config=self.hardware_config)
|
||||
|
||||
def jump_over_split_concat(self, graph: Graph, fq: Node) -> []:
|
||||
in_port = fq.in_port(0).get_source()
|
||||
@ -298,6 +335,13 @@ class FakeQuantizePropagation(BackReplacementPattern):
|
||||
if not skip_ascent_map[input_parent_name]:
|
||||
input_type = ('Split', 'VariadicSplit', 'Concat')
|
||||
input_name = (input_node.name, input_parent_name)
|
||||
if input_type == 'FakeQuantize':
|
||||
if fq['fq_config_priority'] == 'high' and input_node['fq_config_priority'] == 'low':
|
||||
input_node['fq_group'] = fq['fq_group']
|
||||
for fq_config in fq['fq_configs']:
|
||||
if fq_config not in input_node['fq_configs']:
|
||||
input_node['fq_configs'].append(fq_config)
|
||||
logger.debug('FQ %s extended with %s configs', input_name, fq.name)
|
||||
logger.debug('FQ %s jumped over %s (%s)', fq.name, input_type, input_name)
|
||||
|
||||
callback = self.map_op_to_fn[input_type]
|
||||
@ -405,6 +449,9 @@ class FakeQuantizeOptimization(BackReplacementPattern):
|
||||
fq_consumers = sorted(fq_consumers, key=lambda x: x.name)
|
||||
# Keep only first FakeQuantize and disconnect other
|
||||
for fq in fq_consumers[1:]:
|
||||
for fq_config in fq['fq_configs']:
|
||||
if fq_config not in fq_consumers[0]['fq_configs']:
|
||||
fq_consumers[0]['fq_configs'].append(fq_config)
|
||||
logger.debug('Removed useless FakeQuantize {}'.format(fq.name))
|
||||
fq.in_port(0).disconnect()
|
||||
fq.out_port(0).get_connection().set_source(fq_consumers[0].out_port(0))
|
||||
@ -760,9 +807,9 @@ def create_bias_node(graph: Graph, src_node):
|
||||
add_bias.out_node(0)['Insert_Convert_operation_after'] = True
|
||||
|
||||
|
||||
def create_fake_quantize_node(graph: Graph, name, data_type=np.float32):
|
||||
def create_fake_quantize_node(graph: Graph, name, data_type=np.float32, **kwargs):
|
||||
fq = FakeQuantize(graph, {'name': name, 'levels': 0,
|
||||
'stop_value_propagation': True}).create_node()
|
||||
'stop_value_propagation': True, **kwargs}).create_node()
|
||||
|
||||
input_low = Const(graph, {'value': np.array(0.0, dtype=data_type)}).create_node()
|
||||
input_height = Const(graph, {'value': np.array(0.0, dtype=data_type)}).create_node()
|
||||
@ -782,13 +829,17 @@ def create_fake_quantize_node(graph: Graph, name, data_type=np.float32):
|
||||
return fq
|
||||
|
||||
|
||||
def insert_fake_quantize(graph, node, ports=None, names=None):
|
||||
def insert_fake_quantize(graph, node, ports=None, names=None, fq_types=None, hw_config=None):
|
||||
blobs_as_inputs_nodes_type = ['Convolution', 'Deconvolution', 'MatMul']
|
||||
|
||||
port_name = None
|
||||
if ports is not None and names is not None:
|
||||
port_name = dict(zip(ports, names))
|
||||
|
||||
fq_type = None
|
||||
if fq_types is not None and ports is not None:
|
||||
fq_type = dict(zip(ports, fq_types))
|
||||
|
||||
new_fq = []
|
||||
for idx, port in node.in_ports().items():
|
||||
if port.disconnected():
|
||||
@ -804,19 +855,35 @@ def insert_fake_quantize(graph, node, ports=None, names=None):
|
||||
|
||||
# This condition blocks FQ insertion after the keep_shape_ops (KSO) generated sub-graph
|
||||
# to avoid quantization of integer-like tensors
|
||||
if port.get_source().node.type != 'Const' and port.data.get_value() is not None:
|
||||
in_port_type = port.get_source().node.type
|
||||
if in_port_type != 'Const' and port.data.get_value() is not None:
|
||||
continue
|
||||
|
||||
name = 'fq_input'
|
||||
is_weights = in_port_type == 'Const'
|
||||
|
||||
name = 'fq_weights' if is_weights else 'fq_input'
|
||||
if port_name is not None and idx in port_name:
|
||||
name = port_name[idx]
|
||||
|
||||
port_data_type = nu.get_node_data_type(node, idx)
|
||||
port_data_type = port_data_type if port_data_type else np.float32
|
||||
# Create FakeQuantize operations
|
||||
fq_input = create_fake_quantize_node(
|
||||
graph, '{node_name}/{name}_{idx}'.format(node_name=node.name, name=name, idx=idx), port_data_type)
|
||||
fq_group = 'weights' if is_weights else 'activations'
|
||||
if fq_type is not None and idx in fq_type:
|
||||
fq_group = fq_type[idx]
|
||||
|
||||
fq_configs = []
|
||||
node_type = get_hardware_config_operation_type(node, list(hw_config.keys()))
|
||||
if hw_config is not None and hw_config[node_type]:
|
||||
fq_configs = hw_config[node_type][fq_group]
|
||||
|
||||
fq_options = {
|
||||
'fq_group': fq_group,
|
||||
'fq_configs': copy(fq_configs),
|
||||
'fq_config_priority': 'high'
|
||||
}
|
||||
fq_name = '{node_name}/{name}_{idx}'.format(node_name=node.name, name=name, idx=idx)
|
||||
fq_input = create_fake_quantize_node(graph, fq_name, port_data_type, **fq_options)
|
||||
# Insert FakeQuantize after input
|
||||
if node.type == 'Result':
|
||||
in_port = port.get_source()
|
||||
@ -832,6 +899,53 @@ def insert_fake_quantize(graph, node, ports=None, names=None):
|
||||
return new_fq
|
||||
|
||||
|
||||
def insert_output_fake_quantize(graph, node, hw_config=None, ignored_params=None):
|
||||
activation_nodes_type = ['Power', 'Sigmoid', 'Tanh', 'ReLU', 'PReLU',
|
||||
'Clamp', 'Log', 'Abs', 'Exp', 'Sign']
|
||||
|
||||
new_fq = []
|
||||
for out_port_id, port in node.out_ports().items():
|
||||
if port.disconnected():
|
||||
continue
|
||||
|
||||
next_ports = port.get_destinations()
|
||||
for next_port_id, next_port in enumerate(next_ports):
|
||||
next_node = next_port.node
|
||||
|
||||
if next_node.type == 'ShapeOf':
|
||||
continue
|
||||
|
||||
if ignored_params is not None and next_node.type != 'FakeQuantize' \
|
||||
and is_ignored(ignored_params, next_node):
|
||||
continue
|
||||
|
||||
fq_name = '{node_name}/fq_output_{out_port_id}_{next_port_id}'.format(node_name=node.name,
|
||||
out_port_id=out_port_id,
|
||||
next_port_id=next_port_id)
|
||||
fq_configs = hw_config[node.type]['outputs'] if hw_config is not None and hw_config[node.type] else []
|
||||
|
||||
fq_config_priority = 'low'
|
||||
if node.type in activation_nodes_type + ['Parameter']:
|
||||
fq_config_priority = 'high'
|
||||
else:
|
||||
fq_config_priority = 'low'
|
||||
|
||||
fq_options = {
|
||||
'fq_group': 'outputs',
|
||||
'fq_configs': copy(fq_configs),
|
||||
'fq_config_priority': fq_config_priority
|
||||
}
|
||||
fq_output = create_fake_quantize_node(graph, fq_name, **fq_options)
|
||||
|
||||
in_port = next_port.get_source()
|
||||
next_port.get_connection().set_source(fq_output.out_port(0))
|
||||
in_port.connect(fq_output.in_port(0))
|
||||
fq_output.infer(fq_output)
|
||||
new_fq.append(fq_output)
|
||||
|
||||
return new_fq
|
||||
|
||||
|
||||
def traverse_graph(node, move_fn, stop_criteria_fn=None, criteria_fns=None):
|
||||
""" Traverse through graph dependent on move_fn
|
||||
:param node: node to start floating or sinking with some rule
|
||||
|
@ -7,13 +7,16 @@ from .editor import add_fullname_for_nodes
|
||||
from .special_operations import QUANTIZE_AGNOSTIC_OPERATIONS
|
||||
from .passes import InsertFakeQuantize, FakeQuantizePropagation, FakeQuantizeOptimization, RemoveFakeQuantize, \
|
||||
SpecialBlocksMarker, FakeQuantizeNameSwapper
|
||||
from .utils import find_operation_matches, get_operation_list, preprocess_ignored_params
|
||||
from .utils import find_operation_matches, get_operation_list, preprocess_ignored_params, \
|
||||
get_operation_list_with_outputs
|
||||
|
||||
|
||||
class GraphTransformer:
|
||||
def __init__(self, hardware_config, quantize_inputs=False):
|
||||
self.target_device = hardware_config[0]['target_device']
|
||||
hw_ops = get_operation_list(hardware_config)
|
||||
hw_config = {conf['type']: conf['quantization'] for conf in hardware_config if 'type' in conf}
|
||||
quantize_output_operations = get_operation_list_with_outputs(hardware_config)
|
||||
|
||||
quantize_agnostic_operations = [op[1] for op in find_operation_matches(
|
||||
QUANTIZE_AGNOSTIC_OPERATIONS, hw_ops)]
|
||||
@ -27,11 +30,14 @@ class GraphTransformer:
|
||||
|
||||
self.fq_insertion = InsertFakeQuantize()
|
||||
self.fq_insertion.quantize_operations = quantize_operations
|
||||
self.fq_insertion.quantize_output_operations = quantize_output_operations
|
||||
self.fq_insertion.hardware_config = hw_config
|
||||
|
||||
self.fq_propagation = FakeQuantizePropagation()
|
||||
self.fq_propagation.quantize_agnostic_operations = quantize_agnostic_operations
|
||||
self.fq_propagation.quantize_inputs = quantize_inputs
|
||||
self.fq_propagation.quantize_operations = quantize_operations
|
||||
self.fq_propagation.hardware_config = hw_config
|
||||
|
||||
self.fq_optimization = FakeQuantizeOptimization()
|
||||
|
||||
|
@ -13,7 +13,7 @@ from .gpu_patterns import get_gpu_ignored_patterns
|
||||
from .vpu_patterns import get_vpu_ignored_patterns
|
||||
from .gna_patterns import get_gna_ignored_patterns
|
||||
from .special_operations import QUANTIZE_AGNOSTIC_OPERATIONS
|
||||
from .node_utils import get_all_node_outputs
|
||||
from .node_utils import get_all_node_outputs, get_input_shape
|
||||
|
||||
HARDWARE_AWARE_IGNORED_PATTERNS = {
|
||||
'CPU': get_cpu_ignored_patterns(),
|
||||
@ -25,6 +25,8 @@ HARDWARE_AWARE_IGNORED_PATTERNS = {
|
||||
|
||||
DEFAULT_PATH = 'PATH'
|
||||
|
||||
HARDWARE_SPECIAL_FIELDS = ['target_device', 'primary_bitwidth']
|
||||
|
||||
|
||||
# pylint: disable=method-hidden
|
||||
class PathEncoder(json.JSONEncoder):
|
||||
@ -110,7 +112,7 @@ def find_operation_matches(src_ops, dst_ops):
|
||||
def get_operation_list(hardware_config):
|
||||
hw_ops = []
|
||||
for item in hardware_config:
|
||||
if 'target_device' in item:
|
||||
if any([special_value in item for special_value in HARDWARE_SPECIAL_FIELDS]):
|
||||
continue
|
||||
|
||||
op = {}
|
||||
@ -121,6 +123,14 @@ def get_operation_list(hardware_config):
|
||||
hw_ops.append(op)
|
||||
return hw_ops
|
||||
|
||||
def get_operation_list_with_outputs(hardware_config):
|
||||
hw_ops = []
|
||||
for item in hardware_config:
|
||||
if any([special_value in item for special_value in HARDWARE_SPECIAL_FIELDS]):
|
||||
continue
|
||||
if 'quantization' in item and 'outputs' in item['quantization']:
|
||||
hw_ops.append(item['type'])
|
||||
return hw_ops
|
||||
|
||||
def create_quantization_info_for_mo(config):
|
||||
quantization_section = {}
|
||||
@ -219,3 +229,30 @@ def check_agnostic_and_ignored_params(model, ignored_params):
|
||||
|
||||
def is_data_type_quantizable(type_node):
|
||||
return type_node not in (np.int32, np.int64, bool)
|
||||
|
||||
|
||||
def get_hardware_config_operation_type(node, available_types):
|
||||
""" This function gets type by child
|
||||
for hardware configuration of FQ node
|
||||
:param node: node-type object
|
||||
:param available_types: available types with config
|
||||
:return: default or special type of layer as string
|
||||
"""
|
||||
|
||||
def _is_depth_wise(node):
|
||||
if node.type == 'Convolution' and node.has_valid('group'):
|
||||
group = node['group']
|
||||
output = node['output']
|
||||
input_shape = get_input_shape(node, 0)
|
||||
if group == output and input_shape[1] == output:
|
||||
return True
|
||||
return False
|
||||
|
||||
type_checkers = {
|
||||
'DepthWiseConvolution': _is_depth_wise
|
||||
}
|
||||
|
||||
for real_type in type_checkers:
|
||||
if real_type in available_types and type_checkers[real_type](node):
|
||||
return real_type
|
||||
return node.type
|
||||
|
File diff suppressed because one or more lines are too long
@ -1 +1 @@
|
||||
[{"target_device": "CPU"}, {"type": "Convolution", "quantization": {"weights": [{"mode": "symmetric", "bits": 4, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "perchannel"}, {"mode": "asymmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "asymmetric", "bits": 8, "granularity": "perchannel"}], "activations": [{"mode": "symmetric", "bits": 2, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 4, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "perchannel"}, {"mode": "asymmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "asymmetric", "bits": 8, "granularity": "perchannel"}]}}, {"type": "MatMul", "quantization": {"weights": [{"level_low": -127, "mode": "symmetric", "bits": 8, "level_high": 127, "granularity": "pertensor"}], "activations": [{"mode": "symmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "perchannel"}, {"mode": "asymmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "asymmetric", "bits": 8, "granularity": "perchannel"}]}}]
|
||||
[{"target_device": "CPU"}, {"primary_bitwidth": 8}, {"type": "Convolution", "quantization": {"weights": [{"mode": "symmetric", "bits": 4, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "perchannel"}, {"mode": "asymmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "asymmetric", "bits": 8, "granularity": "perchannel"}], "activations": [{"mode": "symmetric", "bits": 2, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 4, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "perchannel"}, {"mode": "asymmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "asymmetric", "bits": 8, "granularity": "perchannel"}]}}, {"type": "MatMul", "quantization": {"weights": [{"level_low": -127, "mode": "symmetric", "bits": 8, "level_high": 127, "granularity": "pertensor"}], "activations": [{"mode": "symmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "perchannel"}, {"mode": "asymmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "asymmetric", "bits": 8, "granularity": "perchannel"}]}}]
|
3
tools/pot/tests/data/models/act_act_example/act_act_example.json
Executable file
3
tools/pot/tests/data/models/act_act_example/act_act_example.json
Executable file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:86ed8ec45c3f4674cf33529a2eb3a635a4bd53dafe95f01f03a7b751dd35f7bd
|
||||
size 161
|
3
tools/pot/tests/data/models/act_act_example/act_act_example.onnx
Executable file
3
tools/pot/tests/data/models/act_act_example/act_act_example.onnx
Executable file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e17e875df7af8f5e2204d5e211d5f94e1f30cd866976c5fbed58fd0438c09a2c
|
||||
size 612
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e6eb24c47d38ae992eef207a5c4e873db496a63eed42ac5469df4f2afc756f48
|
||||
size 16835
|
@ -93,7 +93,9 @@ def test_configurations_by_preset(preset):
|
||||
'preset': preset,
|
||||
'target_device': 'CPU'
|
||||
})
|
||||
hw_config_path = HARDWARE_CONFIG_PATH.joinpath('cpu.json').as_posix()
|
||||
hw_config = HardwareConfig.from_json(hw_config_path)
|
||||
correct_configuration = _load_config('correct_configuration.json')
|
||||
res = get_configurations_by_preset(config, None, correct_configuration)
|
||||
res = get_configurations_by_preset(config, None, correct_configuration, hw_config)
|
||||
ref_configuration = _load_config('ref_configuration.json')
|
||||
assert res == ref_configuration[preset]
|
||||
|
@ -42,7 +42,7 @@ def test_per_channel_activations_for_depthwise(tmp_path, models, model_name, mod
|
||||
ALGORITHM_CONFIG, hardware_config, model)
|
||||
ALGORITHM_CONFIG.preset = ALGORITHM_CONFIG.params.preset
|
||||
ALGORITHM_CONFIG.target_device = ALGORITHM_CONFIG.params.target_device
|
||||
fq_configuration = get_configurations_by_preset(ALGORITHM_CONFIG, model, fq_configurations)
|
||||
fq_configuration = get_configurations_by_preset(ALGORITHM_CONFIG, model, fq_configurations, hardware_config)
|
||||
fq_dw_names = ['Conv_4/WithoutBiases/fq_input_0', 'Conv_13/WithoutBiases/fq_input_0',
|
||||
'Conv_22/WithoutBiases/fq_input_0', 'Conv_32/WithoutBiases/fq_input_0',
|
||||
'Conv_41/WithoutBiases/fq_input_0', 'Conv_51/WithoutBiases/fq_input_0',
|
||||
|
60
tools/pot/tests/test_graph_with_stats.py
Normal file
60
tools/pot/tests/test_graph_with_stats.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright (C) 2020-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from addict import Dict
|
||||
from openvino.tools.pot.data_loaders.creator import create_data_loader
|
||||
from openvino.tools.pot.engines.creator import create_engine
|
||||
from openvino.tools.pot.graph import load_model
|
||||
from openvino.tools.pot.pipeline.initializer import create_pipeline
|
||||
|
||||
from .utils.check_graph import check_model
|
||||
from .utils.config import merge_configs
|
||||
|
||||
TEST_MODELS = [
|
||||
('resnet_example', 'pytorch')
|
||||
('resnet_example', 'pytorch', 'Ranger', 'ANY'),
|
||||
('act_act_example', 'pytorch', 'DefaultQuantization', 'GNA')
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope='module', params=TEST_MODELS,
|
||||
ids=['{}_{}_{}'.format(*m) for m in TEST_MODELS])
|
||||
def _params(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def test_graph_with_stats(_params, tmp_path, models):
|
||||
model_name, model_framework, algo_name, target_device = _params
|
||||
|
||||
algorithm_config = Dict({
|
||||
'algorithms': [{
|
||||
'name': 'algo_name',
|
||||
'params': {
|
||||
'target_device': target_device,
|
||||
'preset': 'performance',
|
||||
'stat_subset_size': 1
|
||||
}
|
||||
}]
|
||||
})
|
||||
|
||||
model = models.get(model_name, model_framework, tmp_path)
|
||||
|
||||
test_dir = Path(__file__).parent
|
||||
path_image_data = os.path.join(test_dir, 'data/image_data')
|
||||
engine_config = Dict({'device': 'CPU',
|
||||
'type': 'simplified',
|
||||
'data_source': path_image_data})
|
||||
config = merge_configs(model.model_params, engine_config, algorithm_config)
|
||||
|
||||
model = load_model(config.model)
|
||||
data_loader = create_data_loader(engine_config, model)
|
||||
engine = create_engine(config.engine, data_loader=data_loader, metric=None)
|
||||
pipeline = create_pipeline(config.compression.algorithms, engine)
|
||||
|
||||
optimized_model = pipeline.run(model)
|
||||
final_model_name = model_name + '_' + algo_name
|
||||
check_model(tmp_path, optimized_model, final_model_name, model_framework)
|
Loading…
Reference in New Issue
Block a user