[MO] Fix swish value infer (#10802)

This commit is contained in:
Maxim Vafin 2022-03-18 14:56:37 +03:00 committed by GitHub
parent dfdbdb4601
commit c8f4f9b7db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 2 deletions

View File

@ -291,6 +291,6 @@ class Swish(Op):
assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name)
beta = beta.item()
input_value = node.in_port(1).data.get_value()
input_value = node.in_port(0).data.get_value()
if input_value is not None and beta is not None:
node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta)))

View File

@ -5,7 +5,7 @@ import unittest
import numpy as np
from openvino.tools.mo.ops.activation_ops import Elu, SoftPlus, Mish
from openvino.tools.mo.ops.activation_ops import Elu, SoftPlus, Mish, Swish
from openvino.tools.mo.graph.graph import Node
from unit_tests.utils.graph import build_graph
@ -115,3 +115,32 @@ class TestActivationOp(unittest.TestCase):
self.assertEqual(res_shape[i], value)
for i, value in enumerate(exp_value):
self.assertAlmostEqual(res_value[i], value)
def test_activation_swish_infer(self):
graph = build_graph(self.nodes_attributes,
[
('node_1', 'activation_node'),
('activation_node', 'node_3')
],
{
'node_1': {
'value': np.array([-1.0, 0.0, 1.0, 20.0])
},
'activation_node': {
'op': 'Swish',
},
'node_3': {
'value': None
}
})
graph.graph['layout'] = 'NCHW'
activation_node = Node(graph, 'activation_node')
Swish.infer(activation_node)
exp_shape = np.array([4])
res_shape = graph.node['node_3']['shape']
res_value = graph.node['node_3']['value']
exp_value = np.array([-0.26894142, 0.0, 0.73105858, 19.99999996])
for i, value in enumerate(exp_shape):
self.assertEqual(res_shape[i], value)
for i, value in enumerate(exp_value):
self.assertAlmostEqual(res_value[i], value)