[POT] Add pattern se blocks (#8425)

* fix(patterns): add pattern se_block

* update reference
This commit is contained in:
Indira Salyahova 2021-11-15 12:53:07 +03:00 committed by GitHub
parent c981d2f0dd
commit 29a3f56003
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 2 deletions

View File

@ -292,6 +292,46 @@ graph TB
---
**Name:** se_block<br/>
**Pattern:** <br/>
```mermaid
graph TB
input(ReduceMean) --> conv_fc(Convolution, MatMul)
conv_fc(Convolution, MatMul) --> bias(Add)
bias_const(Const) --> bias(Add)
bias(Add) --> act(ReLU, PReLU)
act(ReLU, PReLU) --> conv_fc_1(Convolution, MatMul)
conv_fc_1(Convolution, MatMul) --> bias_1(Add)
bias_const_1(Const) --> bias_1(Add)
bias_1(Add) --> act_1(Sigmoid)
act_1(Sigmoid) --> output(Multiply)
style input fill:#73C2FB
style output fill:#73C2FB
```
---
**Name:** se_block_swish_activation<br/>
**Pattern:** <br/>
```mermaid
graph TB
input(ReduceMean) --> conv_fc(Convolution, MatMul)
conv_fc(Convolution, MatMul) --> bias(Add)
bias_const(Const) --> bias(Add)
bias(Add) --> swish(Swish)
swish(Swish) --> conv_fc_1(Convolution, MatMul)
conv_fc_1(Convolution, MatMul) --> bias_1(Add)
bias_const_1(Const) --> bias_1(Add)
bias_1(Add) --> act_1(Sigmoid)
act_1(Sigmoid) --> output(Multiply)
style input fill:#73C2FB
style output fill:#73C2FB
```
---
**Name:** softmax<br/>
**Pattern:** <br/>

View File

@ -28,6 +28,26 @@ def create_swish_pattern():
return pattern.set_name('swish_activation').pattern
@registry_ignore_patterns('blocks')
def create_se_pattern():
"""
Removing this pattern can drop accuracy after quantization of model w/ SE-blocks
"""
pattern = PatternBuilder()
pattern.insert_se(start_name='input', end_name='output')
return pattern.set_name('se_block').pattern
@registry_ignore_patterns('blocks')
def create_se_swish_pattern():
"""
Removing this pattern can drop accuracy after quantization of model w/ SE-blocks
"""
pattern = PatternBuilder()
pattern.insert_se(start_name='input', end_name='output', is_swish=True)
return pattern.set_name('se_block_swish_activation').pattern
@registry_ignore_patterns('blocks')
def create_biased_op_pattern():
pattern = PatternBuilder()

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:df841d3e81aa6476e54787bcfec5e334bee27b80bbad7c8446b0f7f120ca5e7e
size 570777
oid sha256:181402cfe46282cf96d82f1e0b68f2ba5ccbcaccbb1ce3e712fd1b4cb5883917
size 463708