[MO] Fix swish value infer (#10802)
This commit is contained in:
parent
dfdbdb4601
commit
c8f4f9b7db
@ -291,6 +291,6 @@ class Swish(Op):
|
|||||||
assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name)
|
assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name)
|
||||||
beta = beta.item()
|
beta = beta.item()
|
||||||
|
|
||||||
input_value = node.in_port(1).data.get_value()
|
input_value = node.in_port(0).data.get_value()
|
||||||
if input_value is not None and beta is not None:
|
if input_value is not None and beta is not None:
|
||||||
node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta)))
|
node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta)))
|
||||||
|
@ -5,7 +5,7 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from openvino.tools.mo.ops.activation_ops import Elu, SoftPlus, Mish
|
from openvino.tools.mo.ops.activation_ops import Elu, SoftPlus, Mish, Swish
|
||||||
from openvino.tools.mo.graph.graph import Node
|
from openvino.tools.mo.graph.graph import Node
|
||||||
from unit_tests.utils.graph import build_graph
|
from unit_tests.utils.graph import build_graph
|
||||||
|
|
||||||
@ -115,3 +115,32 @@ class TestActivationOp(unittest.TestCase):
|
|||||||
self.assertEqual(res_shape[i], value)
|
self.assertEqual(res_shape[i], value)
|
||||||
for i, value in enumerate(exp_value):
|
for i, value in enumerate(exp_value):
|
||||||
self.assertAlmostEqual(res_value[i], value)
|
self.assertAlmostEqual(res_value[i], value)
|
||||||
|
|
||||||
|
def test_activation_swish_infer(self):
|
||||||
|
graph = build_graph(self.nodes_attributes,
|
||||||
|
[
|
||||||
|
('node_1', 'activation_node'),
|
||||||
|
('activation_node', 'node_3')
|
||||||
|
],
|
||||||
|
{
|
||||||
|
'node_1': {
|
||||||
|
'value': np.array([-1.0, 0.0, 1.0, 20.0])
|
||||||
|
},
|
||||||
|
'activation_node': {
|
||||||
|
'op': 'Swish',
|
||||||
|
},
|
||||||
|
'node_3': {
|
||||||
|
'value': None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
graph.graph['layout'] = 'NCHW'
|
||||||
|
activation_node = Node(graph, 'activation_node')
|
||||||
|
Swish.infer(activation_node)
|
||||||
|
exp_shape = np.array([4])
|
||||||
|
res_shape = graph.node['node_3']['shape']
|
||||||
|
res_value = graph.node['node_3']['value']
|
||||||
|
exp_value = np.array([-0.26894142, 0.0, 0.73105858, 19.99999996])
|
||||||
|
for i, value in enumerate(exp_shape):
|
||||||
|
self.assertEqual(res_shape[i], value)
|
||||||
|
for i, value in enumerate(exp_value):
|
||||||
|
self.assertAlmostEqual(res_value[i], value)
|
||||||
|
Loading…
Reference in New Issue
Block a user