[POT] Add transformer block pattern (#15044)
This commit is contained in:
parent
051597cf2c
commit
b8d51a9e1b
@ -293,6 +293,21 @@ def create_softmax_reshape_transpose_matmul_pattern():
|
||||
return pattern.set_name('softmax_reshape_transpose_matmul').pattern
|
||||
|
||||
|
||||
@registry_ignore_patterns('blocks')
|
||||
def create_softmax_reshape_transpose_gather_matmul_pattern():
|
||||
pattern = PatternBuilder()
|
||||
pattern_2 = PatternBuilder()
|
||||
softmax_out = pattern.append_single_op('SoftMax', 'softmax').get_last_node()
|
||||
pattern_2.append_single_op('Add', 'add').get_last_node()
|
||||
pattern_2.append_op_const('Reshape', 'reshape')
|
||||
pattern_2.append_single_op('Transpose', 'transpose').get_last_node()
|
||||
gather_out = pattern_2.append_single_op('Gather', 'gather').get_last_node()
|
||||
pattern.pattern['nodes'] += pattern_2.pattern['nodes']
|
||||
pattern.pattern['edges'] += pattern_2.pattern['edges']
|
||||
pattern.insert_single_op([softmax_out, gather_out], None, 'MatMul', 'matmul')
|
||||
return pattern.set_name('softmax_reshape_transpose_gather_matmul').pattern
|
||||
|
||||
|
||||
@registry_ignore_patterns('blocks')
|
||||
def create_hswish_without_denominator_pattern():
|
||||
pattern = PatternBuilder()
|
||||
|
Loading…
Reference in New Issue
Block a user