Fix Gelu parameter name (#4895)
* Rename attribute of the "Gelu" operation from "approximation" to "approximation_mode" * Updated MO file for coverage check
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
branch = True
|
||||
|
||||
source =
|
||||
extensions/
|
||||
mo/
|
||||
mo.py
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class GeLUMergerErf(FrontReplacementPattern):
|
||||
# check that the values match the approximation
|
||||
if fabs(div_param - sqrt2) < 1e-06 and mul_param == 0.5 and add_param == 1.0:
|
||||
log.debug('Confirmed Erf-based GELU pattern after {} with name {}'.format(inp_node.op, inp_name))
|
||||
gelu = GeLUOP(graph, dict(name=inp_name + '/GELU_', approximation='erf')).create_node()
|
||||
gelu = GeLUOP(graph, dict(name=inp_name + '/GELU_', approximation_mode='erf')).create_node()
|
||||
div.in_port(0).get_connection().set_destination(gelu.in_port(0))
|
||||
out_node.out_port(0).get_connection().set_source(gelu.out_port(0))
|
||||
rename_nodes([(out_node, node_name + '/TBD'), (gelu, node_name)])
|
||||
|
||||
@@ -23,7 +23,7 @@ from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph
|
||||
|
||||
ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('gelu', {'type': 'Gelu', 'approximation': 'erf', 'name': 'final_mul'}),
|
||||
**regular_op('gelu', {'type': 'Gelu', 'approximation_mode': 'erf', 'name': 'final_mul'}),
|
||||
**result('result')
|
||||
}
|
||||
ref_edges = [('input', 'gelu'), ('gelu', 'result')]
|
||||
@@ -65,7 +65,7 @@ class GeLUMergerErfTest(unittest.TestCase):
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation == 'erf')
|
||||
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation_mode == 'erf')
|
||||
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
|
||||
graph.get_op_nodes(name='final_mul')[0].op == 'Gelu')
|
||||
|
||||
@@ -90,7 +90,7 @@ class GeLUMergerErfTest(unittest.TestCase):
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation == 'erf')
|
||||
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation_mode == 'erf')
|
||||
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
|
||||
graph.get_op_nodes(name='final_mul')[0].op == 'Gelu')
|
||||
|
||||
@@ -115,6 +115,6 @@ class GeLUMergerErfTest(unittest.TestCase):
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation == 'erf')
|
||||
self.assertTrue(graph.get_op_nodes(op='Gelu')[0].approximation_mode == 'erf')
|
||||
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
|
||||
graph.get_op_nodes(name='final_mul')[0].op == 'Gelu')
|
||||
|
||||
@@ -74,6 +74,6 @@ class GeLUMergerTanh(FrontReplacementSubgraph):
|
||||
# check that the values match the approximation
|
||||
if fabs(mul0_param - sqrt2pi) < 1e-06 and fabs(mul_param - 0.044715) < 1e-06 and mul1_param == 0.5:
|
||||
log.debug('Confirmed TanH-based GELU pattern after {} with name {}'.format(inp.op, inp.name))
|
||||
gelu = GeLUOP(graph, dict(name=inp.name + '/GELU_', approximation='tanh')).create_node()
|
||||
gelu = GeLUOP(graph, dict(name=inp.name + '/GELU_', approximation_mode='tanh')).create_node()
|
||||
inp_port.connect(gelu.in_port(0))
|
||||
match['mul2'].out_port(0).get_connection().set_source(gelu.out_port(0))
|
||||
|
||||
@@ -54,7 +54,7 @@ nodes_attributes_tanh = {
|
||||
|
||||
nodes_attributes_ref = {
|
||||
'inp': {'kind': 'op', 'op': 'AnyOp'},
|
||||
'gelu': {'kind': 'op', 'op': 'Gelu', 'approximation': 'tanh'},
|
||||
'gelu': {'kind': 'op', 'op': 'Gelu', 'approximation_mode': 'tanh'},
|
||||
'out': {'kind': 'op', 'op': 'AnyOp'},
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class LeakyReLUFrontExtractor(FrontExtractorOp):
|
||||
else:
|
||||
LeakyReLU.update_node_stat(node, {'negative_slope': negative_slope})
|
||||
elif act_type == 'gelu':
|
||||
GeLUOP.update_node_stat(node, {'approximation': 'erf'})
|
||||
GeLUOP.update_node_stat(node, {'approximation_mode': 'erf'})
|
||||
else:
|
||||
raise Error(
|
||||
"Operation '{}' not supported. Please register it as custom op. " +
|
||||
|
||||
@@ -35,6 +35,6 @@ class GeLUOP(Op):
|
||||
|
||||
def backend_attrs(self):
|
||||
if self.get_opset() == 'opset7':
|
||||
return ['approximation']
|
||||
return ['approximation_mode']
|
||||
else:
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user