[POT] Fixed bug with out port for nodes with multiple outputs (#10932)
* Fixed bug with out port * Add test * Fix test * Change test model
This commit is contained in:
committed by
GitHub
parent
cd361ecae1
commit
99f0093615
@@ -192,8 +192,11 @@ def get_quantized_input_key(quantized_node):
|
||||
Otherwise, key is tuple (fq_input name, output port number)
|
||||
"""
|
||||
quantized_input = get_node_input(quantized_node, 0)
|
||||
quantized_key = create_node_name(quantized_input)
|
||||
return quantized_key
|
||||
key = quantized_input.fullname
|
||||
if len(quantized_input.out_ports()) > 1:
|
||||
port_number = quantized_node.in_port(0).get_source().out
|
||||
key = (quantized_input.fullname, port_number)
|
||||
return key
|
||||
|
||||
|
||||
def node_with_quantized_weights(node):
|
||||
|
||||
@@ -49,7 +49,7 @@ class StatisticGraphBuilder:
|
||||
add_output_node, op_name = getattr(self, f'insert_{type_stat}')(model_graph,
|
||||
node,
|
||||
type_stat,
|
||||
node.name,
|
||||
node_name,
|
||||
**stat.kwargs)
|
||||
if add_output_node:
|
||||
if node_name not in nodes_names_map[model_graph.name]:
|
||||
@@ -99,13 +99,18 @@ class StatisticGraphBuilder:
|
||||
axis_const = self.find_axis(node, granularity, axis)
|
||||
if isinstance(axis_const, str):
|
||||
return (True, node.name)
|
||||
|
||||
out_port = self.get_out_port(node_name)
|
||||
if out_port is not None:
|
||||
node_name = f'{node_name[0]}.{out_port}'
|
||||
reduce_op = create_op_node_with_second_input(node.graph, insert_op, int64_array(axis_const),
|
||||
dict(name=f'{type_stat}_{node_name}'))
|
||||
reduce_op['fullname'] = reset_node_fullname(node.fullname, reduce_op.name)
|
||||
if node.graph != model_graph:
|
||||
Op.create_data_node(reduce_op.graph, reduce_op, {'shape': [1]})
|
||||
node.out_port(0).connect(reduce_op.in_port(0))
|
||||
return self.insert_result(model_graph, node, reduce_op, type_stat)
|
||||
|
||||
node.out_port(out_port if out_port else 0).connect(reduce_op.in_port(0))
|
||||
return self.insert_result(model_graph, node, reduce_op, type_stat, out_port)
|
||||
|
||||
def insert_min(self, model_graph, node, type_stat, node_name, **kwargs):
|
||||
return self.insert_reduce(model_graph, ReduceMin, node, kwargs.get('granularity'), type_stat, node_name)
|
||||
@@ -127,7 +132,12 @@ class StatisticGraphBuilder:
|
||||
axis_const = self.find_axis(node, kwargs.get('granularity'))
|
||||
if isinstance(axis_const, str):
|
||||
return (True, node.name)
|
||||
abs_node = Abs(node.graph, {"name": f'abs_{node_name}'}).create_node_with_data([node.out_node(0)]).in_node(0)
|
||||
|
||||
out_port = self.get_out_port(node_name)
|
||||
if out_port is not None:
|
||||
node_name = f'{node_name[0]}.{out_port}'
|
||||
abs_node = Abs(node.graph, {"name": f'abs_{node_name}'}). \
|
||||
create_node_with_data([node.out_node(out_port if out_port else 0)]).in_node(0)
|
||||
max_op = create_op_node_with_second_input(node.graph, ReduceMax, int64_array(axis_const),
|
||||
dict(name=f'{type_stat}_{node_name}'))
|
||||
|
||||
@@ -135,16 +145,18 @@ class StatisticGraphBuilder:
|
||||
Op.create_data_node(max_op.graph, max_op, {'shape': [1]})
|
||||
max_op['fullname'] = reset_node_fullname(node.fullname, max_op.name)
|
||||
abs_node.out_port(0).connect(max_op.in_port(0))
|
||||
return self.insert_result(model_graph, node, max_op, type_stat)
|
||||
return self.insert_result(model_graph, node, max_op, type_stat, out_port)
|
||||
|
||||
@staticmethod
|
||||
def insert_result(model_graph, node, child_node, name):
|
||||
def insert_result(model_graph, node, child_node, name, port=None):
|
||||
if node.graph != model_graph:
|
||||
model_graph.graph['additional_outputs'] = child_node.fullname.split('|')
|
||||
res_op = AddOutputRecursive().find_and_replace_pattern(model_graph)
|
||||
ie_result_name = res_op[0].name
|
||||
else:
|
||||
ie_result_name = f'{name}_{node.name}'
|
||||
if port is not None:
|
||||
ie_result_name = ie_result_name + f'.{port}'
|
||||
res_op = Result(node.graph, {'name': f'Result_{ie_result_name}'}).create_node()
|
||||
child_node.out_port(0).connect(res_op.in_port(0))
|
||||
return (False, ie_result_name)
|
||||
@@ -165,3 +177,8 @@ class StatisticGraphBuilder:
|
||||
node_name_in_graph = layout_name[0] if isinstance(layout_name, tuple) else layout_name
|
||||
node_name_in_graph = node_name_in_graph.replace('/pre_fq_input', '')
|
||||
return node_name_in_graph
|
||||
|
||||
@staticmethod
|
||||
def get_out_port(node_name):
|
||||
out_port = node_name[1] if isinstance(node_name, tuple) else None
|
||||
return out_port
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:52d25d0f9bc583d15d6e24072626466cbe35db93af13aca01444c37625e61f91
|
||||
size 171
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7b0bc2275947bd97cd11b20b31a69f9175c028386d9ea9ee7f6220de339772fd
|
||||
size 332731
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:466cb973bbeece0f2fc187e9dd2640ccbde151d82ff52daa1628a5120713af49
|
||||
size 125455
|
||||
oid sha256:f96a3afe2b788ae4e460c57a4ed25da10883a553439e27ce22a726a5d4075f20
|
||||
size 125489
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3a394d015d36fdbec8bdaea631a1f2c5383a7132f065a2e35bb664d5b620cb72
|
||||
size 256
|
||||
@@ -88,7 +88,7 @@ def create_(tmp_path, models, model_name, model_framework, quantization_mode,
|
||||
@pytest.mark.parametrize(
|
||||
'model_name, model_framework, quantization_mode, inplace_statistics, \
|
||||
algorithm, preset, granularity, add_output_nodes, type_max, type_min', TEST_MODELS,
|
||||
ids=['{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(m[0], m[1], m[2], m[3], m[4],
|
||||
ids=['{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(m[0], m[1], m[2], m[3], m[4].name,
|
||||
m[5], m[6], m[7], m[8], m[9]) for m in TEST_MODELS])
|
||||
def test_statistics_collector_subsets(tmp_path, models, model_name, model_framework,
|
||||
quantization_mode, inplace_statistics, algorithm,
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import numpy as np
|
||||
@@ -18,7 +17,7 @@ from openvino.tools.pot.algorithms.quantization.minmax.algorithm import MinMaxQu
|
||||
from openvino.tools.pot.algorithms.quantization.bias_correction.algorithm import BiasCorrection
|
||||
from .utils.config import PATHS2DATASETS_CONFIG
|
||||
|
||||
TEST_MODELS = [('mobilenet-v2-pytorch', 'pytorch')]
|
||||
TEST_MODELS = [('mobilenet-v2-pytorch', 'pytorch'), ('lstm_outs_quantization', 'tf')]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'model_name, model_framework', TEST_MODELS,
|
||||
@@ -57,7 +56,7 @@ def test_statistics_collector_subsets(tmp_path, models, model_name, model_framew
|
||||
out = {'MinMaxQuantization': collector.get_statistics_for_algorithm('MinMaxQuantization'),
|
||||
'BiasCorrection': collector.get_statistics_for_algorithm('BiasCorrection')}
|
||||
|
||||
refs_file = Path(__file__).parent / 'data/test_cases_refs/statistics_data.json'
|
||||
refs_file = Path(__file__).parent / 'data/test_cases_refs' / f'{model_name}_statistics_data.json'
|
||||
local_path = os.path.join(tmp_path, '{}_{}.json'.format(model_name, 'statistics_data'))
|
||||
local_file = open(local_path, 'w')
|
||||
|
||||
@@ -65,17 +64,25 @@ def test_statistics_collector_subsets(tmp_path, models, model_name, model_framew
|
||||
refs = json.load(file)
|
||||
|
||||
eps = 1e-3
|
||||
local_out = copy(out)
|
||||
for algo_name, algo_val in local_out.items():
|
||||
local_out = {}
|
||||
for algo_name, algo_val in out.items():
|
||||
local_out[algo_name] = {}
|
||||
for node_name, node_val in algo_val.items():
|
||||
if isinstance(node_name, tuple):
|
||||
name = f'{node_name[0]}.{node_name[1]}'
|
||||
else:
|
||||
name = node_name
|
||||
local_out[algo_name][name] = {}
|
||||
for stats_name, stats_val in node_val.items():
|
||||
local_out[algo_name][node_name][stats_name] = [np.array(v).tolist() for v in stats_val]
|
||||
local_out[algo_name][name][stats_name] = [np.array(v).tolist() for v in stats_val]
|
||||
json.dump(local_out, local_file)
|
||||
for algo_name, algo_val in out.items():
|
||||
for node_name, node_val in algo_val.items():
|
||||
for stats_name, stats_val in node_val.items():
|
||||
if stats_name in ['batch_mean_param_in', 'shape']:
|
||||
continue
|
||||
if isinstance(node_name, tuple):
|
||||
node_name = f'{node_name[0]}.{node_name[1]}'
|
||||
ref_stats_vals = refs[algo_name][node_name][stats_name]
|
||||
for ref_vals, vals in zip(ref_stats_vals, stats_val):
|
||||
assert np.max(np.abs(np.array(ref_vals) - vals)) < eps
|
||||
|
||||
Reference in New Issue
Block a user