Fix Gelu parameter name (#4895)

* Rename attribute of the "Gelu" operation from "approximation" to "approximation_mode"

* Updated MO file for coverage check
This commit is contained in:
Evgeny Lazarev
2021-03-22 19:35:32 +03:00
committed by GitHub
parent 9d69c0e0ec
commit 062bfd88d3
7 changed files with 10 additions and 9 deletions

View File

@@ -3,6 +3,7 @@
branch = True
source =
extensions/
mo/
mo.py

View File

@@ -126,7 +126,7 @@ class GeLUMergerErf(FrontReplacementPattern):
# check that the values match the approximation
if fabs(div_param - sqrt2) < 1e-06 and mul_param == 0.5 and add_param == 1.0:
log.debug('Confirmed Erf-based GELU pattern after {} with name {}'.format(inp_node.op, inp_name))
gelu = GeLUOP(graph, dict(name=inp_name + '/GELU_', approximation='erf')).create_node()
gelu = GeLUOP(graph, dict(name=inp_name + '/GELU_', approximation_mode='erf')).create_node()
div.in_port(0).get_connection().set_destination(gelu.in_port(0))
out_node.out_port(0).get_connection().set_source(gelu.out_port(0))
rename_nodes([(out_node, node_name + '/TBD'), (gelu, node_name)])

View File

@@ -23,7 +23,7 @@ from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph
ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
**regular_op('gelu', {'type': 'Gelu', 'approximation': 'erf', 'name': 'final_mul'}),
**regular_op('gelu', {'type': 'Gelu', 'approximation_mode': 'erf', 'name': 'final_mul'}),
**result('result')
}
ref_edges = [('input', 'gelu'), ('gelu', 'result')]
@@ -65,7 +65,7 @@ class GeLUMergerErfTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation == 'erf')
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation_mode == 'erf')
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
graph.get_op_nodes(name='final_mul')[0].op == 'Gelu')
@@ -90,7 +90,7 @@ class GeLUMergerErfTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation == 'erf')
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation_mode == 'erf')
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
graph.get_op_nodes(name='final_mul')[0].op == 'Gelu')
@@ -115,6 +115,6 @@ class GeLUMergerErfTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation == 'erf')
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation_mode == 'erf')
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
graph.get_op_nodes(name='final_mul')[0].op == 'Gelu')

View File

@@ -74,6 +74,6 @@ class GeLUMergerTanh(FrontReplacementSubgraph):
# check that the values match the approximation
if fabs(mul0_param - sqrt2pi) < 1e-06 and fabs(mul_param - 0.044715) < 1e-06 and mul1_param == 0.5:
log.debug('Confirmed TanH-based GELU pattern after {} with name {}'.format(inp.op, inp.name))
gelu = GeLUOP(graph, dict(name=inp.name + '/GELU_', approximation='tanh')).create_node()
gelu = GeLUOP(graph, dict(name=inp.name + '/GELU_', approximation_mode='tanh')).create_node()
inp_port.connect(gelu.in_port(0))
match['mul2'].out_port(0).get_connection().set_source(gelu.out_port(0))

View File

@@ -54,7 +54,7 @@ nodes_attributes_tanh = {
nodes_attributes_ref = {
'inp': {'kind': 'op', 'op': 'AnyOp'},
'gelu': {'kind': 'op', 'op': 'Gelu', 'approximation': 'tanh'},
'gelu': {'kind': 'op', 'op': 'Gelu', 'approximation_mode': 'tanh'},
'out': {'kind': 'op', 'op': 'AnyOp'},
}

View File

@@ -52,7 +52,7 @@ class LeakyReLUFrontExtractor(FrontExtractorOp):
else:
LeakyReLU.update_node_stat(node, {'negative_slope': negative_slope})
elif act_type == 'gelu':
GeLUOP.update_node_stat(node, {'approximation': 'erf'})
GeLUOP.update_node_stat(node, {'approximation_mode': 'erf'})
else:
raise Error(
"Operation '{}' not supported. Please register it as custom op. " +

View File

@@ -35,6 +35,6 @@ class GeLUOP(Op):
def backend_attrs(self):
if self.get_opset() == 'opset7':
return ['approximation']
return ['approximation_mode']
else:
return []