[POT] Enable quantization inside subgraphs for CPU (#12776)
* Enable LSTMCell quantization * Refactored unify function * reverse dict * Fix docs * Add LSTM and GRU sequences * Change scales to unified_scales in hw configs * Fix inplace statistics * Shift axis for subgraphs * Changed HW config * Update HW configs for tests * Not quantize GRU ops with linear_before_reset: true * Enable tests * Support models with scalar input * Fix tests * return back submodules * Fix tests * Update GPU HW config * Fix comments * Fix axis
This commit is contained in:
parent
5eae673220
commit
bb00a9e664
@ -150,5 +150,5 @@ When the optimized model is a cascaded one (consists of several submodels, for e
|
||||
* Samples:
|
||||
* [Quantization of 3D segmentation model](https://github.com/openvinotoolkit/openvino/tree/master/tools/pot/openvino/tools/pot/api/samples/3d_segmentation)
|
||||
* [Quantization of Face Detection model](https://github.com/openvinotoolkit/openvino/tree/master/tools/pot/openvino/tools/pot/api/samples/face_detection)
|
||||
* [Quantizatin of speech model for GNA device](https://github.com/openvinotoolkit/openvino/tree/master/tools/pot/openvino/tools/pot/api/samples/speech)
|
||||
* [Quantization of speech model for GNA device](https://github.com/openvinotoolkit/openvino/tree/master/tools/pot/openvino/tools/pot/api/samples/speech)
|
||||
|
||||
|
@ -347,6 +347,7 @@ class BiasCorrection(Algorithm):
|
||||
|
||||
if model_copy.is_cascade:
|
||||
ref_stats_layout = {add_name: {'mean_per_channel': TensorStatisticAxis(asf.mean_per_channel_axis,
|
||||
graph_depth=add_name.count('|'),
|
||||
channel=self._channel_axis)}}
|
||||
self._engine.set_model(model_copy)
|
||||
_, q_outputs = self._engine.predict(ref_stats_layout, self._sampler)
|
||||
@ -422,7 +423,7 @@ class BiasCorrection(Algorithm):
|
||||
add_node = self._get_add_node_for_bias(node)
|
||||
add_node_name = add_node.fullname
|
||||
if 'orig_node_name' in add_node:
|
||||
add_node_name = nu.reset_node_fullname(add_node_name, add_node['orig_node_name'])
|
||||
add_node_name = add_node['orig_node_name']
|
||||
axis = OPERATIONS_CHANNEL_AXIS[node.type]
|
||||
self._channel_axis[add_node_name] = axis
|
||||
node_name = node.fullname
|
||||
@ -434,6 +435,7 @@ class BiasCorrection(Algorithm):
|
||||
{'mean_per_channel': TensorStatisticAxis(granularity='perchannel',
|
||||
type='mean',
|
||||
inplace_statistics=self.config['inplace_statistics'],
|
||||
graph_depth=add_node_name.count('|'),
|
||||
channel=self._channel_axis)}
|
||||
statistics_layout[add_node_name]["shape"] = TensorStatistic(func=lambda x, **kwargs: x.shape,
|
||||
shape_for_inference=True)
|
||||
|
@ -131,7 +131,7 @@ def compute_stats_layouts(config, model, qscheme=None):
|
||||
fake_quantize_config = {}
|
||||
for fq in fq_nodes:
|
||||
is_weights = fq['fq_group'] == 'weights'
|
||||
fq_config = copy(fq_configuration[fq.name][fq['fq_group']])
|
||||
fq_config = copy(fq_configuration[fq.fullname][fq['fq_group']])
|
||||
fake_quantize_config[fq.fullname] = fq_config
|
||||
if fq.fullname in config.layerwise_configs[0]:
|
||||
fq_config = Dict(merge_nested_dicts(fq_config, config.layerwise_configs[0][fq.fullname]))
|
||||
|
@ -6,14 +6,14 @@ from copy import deepcopy
|
||||
|
||||
from .range_estimator import get_range_estimator_config
|
||||
from .utils import load_hardware_config
|
||||
from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNIFY_OUTPUTS, CONCAT_UNIFY_INPUTS
|
||||
from ...graph.special_operations import QUANTIZE_AGNOSTIC_OPERATIONS, CONCAT_UNIFY_OUTPUTS,\
|
||||
CONCAT_UNIFY_INPUTS, TYPES_TO_QUANTIZABLE_PORTS
|
||||
from ...graph.utils import find_operation_matches, get_operation_list, is_data_type_quantizable,\
|
||||
get_hardware_config_operation_type
|
||||
from ...graph.model_utils import get_nodes_by_type, get_node_by_name
|
||||
from ...graph.node_utils import get_input_shape, get_all_node_outputs,\
|
||||
get_node_input, get_node_inputs, get_node_data_type, check_const_input
|
||||
get_node_input, get_node_inputs, get_node_data_type, check_const_input, reset_node_fullname
|
||||
from ...graph.passes import traverse_graph
|
||||
|
||||
from ...utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -212,6 +212,9 @@ def get_configurations_by_preset(config, model, fq_to_hw_confs, hardware_config)
|
||||
for key in cur_conf[fq]:
|
||||
if key in ACTIVATION_QUANTIZATION_MODES:
|
||||
if with_concat or unclear_layout or broadcasting:
|
||||
if not isinstance(cur_conf[fq]['activations'], list):
|
||||
cur_conf[fq]['activations'] = [cur_conf[fq]['activations']]
|
||||
|
||||
configuration = [c for c in cur_conf[fq][key] if c['granularity'] == 'pertensor']
|
||||
else:
|
||||
configuration = cur_conf[fq][key]
|
||||
@ -280,12 +283,51 @@ def get_configurations_by_qscheme(fq_to_hw_confs, qscheme):
|
||||
return res
|
||||
|
||||
|
||||
def unify_recurrent_fqs(model, recurrent_hw_ops):
|
||||
recurrent_fqs_to_unify = []
|
||||
|
||||
def source_fn(op):
|
||||
return [p for p in get_node_inputs(op) if p and p.type != 'Const']
|
||||
|
||||
def criteria_fn(op):
|
||||
return op.type == 'FakeQuantize'
|
||||
|
||||
def get_fqs_fullname(node, port):
|
||||
input_node = get_node_input(node, port)
|
||||
if criteria_fn(input_node):
|
||||
return [input_node.fullname]
|
||||
_, criteria = traverse_graph(input_node,
|
||||
move_fn=source_fn,
|
||||
stop_criteria_fn=criteria_fn,
|
||||
criteria_fns=criteria_fn)
|
||||
if criteria_fn in criteria:
|
||||
input_fqs = [reset_node_fullname(input_node.fullname, node_name) for node_name in criteria[criteria_fn]]
|
||||
return input_fqs
|
||||
return []
|
||||
|
||||
def recurrent_fq_to_unify(cell_fullname, fqs):
|
||||
unique_fqs = set().union(*fqs)
|
||||
is_unified = all([get_node_input(get_node_by_name(model, name), 0).type != 'Const' for name in unique_fqs])
|
||||
if len(unique_fqs) >= 2 and is_unified:
|
||||
recurrent_fqs_to_unify.append([[cell_fullname], list(unique_fqs)])
|
||||
|
||||
nodes = get_nodes_by_type(model, types=recurrent_hw_ops.keys(), recursively=True)
|
||||
for node in nodes:
|
||||
unify_group_indices = recurrent_hw_ops[node.type]
|
||||
for indices in unify_group_indices:
|
||||
fqs = [get_fqs_fullname(node, i) for i in indices]
|
||||
recurrent_fq_to_unify(node.fullname, fqs)
|
||||
|
||||
return recurrent_fqs_to_unify
|
||||
|
||||
|
||||
def find_fqs_to_unify(model, config):
|
||||
def _get_unified_scales_ops(hw_ops_):
|
||||
unified_scales_ops_ = []
|
||||
for hw_op in hw_ops_:
|
||||
if 'attributes' in hw_op and 'scales' in hw_op['attributes']:
|
||||
del hw_op['attributes']['scales']
|
||||
if 'attributes' in hw_op and 'unified_scales' in hw_op['attributes'] and \
|
||||
hw_op['attributes']['unified_scales'] == 'all':
|
||||
del hw_op['attributes']['unified_scales']
|
||||
if not hw_op['attributes']:
|
||||
del hw_op['attributes']
|
||||
unified_scales_ops_.append(hw_op)
|
||||
@ -371,6 +413,13 @@ def find_fqs_to_unify(model, config):
|
||||
def _has_const_input(layer):
|
||||
return 'Const' in [parent.type for parent in get_node_inputs(layer) if parent]
|
||||
|
||||
def _get_unified_recurrent_scales_ops(hw_ops_):
|
||||
unified_scales_ops_ = {}
|
||||
for hw_op in hw_ops_:
|
||||
if hw_op['type'] in TYPES_TO_QUANTIZABLE_PORTS and 'unified_scales' in hw_op['attributes']:
|
||||
unified_scales_ops_[hw_op['type']] = hw_op['attributes']['unified_scales']
|
||||
return unified_scales_ops_
|
||||
|
||||
def _process_node(node_, stack_, visited_, to_unify_):
|
||||
visited_[node_.fullname] = True
|
||||
if _is_unified_scales_op(node_) or _is_agnostic_branching_op(node_):
|
||||
@ -397,8 +446,9 @@ def find_fqs_to_unify(model, config):
|
||||
per_channel_quantizable = _get_quantizable_per_ch_ops(hardware_config)
|
||||
hw_ops = get_operation_list(hardware_config)
|
||||
quantize_agnostic_ops = [op[1] for op in find_operation_matches(QUANTIZE_AGNOSTIC_OPERATIONS, hw_ops)]
|
||||
recurrent_hw_ops = _get_unified_recurrent_scales_ops(hw_ops)
|
||||
unified_scales_ops = _get_unified_scales_ops(hw_ops)
|
||||
if not unified_scales_ops:
|
||||
if not (unified_scales_ops or recurrent_hw_ops):
|
||||
return []
|
||||
|
||||
visited = defaultdict(lambda: False)
|
||||
@ -418,6 +468,9 @@ def find_fqs_to_unify(model, config):
|
||||
len(to_unify[1]) > 1:
|
||||
fqs_to_unify.append(to_unify)
|
||||
|
||||
recurrent_fqs = unify_recurrent_fqs(model, recurrent_hw_ops)
|
||||
fqs_to_unify.extend(recurrent_fqs)
|
||||
|
||||
fqs_to_unify = sorted([[sorted(c[0]), sorted(c[1])] for c in fqs_to_unify])
|
||||
logger.debug('Operations and corresponding fake quantize nodes to unify scales:')
|
||||
for ops, fqs in fqs_to_unify:
|
||||
@ -482,7 +535,7 @@ def change_configurations_by_model_type(model, config, fq_configuration, hardwar
|
||||
def change_configurations_by_model_type_transformer(model, fq_configuration, hardware_config):
|
||||
fq_types = _fake_quantize_to_types(model, hardware_config)
|
||||
for fq in get_nodes_by_type(model, ['FakeQuantize']):
|
||||
node_creator_fq, fq_group = fq_types[fq.name]
|
||||
node_creator_fq, fq_group = fq_types[fq.fullname]
|
||||
is_weights = fq_group == 'weights'
|
||||
node_name = None
|
||||
for name, type_node in node_creator_fq:
|
||||
|
@ -149,12 +149,14 @@ class FastBiasCorrection(Algorithm):
|
||||
inputs_outputs_layout[op_output_name] = {
|
||||
"mean_per_channel": TensorStatisticAxis(inplace_statistics=inplace_statistics,
|
||||
granularity='perchannel', type='mean',
|
||||
graph_depth=op_output_name.count('|'),
|
||||
channel=self._channel_axis)}
|
||||
|
||||
input_name = get_quantized_input_key(quantized_node)
|
||||
inputs_outputs_layout[input_name] = {
|
||||
"mean_per_channel": TensorStatisticAxis(inplace_statistics=inplace_statistics,
|
||||
granularity='perchannel', type='mean',
|
||||
graph_depth=op_output_name.count('|'),
|
||||
channel=self._channel_axis)}
|
||||
inputs_outputs_layout[input_name]["shape"] = TensorStatistic(func=lambda x, **kwargs: x.shape,
|
||||
shape_for_inference=True)
|
||||
|
@ -158,7 +158,7 @@ describe internal representation of the DL model and how to work with it.
|
||||
```
|
||||
class openvino.tools.pot.IEEngine(config, data_loader=None, metric=None)
|
||||
```
|
||||
IEEngine is a helper which implements Engine class based on [OpenVINO™ Inference Engine Python* API](ie_python_api/api.html).
|
||||
IEEngine is a helper which implements Engine class based on [OpenVINO™ Inference Engine Python* API](https://docs.openvino.ai/latest/api/ie_python_api/api.html).
|
||||
This class support inference in synchronous and asynchronous modes and can be reused as-is in the custom pipeline or
|
||||
with some modifications, e.g. in case of custom post-processing of inference results.
|
||||
|
||||
|
@ -6,12 +6,12 @@
|
||||
:maxdepth: 1
|
||||
:hidden:
|
||||
|
||||
Quantizatiing Image Classification Model <pot_example_classification_README>
|
||||
Quantizatiing Object Detection Model with Accuracy Control <pot_example_object_detection_README>
|
||||
Quantizatiing Cascaded Model <pot_example_face_detection_README>
|
||||
Quantizatiing Semantic Segmentation Model <pot_example_segmentation_README>
|
||||
Quantizatiing 3D Segmentation Model <pot_example_3d_segmentation_README>
|
||||
Quantizatiing for GNA Device <pot_example_speech_README>
|
||||
Quantizing Image Classification Model <pot_example_classification_README>
|
||||
Quantizing Object Detection Model with Accuracy Control <pot_example_object_detection_README>
|
||||
Quantizing Cascaded Model <pot_example_face_detection_README>
|
||||
Quantizing Semantic Segmentation Model <pot_example_segmentation_README>
|
||||
Quantizing 3D Segmentation Model <pot_example_3d_segmentation_README>
|
||||
Quantizing for GNA Device <pot_example_speech_README>
|
||||
|
||||
@endsphinxdirective
|
||||
|
||||
@ -56,7 +56,7 @@ The following examples demonstrate the implementation of `Engine`, `Metric`, and
|
||||
|
||||
6. [Quantizing for GNA Device](./speech/README.md)
|
||||
- Uses models from Kaldi
|
||||
- Implements `DataLoader` to data in .ark format
|
||||
- Implements `DataLoader` to load data in .ark format
|
||||
- Uses DefaultQuantization algorithm for quantization model
|
||||
|
||||
After execution of each example above the quantized model is placed into the folder `optimized`. The accuracy validation of the quantized model is performed right after the quantization.
|
||||
|
@ -118,7 +118,7 @@ class MTCNNEngine(IEEngine):
|
||||
if sampler is None:
|
||||
sampler = BatchSampler(self.data_loader)
|
||||
if stats_layout:
|
||||
model_with_stat_op, nodes_names_map, output_to_node_names = self._statistic_graph_builder. \
|
||||
model_with_stat_op, nodes_names_map, node_to_result_names = self._statistic_graph_builder. \
|
||||
insert_statistic(copy.deepcopy(self._nx_model),
|
||||
stats_layout, stat_aliases)
|
||||
self.set_model(model_with_stat_op)
|
||||
@ -137,7 +137,7 @@ class MTCNNEngine(IEEngine):
|
||||
|
||||
align_stat_names_with_results(model_output_names,
|
||||
nodes_name,
|
||||
output_to_node_names,
|
||||
node_to_result_names,
|
||||
stats_layout,
|
||||
stat_aliases)
|
||||
|
||||
@ -153,8 +153,8 @@ class MTCNNEngine(IEEngine):
|
||||
process_accumulated_stats(stat_names_aliases=stat_names_aliases,
|
||||
accumulated_stats=self._accumulated_layer_stats)
|
||||
|
||||
if stats_layout:
|
||||
restore_original_node_names(output_to_node_names, accumulated_stats, stats_layout, stat_aliases)
|
||||
if stats_layout and stat_aliases:
|
||||
restore_original_node_names(node_to_result_names, accumulated_stats, stats_layout, stat_aliases)
|
||||
|
||||
metrics = None
|
||||
if self._metric:
|
||||
|
@ -225,7 +225,27 @@
|
||||
{
|
||||
"type": "Concat",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "LSTMSequence",
|
||||
"quantization": {
|
||||
"activations": "q8_a",
|
||||
"weights": "q8_w_sym"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": [[0, 1], [4, 5]]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "GRUSequence",
|
||||
"quantization": {
|
||||
"activations": "q8_a",
|
||||
"weights": "q8_w_sym"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": [[0, 1], [3, 4]]
|
||||
}
|
||||
},
|
||||
{"type": "Reshape"},
|
||||
|
@ -225,7 +225,7 @@
|
||||
{
|
||||
"type": "Concat",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{"type": "Reshape"},
|
||||
|
@ -72,162 +72,162 @@
|
||||
},
|
||||
{
|
||||
"type": "Add",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": ["q8_tn", "q4_tn", "q8_tn"],
|
||||
"weights": ["q8_ch", "q4_tn", "q8_ch"]
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Multiply",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": ["q8_tn", "q4_tn", "q8_tn"],
|
||||
"weights": ["q8_ch", "q4_tn", "q8_ch"]
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Maximum",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": ["q8_tn", "q4_tn", "q8_tn"],
|
||||
"weights": ["q8_ch", "q4_tn", "q8_ch"]
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Less",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "LessEqual",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Greater",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "GreaterEqual",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Divide",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Minimum",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": ["q8_tn", "q4_tn", "q8_tn"],
|
||||
"weights": ["q8_ch", "q4_tn", "q8_ch"]
|
||||
"weights": ["q8_ch", "q4_tn", "q8_ch"],
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Equal",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Subtract",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": ["q8_tn", "q4_tn", "q8_tn"],
|
||||
"weights": ["q8_ch", "q4_tn", "q8_ch"]
|
||||
"weights": ["q8_ch", "q4_tn", "q8_ch"],
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "NotEqual",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "FloorMod",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "LogicalOr",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "LogicalXor",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "LogicalAnd",
|
||||
"attributes": {
|
||||
"scales": "unified"
|
||||
},
|
||||
"quantization": {
|
||||
"activations": "q8_tn",
|
||||
"weights": "q8_ch"
|
||||
},
|
||||
"attributes": {
|
||||
"unified_scales": "all"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
@ -125,7 +125,7 @@ class ACEngine(Engine):
|
||||
callback_layout, stat_names_aliases = {}, {}
|
||||
# add outputs for activation statistics collection
|
||||
if stats_layout is not None:
|
||||
model_with_stat_op, nodes_names_map, output_to_node_names = self._statistic_graph_builder.\
|
||||
model_with_stat_op, nodes_names_map, node_to_result_names = self._statistic_graph_builder.\
|
||||
insert_statistic(copy.deepcopy(self._nx_model),
|
||||
stats_layout, stat_aliases)
|
||||
self.set_model(model_with_stat_op)
|
||||
@ -149,7 +149,7 @@ class ACEngine(Engine):
|
||||
|
||||
align_stat_names_with_results(model_output_names,
|
||||
nodes_name,
|
||||
output_to_node_names,
|
||||
node_to_result_names,
|
||||
stats_layout,
|
||||
stat_aliases)
|
||||
|
||||
@ -214,8 +214,8 @@ class ACEngine(Engine):
|
||||
self._per_sample_metrics.clear()
|
||||
self.dump_prediction_to_annotation = False
|
||||
|
||||
if stats_layout:
|
||||
restore_original_node_names(output_to_node_names, accumulated_stats, stats_layout, stat_aliases)
|
||||
if stats_layout and stat_aliases:
|
||||
restore_original_node_names(node_to_result_names, accumulated_stats, stats_layout, stat_aliases)
|
||||
|
||||
return metrics, accumulated_stats
|
||||
|
||||
|
@ -94,7 +94,7 @@ class IEEngine(Engine):
|
||||
|
||||
stat_names_aliases = None
|
||||
if stats_layout:
|
||||
model_with_stat_op, nodes_names_map, output_to_node_names = self._statistic_graph_builder.\
|
||||
model_with_stat_op, nodes_names_map, node_to_result_names = self._statistic_graph_builder.\
|
||||
insert_statistic(copy.deepcopy(self._nx_model),
|
||||
stats_layout, stat_aliases)
|
||||
self.set_model(model_with_stat_op)
|
||||
@ -109,7 +109,7 @@ class IEEngine(Engine):
|
||||
|
||||
align_stat_names_with_results(model_output_names,
|
||||
nodes_name,
|
||||
output_to_node_names,
|
||||
node_to_result_names,
|
||||
stats_layout,
|
||||
stat_aliases)
|
||||
|
||||
@ -127,8 +127,8 @@ class IEEngine(Engine):
|
||||
process_accumulated_stats(accumulated_stats=self._accumulated_layer_stats,
|
||||
stat_names_aliases=stat_names_aliases)
|
||||
|
||||
if stats_layout:
|
||||
restore_original_node_names(output_to_node_names, accumulated_stats, stats_layout, stat_aliases)
|
||||
if stats_layout and stat_aliases:
|
||||
restore_original_node_names(node_to_result_names, accumulated_stats, stats_layout, stat_aliases)
|
||||
|
||||
# Calculate metrics of required type. Reset collected statistics
|
||||
metrics = None
|
||||
@ -259,7 +259,7 @@ class IEEngine(Engine):
|
||||
)
|
||||
input_data_batched = input_data_batched.squeeze()
|
||||
if is_sampler_batchfied:
|
||||
if input_data_batched.shape[batch_dim] != len(input_data):
|
||||
if len(input_data_batched.shape) > batch_dim and input_data_batched.shape[batch_dim] != len(input_data):
|
||||
input_data_batched = np.expand_dims(input_data_batched, batch_dim)
|
||||
|
||||
if is_dynamic_input(input_blob):
|
||||
|
@ -119,30 +119,32 @@ def update_stats(stats_layout: dict, stat_aliases: dict, old_key: str, new_key:
|
||||
stat_aliases[algo_name][new_key] = stat_aliases[algo_name].pop(old_key)
|
||||
|
||||
|
||||
def restore_original_node_names(output2node, accumulated_stats, stats_layout, stat_aliases):
|
||||
if output2node and stats_layout:
|
||||
for out_name, original_node_name in output2node.items():
|
||||
def restore_original_node_names(node2result, accumulated_stats, stats_layout, stat_aliases):
|
||||
if node2result and stats_layout:
|
||||
for original_node_name, out_name in node2result.items():
|
||||
accumulated_stats[original_node_name] = accumulated_stats.pop(out_name)
|
||||
update_stats(stats_layout, stat_aliases, out_name, original_node_name)
|
||||
|
||||
|
||||
def align_stat_names_with_results(result_names, nodes_name, output2node, stats_layout, stat_aliases):
|
||||
""" Change node name in stast to result name if in the original model the subgraph had 1 output,
|
||||
def align_stat_names_with_results(result_names, nodes_name, node2result, stats_layout, stat_aliases):
|
||||
""" Change node name in stats to result name if in the original model the subgraph had 1 output,
|
||||
but after adding outputs in the subgraph, the number of output ports increased.
|
||||
For such nodes, it is necessary to add a '.0' to the original output name
|
||||
:param: result_names: names of Result nodes
|
||||
:param: nodes_name: node name in graph
|
||||
:param: output2node: a dict storing the matching of the result to the node
|
||||
:param: node2result: a dict storing the matching of the result to the node
|
||||
:param: stats_layout: dict of stats collection functions
|
||||
:param: stat_aliases: dict of algorithms collections stats
|
||||
"""
|
||||
if output2node:
|
||||
if node2result:
|
||||
for original_out_name in nodes_name:
|
||||
if original_out_name not in result_names and (original_out_name, 0) not in stats_layout:
|
||||
if isinstance(original_out_name, str) and \
|
||||
original_out_name not in result_names and \
|
||||
(original_out_name, 0) not in stats_layout:
|
||||
out_name_with_port = original_out_name + '.0'
|
||||
assert out_name_with_port in result_names
|
||||
update_stats(stats_layout, stat_aliases, original_out_name, out_name_with_port)
|
||||
output2node[out_name_with_port] = original_out_name
|
||||
node2result[original_out_name] = out_name_with_port
|
||||
|
||||
|
||||
def process_raw_output(raw_output):
|
||||
|
@ -52,25 +52,25 @@ def make_copy_fake_quantize(nodes, edges, fq):
|
||||
fq_attrs['levels'] = int(fq_attrs['levels'])
|
||||
|
||||
nodes.extend([
|
||||
(fq.name, fq.type, fq_attrs),
|
||||
(input_low.name, input_low.type,
|
||||
(fq.fullname, fq.type, fq_attrs),
|
||||
(input_low.fullname, input_low.type,
|
||||
{'value': input_low.value}),
|
||||
(input_height.name, input_height.type,
|
||||
(input_height.fullname, input_height.type,
|
||||
{'value': input_height.value}),
|
||||
(output_low.name, output_low.type,
|
||||
(output_low.fullname, output_low.type,
|
||||
{'value': output_low.value}),
|
||||
(output_height.name, output_height.type,
|
||||
(output_height.fullname, output_height.type,
|
||||
{'value': output_height.value}),
|
||||
(weights.name, weights.type, {'value': weights.value.copy()})])
|
||||
(weights.fullname, weights.type, {'value': weights.value.copy()})])
|
||||
|
||||
edges.extend([
|
||||
(weights.name, fq.name, {'out': 0, 'in': 0}),
|
||||
(input_low.name, fq.name, {'out': 0, 'in': 1}),
|
||||
(input_height.name, fq.name, {'out': 0, 'in': 2}),
|
||||
(output_low.name, fq.name, {'out': 0, 'in': 3}),
|
||||
(output_height.name, fq.name, {'out': 0, 'in': 4})
|
||||
(weights.fullname, fq.fullname, {'out': 0, 'in': 0}),
|
||||
(input_low.fullname, fq.fullname, {'out': 0, 'in': 1}),
|
||||
(input_height.fullname, fq.fullname, {'out': 0, 'in': 2}),
|
||||
(output_low.fullname, fq.fullname, {'out': 0, 'in': 3}),
|
||||
(output_height.fullname, fq.fullname, {'out': 0, 'in': 4})
|
||||
])
|
||||
return fq.name
|
||||
return fq.fullname
|
||||
|
||||
|
||||
def make_copy_graph_attrs(model, input_name, input_shape):
|
||||
@ -95,7 +95,7 @@ def make_copy_graph_attrs(model, input_name, input_shape):
|
||||
|
||||
|
||||
def build_graph_for_node(model, input_name, input_shape, node, remove_bias=False, remove_fake_quantize=False):
|
||||
""" Build the Graph (input - node - output). The Convolution, FullyConnected node types are supported.
|
||||
""" Build the Graph (input - node - output). The Convolution, MatMul node types are supported.
|
||||
:param model: source model
|
||||
:param input_name: name of the input node in the generated graph
|
||||
:param input_shape: shape of the input node in the generated graph
|
||||
@ -113,36 +113,36 @@ def build_graph_for_node(model, input_name, input_shape, node, remove_bias=False
|
||||
if node.has_valid('output') and node.has_valid('get_output_feature_dim'):
|
||||
node_attrs['get_output_feature_dim'] = None
|
||||
|
||||
nodes.append((node.name, node.type, node_attrs))
|
||||
edges.append((input_name, node.name, {'out': 0, 'in': 0}))
|
||||
nodes.append((node.fullname, node.type, node_attrs))
|
||||
edges.append((input_name, node.fullname, {'out': 0, 'in': 0}))
|
||||
|
||||
parent_nodes = get_node_inputs(node)
|
||||
if parent_nodes[1].type == 'FakeQuantize' and not remove_fake_quantize:
|
||||
fq = parent_nodes[1]
|
||||
fq_name = make_copy_fake_quantize(nodes, edges, fq)
|
||||
edges.append((fq_name, node.name, {'out': 0, 'in': 1}))
|
||||
edges.append((fq_name, node.fullname, {'out': 0, 'in': 1}))
|
||||
else:
|
||||
weights = parent_nodes[1]
|
||||
nodes.append((weights.name, weights.type, {'value': weights.value.copy()}))
|
||||
edges.append((weights.name, node.name, {'out': 0, 'in': 1}))
|
||||
nodes.append((weights.fullname, weights.type, {'value': weights.value.copy()}))
|
||||
edges.append((weights.fullname, node.fullname, {'out': 0, 'in': 1}))
|
||||
|
||||
if not remove_bias:
|
||||
if parent_nodes[2].type == 'FakeQuantize' and not remove_fake_quantize:
|
||||
fq = parent_nodes[1]
|
||||
fq_name = make_copy_fake_quantize(nodes, edges, fq)
|
||||
edges.append((fq_name, node.name, {'out': 0, 'in': 2}))
|
||||
edges.append((fq_name, node.fullname, {'out': 0, 'in': 2}))
|
||||
else:
|
||||
weights = parent_nodes[2]
|
||||
nodes.append((weights.name, weights.type, {'value': weights.value.copy()}))
|
||||
edges.append((weights.name, node.name, {'out': 0, 'in': 2}))
|
||||
nodes.append((weights.fullname, weights.type, {'value': weights.value.copy()}))
|
||||
edges.append((weights.fullname, node.fullname, {'out': 0, 'in': 2}))
|
||||
|
||||
result_name = '{}/out'.format(node.name)
|
||||
result_name = '{}/out'.format(node.fullname)
|
||||
nodes.append((result_name, 'Result', {}))
|
||||
edges.append((node.name, result_name, {'out': 0, 'in': 0}))
|
||||
edges.append((node.fullname, result_name, {'out': 0, 'in': 0}))
|
||||
graph = build_graph(*make_copy_graph_attrs(model, input_name, input_shape), nodes, edges)
|
||||
|
||||
# Add the neccessary attribute to the new graph
|
||||
src_node = get_node_by_name(graph, node.name)
|
||||
src_node = get_node_by_name(graph, node.fullname)
|
||||
weights_node = get_node_input(src_node, 1)
|
||||
weights_node = get_node_input(weights_node, 0) \
|
||||
if weights_node.type == 'FakeQuantize' else weights_node
|
||||
|
@ -41,7 +41,7 @@ def find_node(graph: Graph, name):
|
||||
|
||||
|
||||
# TODO: set recursively = True to enable subgraphs quantization
|
||||
def get_node_by_name(graph: Graph, name: str, recursively: bool = False) -> Node:
|
||||
def get_node_by_name(graph: Graph, name: str, recursively: bool = True) -> Node:
|
||||
""" Returns node by name
|
||||
:param graph: NetworkX model to take node
|
||||
:param name: name of the node
|
||||
@ -105,7 +105,7 @@ def connect_nodes_by_name(graph: Graph, src_node_name, src_port, dst_node_name,
|
||||
|
||||
|
||||
# TODO: set recursively = True to enable subgraphs quantization
|
||||
def get_all_operation_nodes(graph: Graph, recursively: bool = False):
|
||||
def get_all_operation_nodes(graph: Graph, recursively: bool = True):
|
||||
""" Returns sequence of all nodes in graph
|
||||
:param graph: NetworkX model to take nodes
|
||||
:param recursively: whether return all nodes from the graph
|
||||
@ -121,7 +121,7 @@ def get_all_operation_nodes(graph: Graph, recursively: bool = False):
|
||||
|
||||
|
||||
# TODO: set recursively = True to enable subgraphs quantization
|
||||
def get_nodes_by_type(graph: Graph, types: list, recursively: bool = False) -> list:
|
||||
def get_nodes_by_type(graph: Graph, types: list, recursively: bool = True) -> list:
|
||||
""" Returns all nodes with type from types collection
|
||||
:param graph: NetworkX model to collect nodes
|
||||
:param types: list of required types
|
||||
|
@ -5,6 +5,7 @@ import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from openvino.tools.mo.graph.graph import Graph
|
||||
from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
|
||||
from openvino.tools.mo.utils.logger import init_logger
|
||||
from openvino.runtime import Core # pylint: disable=E0401,E0611
|
||||
@ -57,7 +58,7 @@ def load_graph(model_config, target_device='ANY'):
|
||||
meta_data['quantization_parameters'] = model_config.quantization_info
|
||||
graph_from_ir.meta_data = meta_data
|
||||
graph_from_ir.graph['cmd_params'] = orig_graph_from_ir.graph['cmd_params']
|
||||
remove_converts(graph_from_ir)
|
||||
for_graph_and_each_sub_graph_recursively(graph_from_ir, remove_converts)
|
||||
model_preprocessing(graph_from_ir)
|
||||
if os.path.exists(serialized_xml_path):
|
||||
os.remove(serialized_xml_path)
|
||||
|
@ -4,6 +4,7 @@
|
||||
import networkx as nx
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.middle.passes.infer import type_infer
|
||||
from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||
|
||||
from . import editor as ge, builder as gb
|
||||
from .nx_model import CompressedModel
|
||||
@ -55,11 +56,11 @@ def add_outputs(models, node_names):
|
||||
def compress_model_weights(model: CompressedModel):
|
||||
"""Apply transformations to save model weights to INT8."""
|
||||
for model_dict in model.models:
|
||||
compress_weights(model_dict['model'])
|
||||
for_graph_and_each_sub_graph_recursively(model_dict['model'], compress_weights)
|
||||
|
||||
|
||||
# TODO: set recursively = True to enable subgraphs quantization
|
||||
def get_nodes_by_type(model: CompressedModel, types: list, recursively: bool = False):
|
||||
def get_nodes_by_type(model: CompressedModel, types: list, recursively: bool = True):
|
||||
""" Returns all nodes with type from types collection
|
||||
:param model: CompressedModel model
|
||||
:param types: list of required types
|
||||
@ -87,7 +88,7 @@ def get_node_by_name(model: CompressedModel, name: str) -> Node:
|
||||
|
||||
|
||||
# TODO: set recursively = True to enable subgraphs quantization
|
||||
def get_all_operation_nodes(model: CompressedModel, recursively: bool = False):
|
||||
def get_all_operation_nodes(model: CompressedModel, recursively: bool = True):
|
||||
""" Returns sequence of all nodes in all graphs
|
||||
:param model: CompressedModel model
|
||||
:param recursively: whether return all nodes from the model
|
||||
|
@ -288,7 +288,8 @@ def create_node_name(input_node, mode=tuple):
|
||||
|
||||
|
||||
def get_node_data_type(node, port_id=0):
|
||||
if node.type != 'Const' and node.in_port(port_id).get_source() is not None \
|
||||
if node.type != 'Const' and port_id in node.in_ports() \
|
||||
and node.in_port(port_id).get_source() is not None \
|
||||
and node.in_port(port_id).get_source().is_data_type_defined():
|
||||
return node.in_port(port_id).get_source().get_data_type()
|
||||
return None
|
||||
|
@ -28,7 +28,7 @@ from . import node_utils as nu
|
||||
from .editor import get_nodes_by_type
|
||||
from .pattern_utils import get_fq_result_pattern
|
||||
from .special_operations import OPERATIONS_WITH_WEIGHTS, DETECTION_OUTPUT_FINAL_TYPES, \
|
||||
SPLIT_OPERATIONS, OPERATIONS_WITH_BIAS
|
||||
SPLIT_OPERATIONS, OPERATIONS_WITH_BIAS, TYPES_TO_QUANTIZABLE_PORTS
|
||||
from .utils import find_operation_matches, is_ignored, get_hw_aware_ignored_patterns
|
||||
from ..graph.node_utils import get_all_node_outputs, get_node_inputs, get_node_input, get_weights_for_node
|
||||
from ..graph.special_patterns import get_ignored_patterns
|
||||
@ -149,8 +149,9 @@ class InsertFakeQuantize(BackReplacementPattern):
|
||||
|
||||
if m_op.type in ['Convolution', 'ConvolutionBackpropData', 'MatMul']:
|
||||
insert_fake_quantize(graph, m_op, [0, 1], hw_config=self.hardware_config, input_priority_types=self.input_priority_types)
|
||||
elif m_op.type == 'LSTMCell':
|
||||
insert_fake_quantize(graph, m_op, [0, 1, 2, 3, 4], hw_config=self.hardware_config, input_priority_types=self.input_priority_types)
|
||||
elif m_op.type in TYPES_TO_QUANTIZABLE_PORTS:
|
||||
ports = TYPES_TO_QUANTIZABLE_PORTS[m_op.type]
|
||||
insert_fake_quantize(graph, m_op, ports, hw_config=self.hardware_config, input_priority_types=self.input_priority_types)
|
||||
elif self.quantize_only_input(m_op):
|
||||
insert_fake_quantize(graph, m_op, [0], hw_config=self.hardware_config, input_priority_types=self.input_priority_types)
|
||||
else:
|
||||
@ -756,10 +757,10 @@ class FakeQuantizeNameSwapper(BackReplacementPattern):
|
||||
if len(input_node_outputs) > 1 and all([op.type == 'FakeQuantize' for op in input_node_outputs]):
|
||||
new_fq_name += '.{}'.format(fq_node.in_port(0).get_source().idx)
|
||||
|
||||
fq_node['orig_fq_name'] = copy(fq_node.name)
|
||||
fq_node['orig_fq_name'] = nu.reset_node_fullname(input_node.fullname, copy(fq_node.name))
|
||||
|
||||
if 'orig_node_name' not in input_node:
|
||||
input_node['orig_node_name'] = copy(input_node.name)
|
||||
input_node['orig_node_name'] = copy(input_node.fullname)
|
||||
rename_node(input_node, f'{input_node.name}/pre_fq_input')
|
||||
rename_node(fq_node, new_fq_name)
|
||||
|
||||
@ -831,6 +832,7 @@ def create_fake_quantize_node(graph: Graph, name, data_type=np.float32, **kwargs
|
||||
|
||||
def insert_fake_quantize(graph, node, ports=None, names=None, fq_types=None, hw_config=None, input_priority_types=[]):
|
||||
blobs_as_inputs_nodes_type = ['Convolution', 'Deconvolution', 'MatMul']
|
||||
gru_node_types = ['GRUCell', 'GRUSequence']
|
||||
|
||||
port_name = None
|
||||
if ports is not None and names is not None:
|
||||
@ -850,6 +852,10 @@ def insert_fake_quantize(graph, node, ports=None, names=None, fq_types=None, hw_
|
||||
if 'bin' in node.in_edges()[idx]:
|
||||
del node.in_edges()[idx]['bin']
|
||||
|
||||
# Temporary WA until oneDNN supports it (ticket 82164)
|
||||
if node.type in gru_node_types and node.linear_before_reset:
|
||||
continue
|
||||
|
||||
if ports is not None and idx not in ports:
|
||||
continue
|
||||
|
||||
@ -875,7 +881,7 @@ def insert_fake_quantize(graph, node, ports=None, names=None, fq_types=None, hw_
|
||||
fq_configs = []
|
||||
if hw_config is not None:
|
||||
node_type = get_hardware_config_operation_type(node, list(hw_config.keys()))
|
||||
if hw_config[node_type]:
|
||||
if hw_config[node_type] and hw_config[node_type][fq_group]:
|
||||
fq_configs = hw_config[node_type][fq_group]
|
||||
else:
|
||||
node_type = None
|
||||
|
@ -65,6 +65,9 @@ DETECTION_OUTPUT_FINAL_TYPES = [
|
||||
{'type': 'TopK'}
|
||||
]
|
||||
|
||||
# TODO: Add attributes to GraphTransformer hw_config
|
||||
TYPES_TO_QUANTIZABLE_PORTS = {'LSTMSequence': [0, 1, 4, 5], 'GRUSequence': [0, 1, 3, 4]}
|
||||
|
||||
ELTWISE_TYPES = ['Add', 'Multiply', 'Subtract', 'Divide', 'Less', 'LessEqual', 'Greater', 'GreaterEqual',
|
||||
'Equal', 'NotEqual', 'FloorMod', 'LogicalOr', 'LogicalXor', 'LogicalAnd', 'Maximum', 'Minimum']
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (C) 2020-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||
from openvino.tools.mo.middle.passes.infer import type_infer
|
||||
|
||||
from .editor import add_fullname_for_nodes
|
||||
@ -83,10 +84,7 @@ class GraphTransformer:
|
||||
for model_dict in model.models:
|
||||
self.fq_insertion.ignored_params = ignored_params_[model_dict['name']] if model.is_cascade \
|
||||
else ignored_params_
|
||||
self._insert_fake_quantize(model_dict['model'])
|
||||
# TODO: Uncomment to enable subgraphs quantization
|
||||
# from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||
# for_graph_and_each_sub_graph_recursively(model_dict['model'], self._insert_fake_quantize)
|
||||
for_graph_and_each_sub_graph_recursively(model_dict['model'], self._insert_fake_quantize)
|
||||
add_fullname_for_nodes(model_dict['model'])
|
||||
return model
|
||||
|
||||
|
@ -73,6 +73,8 @@ def mean_per_channel(acts, **_):
|
||||
@compute_act_stats_fn_per_channel.register('mean_axis')
|
||||
def mean_per_channel_axis(acts, layer_key=None, **kwargs):
|
||||
axis = kwargs.get('channel', {}).get(layer_key, 1)
|
||||
if axis >= 0:
|
||||
axis += kwargs.get('graph_depth', 0)
|
||||
return calculate_per_channel_stats(acts, np.mean, axis=axis)
|
||||
|
||||
|
||||
|
@ -24,15 +24,25 @@ from ..statistics.function_selector import ACTIVATIONS, get_stats_function
|
||||
# pylint: disable=R0912
|
||||
class StatisticGraphBuilder:
|
||||
def insert_statistic(self, model, stats_layout, stat_aliases=None):
|
||||
output_to_node_names = {}
|
||||
node_to_result_names = {}
|
||||
nodes_names_map = {m['model'].name: {} for m in model.models}
|
||||
if stat_aliases is None or model is None:
|
||||
if stat_aliases is None:
|
||||
for node_name in stats_layout.keys():
|
||||
node_name_in_graph = self.get_graph_node_name(node_name)
|
||||
node = get_node_by_name(model, node_name_in_graph)
|
||||
node_graph = node.graph
|
||||
nodes_names_map[node_graph.name][node_name] = convert_to_outputs_name(node_name)
|
||||
return model, nodes_names_map, output_to_node_names
|
||||
node_in_main_graph = get_node_by_name(model, node_name_in_graph.split('|')[0])
|
||||
model_graph = node_in_main_graph.graph
|
||||
|
||||
if model_graph == node_graph:
|
||||
nodes_names_map[model_graph.name][node_name] = convert_to_outputs_name(node_name)
|
||||
else:
|
||||
result_name = self.add_subgraph_output(model_graph, node_name)
|
||||
node_to_result_names[node_name] = result_name
|
||||
for node_name, result_name in node_to_result_names.items():
|
||||
stats_layout[result_name] = stats_layout.pop(node_name)
|
||||
return model, nodes_names_map, node_to_result_names
|
||||
|
||||
copy_stat_aliases = deepcopy(stat_aliases)
|
||||
for algo_name, node_stats in copy_stat_aliases.items():
|
||||
for node_name, stats in node_stats.items():
|
||||
@ -42,17 +52,17 @@ class StatisticGraphBuilder:
|
||||
model_graph = node_in_main_graph.graph
|
||||
for stat, _ in list(stats.items()):
|
||||
if not isinstance(stat, Statistic) or not stat.kwargs.get('inplace_statistics', False):
|
||||
if node_name not in nodes_names_map[model_graph.name]:
|
||||
if node_name not in nodes_names_map[model_graph.name] and model_graph == node.graph:
|
||||
nodes_names_map[model_graph.name][node_name] = convert_to_outputs_name(node_name)
|
||||
continue
|
||||
type_stat = stat.kwargs['type']
|
||||
add_output_node, op_name = getattr(self, f'insert_{type_stat}')(model_graph,
|
||||
node,
|
||||
type_stat,
|
||||
node_name,
|
||||
node_name, # name with port
|
||||
**stat.kwargs)
|
||||
if add_output_node:
|
||||
if node_name not in nodes_names_map[model_graph.name]:
|
||||
if node_name not in nodes_names_map[model_graph.name] and model_graph == node.graph:
|
||||
nodes_names_map[model_graph.name][node_name] = convert_to_outputs_name(op_name)
|
||||
class_statistic = TensorStatistic if isinstance(stat, TensorStatistic) else TensorStatisticAxis
|
||||
fn = get_stats_function(ACTIVATIONS, type_stat, stat.kwargs.get('granularity'),
|
||||
@ -76,35 +86,31 @@ class StatisticGraphBuilder:
|
||||
|
||||
# add output if node in subgraph
|
||||
if model_graph != node.graph:
|
||||
if node_name in nodes_names_map[model_graph.name]:
|
||||
del nodes_names_map[model_graph.name][node_name]
|
||||
|
||||
# Don't need adding extra output to the same node, but for another algo
|
||||
if node_name_in_graph in output_to_node_names.values():
|
||||
result_name = next((result for result, node in output_to_node_names.items()
|
||||
if node == node_name_in_graph))
|
||||
if node_name in node_to_result_names:
|
||||
result_name = node_to_result_names[node_name]
|
||||
else:
|
||||
model_graph.graph['additional_outputs'] = node_name_in_graph.split('|')
|
||||
results = AddOutputRecursive().find_and_replace_pattern(model_graph)
|
||||
assert len(results) == 1
|
||||
result_name = results[0].name
|
||||
if node_name in stats_layout:
|
||||
stats_layout[result_name] = stats_layout.pop(node_name)
|
||||
stat_aliases[algo_name][result_name] = stat_aliases[algo_name].pop(node_name)
|
||||
output_to_node_names[result_name] = node_name_in_graph
|
||||
result_name = self.add_subgraph_output(model_graph, node_name)
|
||||
node_to_result_names[node_name] = result_name
|
||||
|
||||
return model, nodes_names_map, output_to_node_names
|
||||
for node_name, result_name in node_to_result_names.items():
|
||||
stats_layout[result_name] = stats_layout.pop(node_name)
|
||||
for algo_name in copy_stat_aliases:
|
||||
if node_name in stat_aliases[algo_name]:
|
||||
stat_aliases[algo_name][result_name] = stat_aliases[algo_name].pop(node_name)
|
||||
|
||||
return model, nodes_names_map, node_to_result_names
|
||||
|
||||
def insert_reduce(self, model_graph, insert_op, node, granularity, type_stat, node_name, axis=1):
|
||||
axis_const = self.find_axis(node, granularity, axis)
|
||||
out_port = self.get_out_port(node_name)
|
||||
axis_const = self.find_axis(node, granularity, axis, port=out_port)
|
||||
if isinstance(axis_const, str):
|
||||
return (True, node.name)
|
||||
|
||||
out_port = self.get_out_port(node_name)
|
||||
if out_port is not None:
|
||||
node_name = f'{node_name[0]}.{out_port}'
|
||||
reduce_op = create_op_node_with_second_input(node.graph, insert_op, int64_array(axis_const),
|
||||
dict(name=f'{type_stat}_{node_name}'))
|
||||
dict(name=f'{type_stat}_{node_name.split("|")[-1]}'))
|
||||
reduce_op['fullname'] = reset_node_fullname(node.fullname, reduce_op.name)
|
||||
if node.graph != model_graph:
|
||||
Op.create_data_node(reduce_op.graph, reduce_op, {'shape': [1]})
|
||||
@ -129,17 +135,18 @@ class StatisticGraphBuilder:
|
||||
axis_channel)
|
||||
|
||||
def insert_abs_max(self, model_graph, node, type_stat, node_name, **kwargs):
|
||||
axis_const = self.find_axis(node, kwargs.get('granularity'))
|
||||
out_port = self.get_out_port(node_name)
|
||||
axis_const = self.find_axis(node, kwargs.get('granularity'), port=out_port)
|
||||
if isinstance(axis_const, str):
|
||||
return (True, node.name)
|
||||
|
||||
out_port = self.get_out_port(node_name)
|
||||
if out_port is not None:
|
||||
node_name = f'{node_name[0]}.{out_port}'
|
||||
abs_node = Abs(node.graph, {"name": f'abs_{node_name}'}). \
|
||||
clean_name = node_name.split("|")[-1]
|
||||
abs_node = Abs(node.graph, {"name": f'abs_{clean_name}'}). \
|
||||
create_node_with_data([node.out_node(out_port if out_port else 0)]).in_node(0)
|
||||
max_op = create_op_node_with_second_input(node.graph, ReduceMax, int64_array(axis_const),
|
||||
dict(name=f'{type_stat}_{node_name}'))
|
||||
dict(name=f'{type_stat}_{clean_name}'))
|
||||
|
||||
if node.graph != model_graph:
|
||||
Op.create_data_node(max_op.graph, max_op, {'shape': [1]})
|
||||
@ -147,12 +154,11 @@ class StatisticGraphBuilder:
|
||||
abs_node.out_port(0).connect(max_op.in_port(0))
|
||||
return self.insert_result(model_graph, node, max_op, type_stat, out_port)
|
||||
|
||||
@staticmethod
|
||||
def insert_result(model_graph, node, child_node, name, port=None):
|
||||
def insert_result(self, model_graph, node, child_node, name, port=None):
|
||||
if node.graph != model_graph:
|
||||
model_graph.graph['additional_outputs'] = child_node.fullname.split('|')
|
||||
res_op = AddOutputRecursive().find_and_replace_pattern(model_graph)
|
||||
ie_result_name = res_op[0].name
|
||||
ie_result_name = self.add_subgraph_output(model_graph, child_node.fullname)
|
||||
else:
|
||||
ie_result_name = f'{name}_{node.name}'
|
||||
if port is not None:
|
||||
@ -162,8 +168,8 @@ class StatisticGraphBuilder:
|
||||
return (False, ie_result_name)
|
||||
|
||||
@staticmethod
|
||||
def find_axis(node, granularity, axis=1):
|
||||
shape = len(get_output_shape(node, 0))
|
||||
def find_axis(node, granularity, axis=1, port=None):
|
||||
shape = len(get_output_shape(node, port if port else 0))
|
||||
if shape < 3 and granularity == 'perchannel':
|
||||
return node.name
|
||||
axis_const = list(i for i in range(shape))
|
||||
@ -182,3 +188,11 @@ class StatisticGraphBuilder:
|
||||
def get_out_port(node_name):
|
||||
out_port = node_name[1] if isinstance(node_name, tuple) else None
|
||||
return out_port
|
||||
|
||||
def add_subgraph_output(self, model_graph, node_name):
|
||||
name = self.get_graph_node_name(node_name)
|
||||
port = node_name[1] if isinstance(node_name, tuple) else 0
|
||||
model_graph.graph['additional_outputs'] = name.split('|')
|
||||
results = AddOutputRecursive().find_and_replace_pattern(model_graph)
|
||||
result_name = results[port].name
|
||||
return result_name
|
||||
|
File diff suppressed because one or more lines are too long
@ -1 +1 @@
|
||||
[{"target_device": "CPU"}, {"primary_bitwidth": 8, "input_priority_types": []}, {"type": "Convolution", "quantization": {"weights": [{"mode": "symmetric", "bits": 4, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "perchannel"}, {"mode": "asymmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "asymmetric", "bits": 8, "granularity": "perchannel"}], "activations": [{"mode": "symmetric", "bits": 2, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 4, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "perchannel"}, {"mode": "asymmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "asymmetric", "bits": 8, "granularity": "perchannel"}]}}, {"type": "MatMul", "quantization": {"weights": [{"level_low": -127, "mode": "symmetric", "bits": 8, "level_high": 127, "granularity": "pertensor"}], "activations": [{"mode": "symmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "symmetric", "bits": 8, "granularity": "perchannel"}, {"mode": "asymmetric", "bits": 8, "granularity": "pertensor"}, {"mode": "asymmetric", "bits": 8, "granularity": "perchannel"}]}}]
|
||||
[{"target_device": "CPU"}, {"primary_bitwidth": 8, "input_priority_types": []}, {"type": "Convolution", "quantization": {"activations": [{"bits": 2, "mode": "symmetric", "granularity": "pertensor"}, {"bits": 4, "mode": "symmetric", "granularity": "pertensor"}, {"bits": 8, "mode": "symmetric", "granularity": "pertensor"}, {"bits": 8, "mode": "symmetric", "granularity": "perchannel"}, {"bits": 8, "mode": "asymmetric", "granularity": "pertensor"}, {"bits": 8, "mode": "asymmetric", "granularity": "perchannel"}], "weights": [{"bits": 4, "mode": "symmetric", "granularity": "pertensor"}, {"bits": 8, "mode": "symmetric", "granularity": "pertensor"}, {"bits": 8, "mode": "symmetric", "granularity": "perchannel"}, {"bits": 8, "mode": "asymmetric", "granularity": "pertensor"}, {"bits": 8, "mode": "asymmetric", "granularity": "perchannel"}]}}, {"type": "MatMul", "quantization": {"activations": [{"bits": 8, "mode": "symmetric", "granularity": "pertensor"}, {"bits": 8, "mode": "symmetric", "granularity": "perchannel"}, {"bits": 8, "mode": "asymmetric", "granularity": "pertensor"}, {"bits": 8, "mode": "asymmetric", "granularity": "perchannel"}], "weights": [{"bits": 8, "mode": "symmetric", "granularity": "pertensor", "level_low": -127, "level_high": 127}]}}]
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a57db75ff796b259d9291c1f69ebe27540e7b88bc214b4ed4eebf7f25341c316
|
||||
size 260
|
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:75dbf91a789ad20e045bd6e853a81c2af8aaef7caf6afc0a62c47f80e49f7aba
|
||||
size 81835
|
||||
oid sha256:bc32d4292f1206f8a1d81fb3795ec96c205278160ff0955338a3c7d5af1bc4e9
|
||||
size 167486
|
||||
|
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:40715bc6862c36af61230b8afc67cc30f70c3e2b9fa78c2c20d1f8afd095fb99
|
||||
size 103631
|
||||
oid sha256:9166a46afced78799c364134a0a4c9ef6eadd5cd188ec7c6136d684a1a89476b
|
||||
size 165788
|
||||
|
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:81a8f061fbe4e248b0150151a1506b90ad1f6b66e9a37399d6e371fae215eb76
|
||||
size 23986
|
||||
oid sha256:7a75521eee344b9fb4bc896a492416b799fc15a912fe601746665764dd6cc679
|
||||
size 20481
|
||||
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:07601534d5916fdb50f2f8ed4afd82ea0a28f6c52a234879026a7a3ce54831fb
|
||||
size 57540
|
@ -28,7 +28,8 @@ TEST_MODELS = [
|
||||
('lstm_example', 'pytorch', 'GNA'),
|
||||
#('multiple_outputs_net_example', 'tf', 'GNA'),
|
||||
('resnet_example', 'pytorch', 'CPU_SPR'),
|
||||
#('tensor_iterator_example', 'tf', 'ANY'),
|
||||
('tensor_iterator_example', 'tf', 'ANY'),
|
||||
('ti_decomposition_example', 'tf', 'GNA'),
|
||||
('softsign_example', 'tf', 'GNA'),
|
||||
('gather_example', 'tf', 'CPU'),
|
||||
('split_concat_example', 'pytorch', 'ANY'),
|
||||
|
Loading…
Reference in New Issue
Block a user