Fix Mish and SoftPlus value propagation functions (#2120)

* Fix Mish and SoftPlus value propagation functions

* Add unit tests for SoftPlus & Mish operations value propagation functions
This commit is contained in:
Anton Chetverikov 2020-09-11 12:58:14 +03:00 committed by GitHub
parent 1fd2df6e0d
commit e6e7f5158a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 6 deletions

View File

@ -233,13 +233,13 @@ class Log(Activation):
class SoftPlus(Activation):
op = 'SoftPlus'
version = 'opset4'
operation = staticmethod(lambda x: np.ln(np.exp(x) + 1.0))
operation = staticmethod(lambda x: np.log(np.exp(x) + 1.0))
class Mish(Activation):
op = 'Mish'
version = 'opset4'
operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0)))
operation = staticmethod(lambda x: x * np.tanh(np.log(np.exp(x) + 1.0)))
class HSwish(Activation):

View File

@ -18,7 +18,7 @@ import unittest
import numpy as np
from extensions.ops.activation_ops import Elu
from extensions.ops.activation_ops import Elu, SoftPlus, Mish
from mo.graph.graph import Node
from mo.utils.unittest.graph import build_graph
@ -26,12 +26,13 @@ from mo.utils.unittest.graph import build_graph
class TestActivationOp(unittest.TestCase):
nodes_attributes = {
'node_1': {
'shape': np.array([227, 227, 227, 227]),
'shape': np.array([4]),
'value': None
},
'activation_node': {
'op': 'Activation',
'kind': 'op'
'kind': 'op',
'operation': None
},
'node_3': {
'shape': None
@ -59,7 +60,7 @@ class TestActivationOp(unittest.TestCase):
graph.graph['layout'] = 'NCHW'
activation_node = Node(graph, 'activation_node')
Elu.infer(activation_node)
exp_shape = np.array([227, 227, 227, 227])
exp_shape = np.array([4])
res_shape = graph.node['node_3']['shape']
res_value = graph.node['node_3']['value']
exp_value = np.array([6., -0.98168436, -0.86466472, -0.63212056])
@ -67,3 +68,63 @@ class TestActivationOp(unittest.TestCase):
self.assertEqual(res_shape[i], value)
for i, value in enumerate(exp_value):
self.assertAlmostEqual(res_value[i], value)
def test_activation_softplus_infer(self):
graph = build_graph(self.nodes_attributes,
[
('node_1', 'activation_node'),
('activation_node', 'node_3')
],
{
'node_1': {
'value': np.array([-1.0, 0.0, 1.0, 20.0])
},
'activation_node': {
'op': 'SoftPlus',
'operation': SoftPlus.operation,
},
'node_3': {
'value': None
}
})
graph.graph['layout'] = 'NCHW'
activation_node = Node(graph, 'activation_node')
SoftPlus.infer(activation_node)
exp_shape = np.array([4])
res_shape = graph.node['node_3']['shape']
res_value = graph.node['node_3']['value']
exp_value = np.array([0.3132617, 0.6931472, 1.3132617, 20.0])
for i, value in enumerate(exp_shape):
self.assertEqual(res_shape[i], value)
for i, value in enumerate(exp_value):
self.assertAlmostEqual(res_value[i], value)
def test_activation_mish_infer(self):
graph = build_graph(self.nodes_attributes,
[
('node_1', 'activation_node'),
('activation_node', 'node_3')
],
{
'node_1': {
'value': np.array([-1.0, 0.0, 1.0, 20.0])
},
'activation_node': {
'op': 'Mish',
'operation': Mish.operation,
},
'node_3': {
'value': None
}
})
graph.graph['layout'] = 'NCHW'
activation_node = Node(graph, 'activation_node')
Mish.infer(activation_node)
exp_shape = np.array([4])
res_shape = graph.node['node_3']['shape']
res_value = graph.node['node_3']['value']
exp_value = np.array([-0.30340146, 0.0, 0.8650984, 20.0])
for i, value in enumerate(exp_shape):
self.assertEqual(res_shape[i], value)
for i, value in enumerate(exp_value):
self.assertAlmostEqual(res_value[i], value)