[ MO ] DepthToSpace & ShuffleChannels fusion (#2001)
* [ MO ] ShuffleChannel fusion * DepthToSpace fusion * test * comment
This commit is contained in:
parent
38eb4a2398
commit
510c699731
@ -16,6 +16,8 @@
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.FuseTransposesSequence import FuseTransposesSequence
|
||||
from extensions.ops.depth_to_space import DepthToSpaceOp
|
||||
from extensions.ops.shufflechannel import ShuffleChannels
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph
|
||||
@ -33,27 +35,29 @@ class ShuffleChannelPatternOptimization(BackReplacementPattern):
|
||||
return dict(
|
||||
nodes=[
|
||||
('t_start_order', {'type': 'Const'}),
|
||||
('t_start_order_d', {'value': lambda value: value is not None and np.all(np.array_equal(value, [0, 2, 3, 1]))}),
|
||||
('t_start_order_d',
|
||||
{'value': lambda v: v is not None and np.all(np.array_equal(v, [0, 2, 3, 1]))}),
|
||||
('t_start', {'type': 'Transpose'}),
|
||||
('t_start_d', {}),
|
||||
|
||||
('reshape_dim', {'type': 'Const'}),
|
||||
('reshape_dim_d', {'value': lambda value: value is not None and value.size == 5 and np.all(value[0] == -1)}),
|
||||
('reshape_dim_d',
|
||||
{'value': lambda v: v is not None and v.size == 5 and np.all(v[0] == -1)}),
|
||||
('reshape_start', {'type': 'Reshape'}),
|
||||
('reshape_start_d', {}),
|
||||
|
||||
('t_5d_order', {'type': 'Const'}),
|
||||
('t_5d_order_d', {'value': lambda value: value is not None and np.all(np.array_equal(value, [0, 1, 2, 4, 3]))}),
|
||||
('t_5d_order_d', {'value': lambda v: v is not None and np.all(np.array_equal(v, [0, 1, 2, 4, 3]))}),
|
||||
('t_5d', {'type': 'Transpose'}),
|
||||
('t_5d_d', {}),
|
||||
|
||||
('reshape_1_dim', {'type': 'Const'}),
|
||||
('reshape_1_dim_d', {'value': lambda value: value is not None and value.size == 4 and np.all(value[0] == -1)}),
|
||||
('reshape_1_dim_d', {'value': lambda v: v is not None and v.size == 4 and np.all(v[0] == -1)}),
|
||||
('reshape_end', {'type': 'Reshape'}),
|
||||
('reshape_end_d', {}),
|
||||
|
||||
('t_end_order', {'type': 'Const'}),
|
||||
('t_end_order_d', {'value': lambda value: value is not None and np.all(np.array_equal(value, [0, 3, 1, 2]))}),
|
||||
('t_end_order_d', {'value': lambda v: v is not None and np.all(np.array_equal(v, [0, 3, 1, 2]))}),
|
||||
('t_end', {'type': 'Transpose'}),
|
||||
],
|
||||
edges=[
|
||||
@ -126,3 +130,149 @@ class ShuffleChannelPatternOptimization(BackReplacementPattern):
|
||||
|
||||
match['reshape_1_dim']['value'] = int64_array(np.take(new_end.in_port(1).data.get_value(), [0, 3, 1, 2]))
|
||||
match['reshape_1_dim'].infer(match['reshape_1_dim'])
|
||||
|
||||
|
||||
class ShuffleChannelFusion(BackReplacementPattern):
|
||||
"""
|
||||
FUSION: Reshape->Transpose->Reshape to ShuffleChannel
|
||||
We are able to perform the fusion if the pattern satisfies the conditions:
|
||||
1. Pattern input 4D shape is the same as pattern output 4D shape
|
||||
2. First Reshape splits channel dimension (1 axis) into two dimensions
|
||||
3. Transpose permutes only splitted dimensions
|
||||
4. Second Reshape pack them back
|
||||
|
||||
Fixes original models reshape-ability (Smart reshape)
|
||||
"""
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [FuseTransposesSequence]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('reshape_0_pattern', dict(type='Const')),
|
||||
('reshape_0_pattern_d', dict(value=lambda v: v is not None and v.size == 5 and np.all(v > 0))),
|
||||
('reshape_0', dict(type='Reshape')),
|
||||
('reshape_0_d', dict()),
|
||||
|
||||
('order', dict(type='Const')),
|
||||
('order_d', dict(value=lambda v: v is not None and np.array_equal([0, 2, 1, 3, 4], v))),
|
||||
('transpose', dict(type='Transpose')),
|
||||
('transpose_d', {}),
|
||||
|
||||
('reshape_1_pattern', dict(type='Const')),
|
||||
('reshape_1_pattern_d', dict(value=lambda v: v is not None and v.size == 4 and np.all(v > 0))),
|
||||
('reshape_1', dict(type='Reshape')),
|
||||
],
|
||||
edges=[
|
||||
('reshape_0_pattern', 'reshape_0_pattern_d'),
|
||||
('reshape_0_pattern_d', 'reshape_0'),
|
||||
('reshape_0', 'reshape_0_d'),
|
||||
('reshape_0_d', 'transpose'),
|
||||
('order', 'order_d'),
|
||||
('order_d', 'transpose'),
|
||||
('transpose', 'transpose_d'),
|
||||
('transpose_d', 'reshape_1'),
|
||||
('reshape_1_pattern', 'reshape_1_pattern_d'),
|
||||
('reshape_1_pattern_d', 'reshape_1'),
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
channel_splitting_reshape = match['reshape_0']
|
||||
channel_concating_reshape = match['reshape_1']
|
||||
|
||||
initial_shape = channel_splitting_reshape.in_port(0).data.get_shape()
|
||||
resulting_shape = channel_concating_reshape.in_port(1).data.get_value()
|
||||
if not np.array_equal(initial_shape, resulting_shape):
|
||||
return
|
||||
|
||||
channel_splitted_out_shape = channel_splitting_reshape.in_port(1).data.get_value()
|
||||
if not all([initial_shape[i] == channel_splitted_out_shape[j] for i, j in {0: 0, 2: 3, 3: 4}.items()]):
|
||||
return
|
||||
|
||||
name = channel_concating_reshape.soft_get('name', channel_concating_reshape.id)
|
||||
group = channel_splitted_out_shape[1]
|
||||
shuffle_channel = ShuffleChannels(graph, {'name': name, 'group': group}).create_node()
|
||||
channel_concating_reshape.out_port(0).get_connection().set_source(shuffle_channel.out_port(0))
|
||||
shuffle_channel.in_port(0).connect(channel_splitting_reshape.in_port(0).get_source())
|
||||
|
||||
|
||||
class DepthToSpaceFusion(BackReplacementPattern):
|
||||
"""
|
||||
FUSION: Reshape->Transpose->Reshape to DepthToSpace
|
||||
We are able to perform the fusion if the pattern satisfies the conditions:
|
||||
1. Pattern has 6D input and 4D output
|
||||
2. First Reshape splits channel dimension (1 axis) into three dimensions [new_depth, block_size, block_size]
|
||||
3. Transpose permutes splitted dimensions with spatial ones
|
||||
4. Second Reshape pack block size together with spatial dimension
|
||||
|
||||
Fixes original models reshape-ability (Smart reshape)
|
||||
"""
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [FuseTransposesSequence]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('reshape_0_pattern', dict(type='Const')),
|
||||
('reshape_0_pattern_d', dict(value=lambda v: v is not None and v.size == 6 and np.all(v > 0))),
|
||||
('reshape_0', dict(type='Reshape')),
|
||||
('reshape_0_d', dict()),
|
||||
|
||||
('order', dict(type='Const')),
|
||||
('order_d', dict(value=lambda v: v is not None and np.array_equal([0, 1, 4, 2, 5, 3], v))),
|
||||
('transpose', dict(type='Transpose')),
|
||||
('transpose_d', {}),
|
||||
|
||||
('reshape_1_pattern', dict(type='Const')),
|
||||
('reshape_1_pattern_d', dict(value=lambda v: v is not None and v.size == 4 and np.all(v > 0))),
|
||||
('reshape_1', dict(type='Reshape')),
|
||||
],
|
||||
edges=[
|
||||
('reshape_0_pattern', 'reshape_0_pattern_d'),
|
||||
('reshape_0_pattern_d', 'reshape_0'),
|
||||
('reshape_0', 'reshape_0_d'),
|
||||
('reshape_0_d', 'transpose'),
|
||||
('order', 'order_d'),
|
||||
('order_d', 'transpose'),
|
||||
('transpose', 'transpose_d'),
|
||||
('transpose_d', 'reshape_1'),
|
||||
('reshape_1_pattern', 'reshape_1_pattern_d'),
|
||||
('reshape_1_pattern_d', 'reshape_1'),
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
channel_splitting_reshape = match['reshape_0']
|
||||
channel_concating_reshape = match['reshape_1']
|
||||
|
||||
initial_shape = channel_splitting_reshape.in_port(0).data.get_shape()
|
||||
resulting_shape = channel_concating_reshape.in_port(1).data.get_value()
|
||||
if initial_shape[0] != resulting_shape[0]:
|
||||
return
|
||||
|
||||
channel_splitted_out_shape = channel_splitting_reshape.in_port(1).data.get_value()
|
||||
if not all([initial_shape[i] == channel_splitted_out_shape[j] for i, j in {0: 0, 2: 4, 3: 5}.items()]) or \
|
||||
channel_splitted_out_shape[1] != channel_splitted_out_shape[2]:
|
||||
return
|
||||
block_size = channel_splitted_out_shape[2]
|
||||
expected_output_shape = [initial_shape[0], initial_shape[1] // (block_size * block_size),
|
||||
initial_shape[2] * block_size, initial_shape[3] * block_size]
|
||||
if not np.array_equal(expected_output_shape, resulting_shape):
|
||||
return
|
||||
|
||||
name = channel_concating_reshape.soft_get('name', channel_concating_reshape.id)
|
||||
depth_to_space = DepthToSpaceOp(graph,
|
||||
{'name': name, 'block_size': block_size, 'mode': 'depth_first'}).create_node()
|
||||
channel_concating_reshape.out_port(0).get_connection().set_source(depth_to_space.out_port(0))
|
||||
depth_to_space.in_port(0).connect(channel_splitting_reshape.in_port(0).get_source())
|
||||
|
@ -0,0 +1,184 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
|
||||
from generator import generate, generator
|
||||
|
||||
from extensions.back.ShuffleChannelPatternOptimization import ShuffleChannelFusion, DepthToSpaceFusion
|
||||
from extensions.ops.depth_to_space import DepthToSpaceOp
|
||||
from extensions.ops.parameter import Parameter
|
||||
from extensions.ops.shufflechannel import ShuffleChannels
|
||||
from extensions.ops.transpose import Transpose
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.ops.reshape import Reshape
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, \
|
||||
valued_const_with_data, connect, regular_op_with_empty_data
|
||||
|
||||
|
||||
@generator
|
||||
class ShuffleChannelFusionTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def get_graphs(input_shape, reshape_0_pattern, order, reshape_1_pattern, group):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter', 'shape': int64_array(input_shape),
|
||||
'infer': Parameter.infer}),
|
||||
|
||||
**valued_const_with_data('reshape_0_pattern', int64_array(reshape_0_pattern)),
|
||||
**regular_op_with_empty_data('reshape_0', {'type': 'Reshape', 'infer': Reshape.infer}),
|
||||
|
||||
**valued_const_with_data('order', int64_array(order)),
|
||||
**regular_op_with_empty_data('transpose', {'type': 'Transpose', 'infer': Transpose.infer}),
|
||||
|
||||
**valued_const_with_data('reshape_1_pattern', int64_array(reshape_1_pattern)),
|
||||
**regular_op_with_empty_data('reshape_1', {'type': 'Reshape', 'infer': Reshape.infer,
|
||||
'name': 'final_reshape'}),
|
||||
|
||||
**result(),
|
||||
}
|
||||
edges = [
|
||||
*connect('input', '0:reshape_0'),
|
||||
*connect('reshape_0_pattern', '1:reshape_0'),
|
||||
*connect('reshape_0', '0:transpose'),
|
||||
*connect('order', '1:transpose'),
|
||||
*connect('transpose', '0:reshape_1'),
|
||||
*connect('reshape_1_pattern', '1:reshape_1'),
|
||||
*connect('reshape_1', 'output'),
|
||||
]
|
||||
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
|
||||
for node in graph.get_op_nodes():
|
||||
node['op'] = node['type']
|
||||
graph.clean_up()
|
||||
|
||||
ref_nodes = {
|
||||
**regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter', 'shape': int64_array(input_shape),
|
||||
'infer': Parameter.infer}),
|
||||
**regular_op_with_empty_data('shuffle_channel', {'type': 'ShuffleChannels', 'infer': ShuffleChannels.infer,
|
||||
'name': 'final_reshape', 'group': group}),
|
||||
**result()
|
||||
}
|
||||
ref_edges = [*connect('input', 'shuffle_channel'), *connect('shuffle_channel', 'output')]
|
||||
graph_ref = build_graph(ref_nodes, ref_edges, nodes_with_edges_only=True)
|
||||
for node in graph_ref.get_op_nodes():
|
||||
node['op'] = node['type']
|
||||
graph_ref.clean_up()
|
||||
|
||||
return graph, graph_ref
|
||||
|
||||
@generate(*[
|
||||
([1, 512, 7, 6], [1, 2, 256, 7, 6], [0, 2, 1, 3, 4], [1, 512, 7, 6], 2),
|
||||
([2, 512, 7, 6], [2, 2, 256, 7, 6], [0, 2, 1, 3, 4], [2, 512, 7, 6], 2),
|
||||
([1, 200, 200, 200], [1, 50, 4, 200, 200], [0, 2, 1, 3, 4], [1, 200, 200, 200], 50),
|
||||
])
|
||||
def test_fusion(self, input_shape, reshape_0_pattern, order, reshape_1_pattern, group):
|
||||
graph, graph_ref = self.get_graphs(input_shape, reshape_0_pattern, order, reshape_1_pattern, group)
|
||||
ShuffleChannelFusion().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(len(graph.get_op_nodes(name='final_reshape')) == 1 and
|
||||
graph.get_op_nodes(name='final_reshape')[0].op == 'ShuffleChannels')
|
||||
|
||||
@generate(*[
|
||||
([1, 512, 7, 6], [0, 2, 256, 7, 6], [0, 2, 1, 3, 4], [1, 512, 7, 6], 2),
|
||||
([1, 512, 7, 6], [1, 2, 256, 7, 6], [0, 2, 1, 4, 3], [1, 512, 7, 6], 2),
|
||||
([1, 512, 7, 6], [1, 2, 256, 7, 6], [0, 2, 1, 3, 4], [-1, 512, 7, 6], 2),
|
||||
])
|
||||
def test_negative(self, input_shape, reshape_0_pattern, order, reshape_1_pattern, group):
|
||||
graph, _ = self.get_graphs(input_shape, reshape_0_pattern, order, reshape_1_pattern, group)
|
||||
graph_ref = graph.copy()
|
||||
ShuffleChannelFusion().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
|
||||
@generator
|
||||
class DepthToSpaceFusionTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def get_graphs(input_shape, reshape_0_pattern, order, reshape_1_pattern, block_size):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter', 'shape': int64_array(input_shape),
|
||||
'infer': Parameter.infer}),
|
||||
|
||||
**valued_const_with_data('reshape_0_pattern', int64_array(reshape_0_pattern)),
|
||||
**regular_op_with_empty_data('reshape_0', {'type': 'Reshape', 'infer': Reshape.infer}),
|
||||
|
||||
**valued_const_with_data('order', int64_array(order)),
|
||||
**regular_op_with_empty_data('transpose', {'type': 'Transpose', 'infer': Transpose.infer}),
|
||||
|
||||
**valued_const_with_data('reshape_1_pattern', int64_array(reshape_1_pattern)),
|
||||
**regular_op_with_empty_data('reshape_1', {'type': 'Reshape', 'infer': Reshape.infer,
|
||||
'name': 'final_reshape'}),
|
||||
|
||||
**result(),
|
||||
}
|
||||
edges = [
|
||||
*connect('input', '0:reshape_0'),
|
||||
*connect('reshape_0_pattern', '1:reshape_0'),
|
||||
*connect('reshape_0', '0:transpose'),
|
||||
*connect('order', '1:transpose'),
|
||||
*connect('transpose', '0:reshape_1'),
|
||||
*connect('reshape_1_pattern', '1:reshape_1'),
|
||||
*connect('reshape_1', 'output'),
|
||||
]
|
||||
graph = build_graph(nodes, edges, nodes_with_edges_only=True, cli=Namespace())
|
||||
for node in graph.get_op_nodes():
|
||||
node['op'] = node['type']
|
||||
graph.clean_up()
|
||||
|
||||
ref_nodes = {
|
||||
**regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter', 'shape': int64_array(input_shape),
|
||||
'infer': Parameter.infer}),
|
||||
**regular_op_with_empty_data('depth_to_space', {'type': 'DepthToSpace', 'infer': DepthToSpaceOp.infer,
|
||||
'name': 'final_reshape', 'block_size': block_size}),
|
||||
**result()
|
||||
}
|
||||
ref_edges = [*connect('input', 'depth_to_space'), *connect('depth_to_space', 'output')]
|
||||
graph_ref = build_graph(ref_nodes, ref_edges, nodes_with_edges_only=True)
|
||||
for node in graph_ref.get_op_nodes():
|
||||
node['op'] = node['type']
|
||||
graph_ref.clean_up()
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
graph_ref.graph['layout'] = 'NCHW'
|
||||
|
||||
return graph, graph_ref
|
||||
|
||||
@generate(*[
|
||||
([1, 512, 7, 6], [1, 2, 2, 128, 7, 6], [0, 1, 4, 2, 5, 3], [1, 128, 14, 12], 2),
|
||||
([2, 512, 7, 6], [2, 2, 2, 128, 7, 6], [0, 1, 4, 2, 5, 3], [2, 128, 14, 12], 2),
|
||||
([1, 200, 200, 200], [1, 2, 2, 50, 200, 200], [0, 1, 4, 2, 5, 3], [1, 50, 400, 400], 2),
|
||||
])
|
||||
def test_fusion(self, input_shape, reshape_0_pattern, order, reshape_1_pattern, block_size):
|
||||
graph, graph_ref = self.get_graphs(input_shape, reshape_0_pattern, order, reshape_1_pattern, block_size)
|
||||
DepthToSpaceFusion().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(len(graph.get_op_nodes(name='final_reshape')) == 1 and
|
||||
graph.get_op_nodes(name='final_reshape')[0].op == 'DepthToSpace')
|
||||
|
||||
@generate(*[
|
||||
([1, 512, 7, 6], [0, 2, 2, 128, 7, 6], [0, 1, 4, 2, 5, 3], [1, 128, 14, 12], 2),
|
||||
([2, 512, 7, 6], [2, 2, 2, 128, 7, 6], [0, 1, 4, 2, 5, 3], [-1, 128, 14, 12], 2),
|
||||
([1, 200, 200, 200], [1, 2, 2, 50, 200, 200], [0, 1, 4, 2, 3, 5], [1, 50, 400, 400], 2),
|
||||
])
|
||||
def test_negative(self, input_shape, reshape_0_pattern, order, reshape_1_pattern, group):
|
||||
graph, _ = self.get_graphs(input_shape, reshape_0_pattern, order, reshape_1_pattern, group)
|
||||
graph_ref = graph.copy()
|
||||
DepthToSpaceFusion().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output')
|
||||
self.assertTrue(flag, resp)
|
@ -58,10 +58,10 @@ class HSwishWithClamp(FrontReplacementSubgraph):
|
||||
nodes=[
|
||||
('input', dict()),
|
||||
('add', dict(op='Add')),
|
||||
('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))),
|
||||
('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))),
|
||||
('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))),
|
||||
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))),
|
||||
('const_0', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 0.0, atol=1e-6))),
|
||||
('const_3', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 3.0, atol=1e-6))),
|
||||
('const_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 6.0, atol=1e-6))),
|
||||
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 1 / 6.0, atol=1e-6))),
|
||||
('clamp', dict(op='Clamp')),
|
||||
('mul', dict(op='Mul')),
|
||||
('mul_2', dict(op='Mul')),
|
||||
@ -97,10 +97,10 @@ class HSwishWithMinMax(FrontReplacementSubgraph):
|
||||
nodes=[
|
||||
('input', dict()),
|
||||
('add', dict(op='Add')),
|
||||
('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))),
|
||||
('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))),
|
||||
('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))),
|
||||
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))),
|
||||
('const_0', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 0.0, atol=1e-6))),
|
||||
('const_3', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 3.0, atol=1e-6))),
|
||||
('const_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 6.0, atol=1e-6))),
|
||||
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 1 / 6.0, atol=1e-6))),
|
||||
('max', dict(op='Maximum')),
|
||||
('min', dict(op='Minimum')),
|
||||
('mul', dict(op='Mul')),
|
||||
|
@ -32,7 +32,7 @@ class SoftplusFusion(FrontReplacementSubgraph):
|
||||
nodes=[
|
||||
('exp', dict(op='Exp')),
|
||||
('add', dict(op='Add')),
|
||||
('const_1', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1.0, atol=1e-6))),
|
||||
('const_1', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 1.0, atol=1e-6))),
|
||||
('ln', dict(op='Log')),
|
||||
],
|
||||
edges=[
|
||||
|
Loading…
Reference in New Issue
Block a user