Fix wrong attributes for Pad (#4216)
This commit is contained in:
parent
c1a606d507
commit
deca4fc443
@ -35,9 +35,11 @@ def pad_op_transform(graph: Graph, match: dict):
|
||||
log.info('The pad node "{}" with pad mode "{}" cannot be fused.'.format(pad_op.soft_get('name'), pad_op.mode))
|
||||
return
|
||||
|
||||
if pad_op.mode == 'constant' and pad_op.fill_value != 0.0:
|
||||
log.info('The pad node "{}" with non-zero fill value cannot be fused.'.format(pad_op.soft_get('name')))
|
||||
return
|
||||
if pad_op.mode == 'constant':
|
||||
fill_value = pad_op.in_port(3).data.get_value()
|
||||
if fill_value is None or fill_value != 0.0:
|
||||
log.info('The pad node "{}" with non-zero fill value cannot be fused.'.format(pad_op.soft_get('name')))
|
||||
return
|
||||
|
||||
input_tensor_dims = len(match['pad_output'].shape)
|
||||
for in_port in [1, 2]:
|
||||
|
@ -42,7 +42,6 @@ class Pad(Op):
|
||||
'infer': self.infer,
|
||||
|
||||
'mode': 'constant',
|
||||
'fill_value': float(0),
|
||||
|
||||
'force_precision_in_ports': {
|
||||
1: 'int64',
|
||||
@ -54,11 +53,7 @@ class Pad(Op):
|
||||
}, attrs)
|
||||
|
||||
def backend_attrs(self):
|
||||
return [('pad_mode', 'mode'),
|
||||
('pad_value', 'fill_value'),
|
||||
('pads_begin', lambda node: ','.join(map(str, node.pads[:, 0])) if node.has_valid('pads') else None),
|
||||
('pads_end', lambda node: ','.join(map(str, node.pads[:, 1])) if node.has_valid('pads') else None),
|
||||
]
|
||||
return [('pad_mode', 'mode')]
|
||||
|
||||
@staticmethod
|
||||
def infer(node):
|
||||
|
Loading…
Reference in New Issue
Block a user