Fix output layout of tf yolo models converted with transformations config (#9583)
This commit is contained in:
parent
cf344a3b73
commit
9fb9e19efa
@ -23,7 +23,7 @@ class YoloRegionAddon(FrontReplacementFromConfigFileGeneral):
|
|||||||
op_outputs = [n for n, d in graph.nodes(data=True) if 'op' in d and d['op'] == 'Result']
|
op_outputs = [n for n, d in graph.nodes(data=True) if 'op' in d and d['op'] == 'Result']
|
||||||
for op_output in op_outputs:
|
for op_output in op_outputs:
|
||||||
last_node = Node(graph, op_output).in_node(0)
|
last_node = Node(graph, op_output).in_node(0)
|
||||||
op_params = dict(name=last_node.id + '/YoloRegion', axis=1, end_axis=-1)
|
op_params = dict(name=last_node.id + '/YoloRegion', axis=1, end_axis=-1, nchw_layout=True)
|
||||||
op_params.update(replacement_descriptions)
|
op_params.update(replacement_descriptions)
|
||||||
region_layer = RegionYoloOp(graph, op_params)
|
region_layer = RegionYoloOp(graph, op_params)
|
||||||
region_layer_node = region_layer.create_node([last_node])
|
region_layer_node = region_layer.create_node([last_node])
|
||||||
@ -51,7 +51,7 @@ class YoloV3RegionAddon(FrontReplacementFromConfigFileGeneral):
|
|||||||
'Refer to documentation about converting YOLO models for more information.'.format(
|
'Refer to documentation about converting YOLO models for more information.'.format(
|
||||||
', '.join(replacement_descriptions['entry_points']), input_node_name))
|
', '.join(replacement_descriptions['entry_points']), input_node_name))
|
||||||
last_node = Node(graph, input_node_name).in_node(0)
|
last_node = Node(graph, input_node_name).in_node(0)
|
||||||
op_params = dict(name=last_node.id + '/YoloRegion', axis=1, end_axis=-1, do_softmax=0)
|
op_params = dict(name=last_node.id + '/YoloRegion', axis=1, end_axis=-1, do_softmax=0, nchw_layout=True)
|
||||||
op_params.update(replacement_descriptions)
|
op_params.update(replacement_descriptions)
|
||||||
if 'masks' in op_params:
|
if 'masks' in op_params:
|
||||||
op_params['mask'] = op_params['masks'][i]
|
op_params['mask'] = op_params['masks'][i]
|
||||||
|
Loading…
Reference in New Issue
Block a user