[POT] Update BC with the Parameter nodes connection (#10848)
* Update BC with the Parameter nodes connection * Update test_sanity with octave
This commit is contained in:
parent
d7372d678c
commit
4746d0881b
@ -260,6 +260,7 @@ class BiasCorrection(Algorithm):
|
||||
outputs_shapes = {nu.create_node_name(n): nu.get_output_shape(n, 0).copy() for n in input_nodes}
|
||||
inputs_data = []
|
||||
param_type = 'Parameter'
|
||||
nodes_data = []
|
||||
for input_node in input_nodes:
|
||||
input_node_name = nu.create_node_name(input_node)
|
||||
c_input_shape = outputs_shapes[input_node_name]
|
||||
@ -271,16 +272,19 @@ class BiasCorrection(Algorithm):
|
||||
parameter_name = input_node_name + '/parameter'
|
||||
param_node = ge.create_node(input_node.graph, parameter_name, param_type,
|
||||
{'shape': c_input_shape, 'data_type': input_node_data_type})
|
||||
for _, port in input_node.out_ports().items():
|
||||
for in_port in port.get_destinations():
|
||||
in_port.disconnect()
|
||||
in_port.connect(param_node.out_port(0))
|
||||
nodes_data.append((input_node, param_node))
|
||||
inputs_data.append({
|
||||
'param_name': parameter_name,
|
||||
'param_shape': tuple(c_input_shape),
|
||||
'input_name': input_node_name
|
||||
})
|
||||
|
||||
for input_node, param_node in nodes_data:
|
||||
for _, port in input_node.out_ports().items():
|
||||
for in_port in port.get_destinations():
|
||||
in_port.disconnect()
|
||||
in_port.connect(param_node.out_port(0))
|
||||
|
||||
return inputs_data
|
||||
|
||||
def _create_results_after_nodes(self, output_nodes):
|
||||
|
@ -54,7 +54,9 @@ TEST_MODELS = [
|
||||
('mtcnn', 'caffe', 'DefaultQuantization', 'performance', 1, {'recall': 0.76, 'map': 0.6844}, {}, 'CPU'),
|
||||
|
||||
('mtcnn', 'caffe', 'DefaultQuantization', 'performance', 2, {'recall': 0.76, 'map': 0.6638},
|
||||
{'use_fast_bias': False}, 'CPU')
|
||||
{'use_fast_bias': False}, 'CPU'),
|
||||
('octave-resnet-26-0.25', 'mxnet', 'DefaultQuantization', 'performance', 300,
|
||||
{'accuracy@top1': 0.766, 'accuracy@top5': 0.927}, {'use_fast_bias': False}, 'CPU'),
|
||||
]
|
||||
CASCADE_MAP = Dict({
|
||||
'mtcnn': {
|
||||
|
Loading…
Reference in New Issue
Block a user