Fix Mish and SoftPlus value propagation functions (#2120)
* Fix Mish and SoftPlus value propagation functions * Add unit tests for SoftPlus & Mish operations value propagation functions
This commit is contained in:
parent
1fd2df6e0d
commit
e6e7f5158a
@ -233,13 +233,13 @@ class Log(Activation):
|
|||||||
class SoftPlus(Activation):
|
class SoftPlus(Activation):
|
||||||
op = 'SoftPlus'
|
op = 'SoftPlus'
|
||||||
version = 'opset4'
|
version = 'opset4'
|
||||||
operation = staticmethod(lambda x: np.ln(np.exp(x) + 1.0))
|
operation = staticmethod(lambda x: np.log(np.exp(x) + 1.0))
|
||||||
|
|
||||||
|
|
||||||
class Mish(Activation):
|
class Mish(Activation):
|
||||||
op = 'Mish'
|
op = 'Mish'
|
||||||
version = 'opset4'
|
version = 'opset4'
|
||||||
operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0)))
|
operation = staticmethod(lambda x: x * np.tanh(np.log(np.exp(x) + 1.0)))
|
||||||
|
|
||||||
|
|
||||||
class HSwish(Activation):
|
class HSwish(Activation):
|
||||||
|
@ -18,7 +18,7 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from extensions.ops.activation_ops import Elu
|
from extensions.ops.activation_ops import Elu, SoftPlus, Mish
|
||||||
from mo.graph.graph import Node
|
from mo.graph.graph import Node
|
||||||
from mo.utils.unittest.graph import build_graph
|
from mo.utils.unittest.graph import build_graph
|
||||||
|
|
||||||
@ -26,12 +26,13 @@ from mo.utils.unittest.graph import build_graph
|
|||||||
class TestActivationOp(unittest.TestCase):
|
class TestActivationOp(unittest.TestCase):
|
||||||
nodes_attributes = {
|
nodes_attributes = {
|
||||||
'node_1': {
|
'node_1': {
|
||||||
'shape': np.array([227, 227, 227, 227]),
|
'shape': np.array([4]),
|
||||||
'value': None
|
'value': None
|
||||||
},
|
},
|
||||||
'activation_node': {
|
'activation_node': {
|
||||||
'op': 'Activation',
|
'op': 'Activation',
|
||||||
'kind': 'op'
|
'kind': 'op',
|
||||||
|
'operation': None
|
||||||
},
|
},
|
||||||
'node_3': {
|
'node_3': {
|
||||||
'shape': None
|
'shape': None
|
||||||
@ -59,7 +60,7 @@ class TestActivationOp(unittest.TestCase):
|
|||||||
graph.graph['layout'] = 'NCHW'
|
graph.graph['layout'] = 'NCHW'
|
||||||
activation_node = Node(graph, 'activation_node')
|
activation_node = Node(graph, 'activation_node')
|
||||||
Elu.infer(activation_node)
|
Elu.infer(activation_node)
|
||||||
exp_shape = np.array([227, 227, 227, 227])
|
exp_shape = np.array([4])
|
||||||
res_shape = graph.node['node_3']['shape']
|
res_shape = graph.node['node_3']['shape']
|
||||||
res_value = graph.node['node_3']['value']
|
res_value = graph.node['node_3']['value']
|
||||||
exp_value = np.array([6., -0.98168436, -0.86466472, -0.63212056])
|
exp_value = np.array([6., -0.98168436, -0.86466472, -0.63212056])
|
||||||
@ -67,3 +68,63 @@ class TestActivationOp(unittest.TestCase):
|
|||||||
self.assertEqual(res_shape[i], value)
|
self.assertEqual(res_shape[i], value)
|
||||||
for i, value in enumerate(exp_value):
|
for i, value in enumerate(exp_value):
|
||||||
self.assertAlmostEqual(res_value[i], value)
|
self.assertAlmostEqual(res_value[i], value)
|
||||||
|
|
||||||
|
def test_activation_softplus_infer(self):
|
||||||
|
graph = build_graph(self.nodes_attributes,
|
||||||
|
[
|
||||||
|
('node_1', 'activation_node'),
|
||||||
|
('activation_node', 'node_3')
|
||||||
|
],
|
||||||
|
{
|
||||||
|
'node_1': {
|
||||||
|
'value': np.array([-1.0, 0.0, 1.0, 20.0])
|
||||||
|
},
|
||||||
|
'activation_node': {
|
||||||
|
'op': 'SoftPlus',
|
||||||
|
'operation': SoftPlus.operation,
|
||||||
|
},
|
||||||
|
'node_3': {
|
||||||
|
'value': None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
graph.graph['layout'] = 'NCHW'
|
||||||
|
activation_node = Node(graph, 'activation_node')
|
||||||
|
SoftPlus.infer(activation_node)
|
||||||
|
exp_shape = np.array([4])
|
||||||
|
res_shape = graph.node['node_3']['shape']
|
||||||
|
res_value = graph.node['node_3']['value']
|
||||||
|
exp_value = np.array([0.3132617, 0.6931472, 1.3132617, 20.0])
|
||||||
|
for i, value in enumerate(exp_shape):
|
||||||
|
self.assertEqual(res_shape[i], value)
|
||||||
|
for i, value in enumerate(exp_value):
|
||||||
|
self.assertAlmostEqual(res_value[i], value)
|
||||||
|
|
||||||
|
def test_activation_mish_infer(self):
|
||||||
|
graph = build_graph(self.nodes_attributes,
|
||||||
|
[
|
||||||
|
('node_1', 'activation_node'),
|
||||||
|
('activation_node', 'node_3')
|
||||||
|
],
|
||||||
|
{
|
||||||
|
'node_1': {
|
||||||
|
'value': np.array([-1.0, 0.0, 1.0, 20.0])
|
||||||
|
},
|
||||||
|
'activation_node': {
|
||||||
|
'op': 'Mish',
|
||||||
|
'operation': Mish.operation,
|
||||||
|
},
|
||||||
|
'node_3': {
|
||||||
|
'value': None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
graph.graph['layout'] = 'NCHW'
|
||||||
|
activation_node = Node(graph, 'activation_node')
|
||||||
|
Mish.infer(activation_node)
|
||||||
|
exp_shape = np.array([4])
|
||||||
|
res_shape = graph.node['node_3']['shape']
|
||||||
|
res_value = graph.node['node_3']['value']
|
||||||
|
exp_value = np.array([-0.30340146, 0.0, 0.8650984, 20.0])
|
||||||
|
for i, value in enumerate(exp_shape):
|
||||||
|
self.assertEqual(res_shape[i], value)
|
||||||
|
for i, value in enumerate(exp_value):
|
||||||
|
self.assertAlmostEqual(res_value[i], value)
|
||||||
|
Loading…
Reference in New Issue
Block a user