[POT] Add transformer block pattern (#15044)

This commit is contained in:
Liubov Talamanova 2023-01-12 11:37:21 +00:00 committed by GitHub
parent 051597cf2c
commit b8d51a9e1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -293,6 +293,21 @@ def create_softmax_reshape_transpose_matmul_pattern():
return pattern.set_name('softmax_reshape_transpose_matmul').pattern
@registry_ignore_patterns('blocks')
def create_softmax_reshape_transpose_gather_matmul_pattern():
pattern = PatternBuilder()
pattern_2 = PatternBuilder()
softmax_out = pattern.append_single_op('SoftMax', 'softmax').get_last_node()
pattern_2.append_single_op('Add', 'add').get_last_node()
pattern_2.append_op_const('Reshape', 'reshape')
pattern_2.append_single_op('Transpose', 'transpose').get_last_node()
gather_out = pattern_2.append_single_op('Gather', 'gather').get_last_node()
pattern.pattern['nodes'] += pattern_2.pattern['nodes']
pattern.pattern['edges'] += pattern_2.pattern['edges']
pattern.insert_single_op([softmax_out, gather_out], None, 'MatMul', 'matmul')
return pattern.set_name('softmax_reshape_transpose_gather_matmul').pattern
@registry_ignore_patterns('blocks')
def create_hswish_without_denominator_pattern():
pattern = PatternBuilder()