[MO] Fix swish value infer (#10802)
This commit is contained in:
parent
dfdbdb4601
commit
c8f4f9b7db
@ -291,6 +291,6 @@ class Swish(Op):
|
||||
assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name)
|
||||
beta = beta.item()
|
||||
|
||||
input_value = node.in_port(1).data.get_value()
|
||||
input_value = node.in_port(0).data.get_value()
|
||||
if input_value is not None and beta is not None:
|
||||
node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta)))
|
||||
|
@ -5,7 +5,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.ops.activation_ops import Elu, SoftPlus, Mish
|
||||
from openvino.tools.mo.ops.activation_ops import Elu, SoftPlus, Mish, Swish
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from unit_tests.utils.graph import build_graph
|
||||
|
||||
@ -115,3 +115,32 @@ class TestActivationOp(unittest.TestCase):
|
||||
self.assertEqual(res_shape[i], value)
|
||||
for i, value in enumerate(exp_value):
|
||||
self.assertAlmostEqual(res_value[i], value)
|
||||
|
||||
def test_activation_swish_infer(self):
|
||||
graph = build_graph(self.nodes_attributes,
|
||||
[
|
||||
('node_1', 'activation_node'),
|
||||
('activation_node', 'node_3')
|
||||
],
|
||||
{
|
||||
'node_1': {
|
||||
'value': np.array([-1.0, 0.0, 1.0, 20.0])
|
||||
},
|
||||
'activation_node': {
|
||||
'op': 'Swish',
|
||||
},
|
||||
'node_3': {
|
||||
'value': None
|
||||
}
|
||||
})
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
activation_node = Node(graph, 'activation_node')
|
||||
Swish.infer(activation_node)
|
||||
exp_shape = np.array([4])
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
res_value = graph.node['node_3']['value']
|
||||
exp_value = np.array([-0.26894142, 0.0, 0.73105858, 19.99999996])
|
||||
for i, value in enumerate(exp_shape):
|
||||
self.assertEqual(res_shape[i], value)
|
||||
for i, value in enumerate(exp_value):
|
||||
self.assertAlmostEqual(res_value[i], value)
|
||||
|
Loading…
Reference in New Issue
Block a user