[POT] Update BC with the Parameter nodes connection (#10848)

* Update BC with the Parameter nodes connection

* Update test_sanity with octave
This commit is contained in:
Nikita Malinin 2022-03-10 10:28:47 +03:00 committed by GitHub
parent d7372d678c
commit 4746d0881b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 5 deletions

View File

@ -260,6 +260,7 @@ class BiasCorrection(Algorithm):
outputs_shapes = {nu.create_node_name(n): nu.get_output_shape(n, 0).copy() for n in input_nodes}
inputs_data = []
param_type = 'Parameter'
nodes_data = []
for input_node in input_nodes:
input_node_name = nu.create_node_name(input_node)
c_input_shape = outputs_shapes[input_node_name]
@ -271,16 +272,19 @@ class BiasCorrection(Algorithm):
parameter_name = input_node_name + '/parameter'
param_node = ge.create_node(input_node.graph, parameter_name, param_type,
{'shape': c_input_shape, 'data_type': input_node_data_type})
for _, port in input_node.out_ports().items():
for in_port in port.get_destinations():
in_port.disconnect()
in_port.connect(param_node.out_port(0))
nodes_data.append((input_node, param_node))
inputs_data.append({
'param_name': parameter_name,
'param_shape': tuple(c_input_shape),
'input_name': input_node_name
})
for input_node, param_node in nodes_data:
for _, port in input_node.out_ports().items():
for in_port in port.get_destinations():
in_port.disconnect()
in_port.connect(param_node.out_port(0))
return inputs_data
def _create_results_after_nodes(self, output_nodes):

View File

@ -54,7 +54,9 @@ TEST_MODELS = [
('mtcnn', 'caffe', 'DefaultQuantization', 'performance', 1, {'recall': 0.76, 'map': 0.6844}, {}, 'CPU'),
('mtcnn', 'caffe', 'DefaultQuantization', 'performance', 2, {'recall': 0.76, 'map': 0.6638},
{'use_fast_bias': False}, 'CPU')
{'use_fast_bias': False}, 'CPU'),
('octave-resnet-26-0.25', 'mxnet', 'DefaultQuantization', 'performance', 300,
{'accuracy@top1': 0.766, 'accuracy@top5': 0.927}, {'use_fast_bias': False}, 'CPU'),
]
CASCADE_MAP = Dict({
'mtcnn': {