From c592ecd44efd71a6d410667aed702e79bf923db1 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 13 Apr 2023 16:10:40 +0200 Subject: [PATCH] [MO] Fix legacy If (#16613) * Fix legacy If * Add test for If op * Small fix --- tools/mo/openvino/tools/mo/ops/If.py | 8 ++-- .../unit_tests/mo/utils/ir_reader/ops_test.py | 43 +++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/tools/mo/openvino/tools/mo/ops/If.py b/tools/mo/openvino/tools/mo/ops/If.py index ef239c6b2c7..d526b40d9d6 100644 --- a/tools/mo/openvino/tools/mo/ops/If.py +++ b/tools/mo/openvino/tools/mo/ops/If.py @@ -278,11 +278,11 @@ class If(Op): :return: if_node """ then_graph_nodes = if_node.then_graph.nodes() - for idx in range(len(if_node.then_graph.get_op_nodes())): - then_graph_nodes[idx]['internal_layer_id'] = idx + for node in if_node.then_graph.get_op_nodes(): + then_graph_nodes[node.id]['internal_layer_id'] = node.id else_graph_nodes = if_node.else_graph.nodes() - for idx in range(len(if_node.else_graph.get_op_nodes())): - else_graph_nodes[idx]['internal_layer_id'] = idx + for node in if_node.else_graph.get_op_nodes(): + else_graph_nodes[node.id]['internal_layer_id'] = node.id return if_node.node def substitute_ie_attrs(self, new_attrs: dict): diff --git a/tools/mo/unit_tests/mo/utils/ir_reader/ops_test.py b/tools/mo/unit_tests/mo/utils/ir_reader/ops_test.py index 386412c2a1e..87eaf26e029 100644 --- a/tools/mo/unit_tests/mo/utils/ir_reader/ops_test.py +++ b/tools/mo/unit_tests/mo/utils/ir_reader/ops_test.py @@ -109,3 +109,46 @@ class TestOps(unittest.TestCase): graph = TestOps.check_graph_can_save(model, 'is_nan_model') is_nan_node = graph.get_op_nodes(op="IsNaN")[0] self.assertEqual(is_nan_node["version"], "opset10") + + def test_if(self): + parameter_x = opset11.parameter([2], np.float32, "pX") + parameter_y = opset11.parameter([2], np.float32, "pY") + const_z = opset11.constant(4.0, dtype=np.float32) + + condition = opset11.constant(True, dtype=bool) + + # then_body + x_t = opset11.parameter([2], np.float32, "X") + y_t = opset11.parameter([2], np.float32, "Y") + mmul_t = opset11.matmul(x_t, y_t, False, False) + mul_t = opset11.multiply(y_t, x_t) + then_body_res_1 = opset11.result(mmul_t) + then_body_res_2 = opset11.result(mul_t) + then_body = Model([then_body_res_1, then_body_res_2], [x_t, y_t]) + + # else_body + x_e = opset11.parameter([2], np.float32, "X") + z_e = opset11.parameter([], np.float32, "Z") + mul_e = opset11.multiply(x_e, z_e) + else_body_res_1 = opset11.result(z_e) + else_body_res_2 = opset11.result(mul_e) + else_body = Model([else_body_res_1, else_body_res_2], [x_e, z_e]) + + if_node = opset11.if_op(condition) + if_node.set_friendly_name("If_opset8") + if_node.set_then_body(then_body) + if_node.set_else_body(else_body) + if_node.set_input(parameter_x.output(0), x_t, x_e) + if_node.set_input(parameter_y.output(0), y_t, None) + if_node.set_input(const_z.output(0), None, z_e) + out1 = if_node.set_output(then_body_res_1, else_body_res_1) + out2 = if_node.set_output(then_body_res_2, else_body_res_2) + + model = Model([out1, out2], [parameter_x, parameter_y]) + graph = TestOps.check_graph_can_save(model, 'if_model') + if_node = graph.get_op_nodes(op="If")[0] + self.assertEqual(if_node["version"], "opset8") + _, layer_info, _ = if_node['IE'][0] + _, callable_attribute = layer_info[0] + self.assertTrue(callable(callable_attribute)) + self.assertEqual(callable_attribute(if_node), "If_opset8")