Fix Mish and SoftPlus value propagation functions (#2120)
* Fix Mish and SoftPlus value propagation functions * Add unit tests for SoftPlus & Mish operations value propagation functions
This commit is contained in:
parent
1fd2df6e0d
commit
e6e7f5158a
@ -233,13 +233,13 @@ class Log(Activation):
|
||||
class SoftPlus(Activation):
|
||||
op = 'SoftPlus'
|
||||
version = 'opset4'
|
||||
operation = staticmethod(lambda x: np.ln(np.exp(x) + 1.0))
|
||||
operation = staticmethod(lambda x: np.log(np.exp(x) + 1.0))
|
||||
|
||||
|
||||
class Mish(Activation):
|
||||
op = 'Mish'
|
||||
version = 'opset4'
|
||||
operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0)))
|
||||
operation = staticmethod(lambda x: x * np.tanh(np.log(np.exp(x) + 1.0)))
|
||||
|
||||
|
||||
class HSwish(Activation):
|
||||
|
@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.activation_ops import Elu
|
||||
from extensions.ops.activation_ops import Elu, SoftPlus, Mish
|
||||
from mo.graph.graph import Node
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
|
||||
@ -26,12 +26,13 @@ from mo.utils.unittest.graph import build_graph
|
||||
class TestActivationOp(unittest.TestCase):
|
||||
nodes_attributes = {
|
||||
'node_1': {
|
||||
'shape': np.array([227, 227, 227, 227]),
|
||||
'shape': np.array([4]),
|
||||
'value': None
|
||||
},
|
||||
'activation_node': {
|
||||
'op': 'Activation',
|
||||
'kind': 'op'
|
||||
'kind': 'op',
|
||||
'operation': None
|
||||
},
|
||||
'node_3': {
|
||||
'shape': None
|
||||
@ -59,7 +60,7 @@ class TestActivationOp(unittest.TestCase):
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
activation_node = Node(graph, 'activation_node')
|
||||
Elu.infer(activation_node)
|
||||
exp_shape = np.array([227, 227, 227, 227])
|
||||
exp_shape = np.array([4])
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
res_value = graph.node['node_3']['value']
|
||||
exp_value = np.array([6., -0.98168436, -0.86466472, -0.63212056])
|
||||
@ -67,3 +68,63 @@ class TestActivationOp(unittest.TestCase):
|
||||
self.assertEqual(res_shape[i], value)
|
||||
for i, value in enumerate(exp_value):
|
||||
self.assertAlmostEqual(res_value[i], value)
|
||||
|
||||
def test_activation_softplus_infer(self):
|
||||
graph = build_graph(self.nodes_attributes,
|
||||
[
|
||||
('node_1', 'activation_node'),
|
||||
('activation_node', 'node_3')
|
||||
],
|
||||
{
|
||||
'node_1': {
|
||||
'value': np.array([-1.0, 0.0, 1.0, 20.0])
|
||||
},
|
||||
'activation_node': {
|
||||
'op': 'SoftPlus',
|
||||
'operation': SoftPlus.operation,
|
||||
},
|
||||
'node_3': {
|
||||
'value': None
|
||||
}
|
||||
})
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
activation_node = Node(graph, 'activation_node')
|
||||
SoftPlus.infer(activation_node)
|
||||
exp_shape = np.array([4])
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
res_value = graph.node['node_3']['value']
|
||||
exp_value = np.array([0.3132617, 0.6931472, 1.3132617, 20.0])
|
||||
for i, value in enumerate(exp_shape):
|
||||
self.assertEqual(res_shape[i], value)
|
||||
for i, value in enumerate(exp_value):
|
||||
self.assertAlmostEqual(res_value[i], value)
|
||||
|
||||
def test_activation_mish_infer(self):
|
||||
graph = build_graph(self.nodes_attributes,
|
||||
[
|
||||
('node_1', 'activation_node'),
|
||||
('activation_node', 'node_3')
|
||||
],
|
||||
{
|
||||
'node_1': {
|
||||
'value': np.array([-1.0, 0.0, 1.0, 20.0])
|
||||
},
|
||||
'activation_node': {
|
||||
'op': 'Mish',
|
||||
'operation': Mish.operation,
|
||||
},
|
||||
'node_3': {
|
||||
'value': None
|
||||
}
|
||||
})
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
activation_node = Node(graph, 'activation_node')
|
||||
Mish.infer(activation_node)
|
||||
exp_shape = np.array([4])
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
res_value = graph.node['node_3']['value']
|
||||
exp_value = np.array([-0.30340146, 0.0, 0.8650984, 20.0])
|
||||
for i, value in enumerate(exp_shape):
|
||||
self.assertEqual(res_shape[i], value)
|
||||
for i, value in enumerate(exp_value):
|
||||
self.assertAlmostEqual(res_value[i], value)
|
||||
|
Loading…
Reference in New Issue
Block a user