diff --git a/tools/mo/openvino/tools/mo/ops/activation_ops.py b/tools/mo/openvino/tools/mo/ops/activation_ops.py index bdf48463b9b..2d76a274c1e 100644 --- a/tools/mo/openvino/tools/mo/ops/activation_ops.py +++ b/tools/mo/openvino/tools/mo/ops/activation_ops.py @@ -291,6 +291,6 @@ class Swish(Op): assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name) beta = beta.item() - input_value = node.in_port(1).data.get_value() + input_value = node.in_port(0).data.get_value() if input_value is not None and beta is not None: node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta))) diff --git a/tools/mo/unit_tests/mo/ops/activation_test.py b/tools/mo/unit_tests/mo/ops/activation_test.py index d1ea12e9726..95310356b02 100644 --- a/tools/mo/unit_tests/mo/ops/activation_test.py +++ b/tools/mo/unit_tests/mo/ops/activation_test.py @@ -5,7 +5,7 @@ import unittest import numpy as np -from openvino.tools.mo.ops.activation_ops import Elu, SoftPlus, Mish +from openvino.tools.mo.ops.activation_ops import Elu, SoftPlus, Mish, Swish from openvino.tools.mo.graph.graph import Node from unit_tests.utils.graph import build_graph @@ -115,3 +115,32 @@ class TestActivationOp(unittest.TestCase): self.assertEqual(res_shape[i], value) for i, value in enumerate(exp_value): self.assertAlmostEqual(res_value[i], value) + + def test_activation_swish_infer(self): + graph = build_graph(self.nodes_attributes, + [ + ('node_1', 'activation_node'), + ('activation_node', 'node_3') + ], + { + 'node_1': { + 'value': np.array([-1.0, 0.0, 1.0, 20.0]) + }, + 'activation_node': { + 'op': 'Swish', + }, + 'node_3': { + 'value': None + } + }) + graph.graph['layout'] = 'NCHW' + activation_node = Node(graph, 'activation_node') + Swish.infer(activation_node) + exp_shape = np.array([4]) + res_shape = graph.node['node_3']['shape'] + res_value = graph.node['node_3']['value'] + exp_value = np.array([-0.26894142, 0.0, 0.73105858, 19.99999996]) + for i, value in enumerate(exp_shape): + self.assertEqual(res_shape[i], value) + for i, value in enumerate(exp_value): + self.assertAlmostEqual(res_value[i], value)