[MO] Fix legacy If (#16613)

* Fix legacy If

* Add test for If op

* Small fix
This commit is contained in:
Maxim Vafin
2023-04-13 16:10:40 +02:00
committed by GitHub
parent 5795a50a22
commit c592ecd44e
2 changed files with 47 additions and 4 deletions

View File

@@ -278,11 +278,11 @@ class If(Op):
:return: if_node
"""
then_graph_nodes = if_node.then_graph.nodes()
for idx in range(len(if_node.then_graph.get_op_nodes())):
then_graph_nodes[idx]['internal_layer_id'] = idx
for node in if_node.then_graph.get_op_nodes():
then_graph_nodes[node.id]['internal_layer_id'] = node.id
else_graph_nodes = if_node.else_graph.nodes()
for idx in range(len(if_node.else_graph.get_op_nodes())):
else_graph_nodes[idx]['internal_layer_id'] = idx
for node in if_node.else_graph.get_op_nodes():
else_graph_nodes[node.id]['internal_layer_id'] = node.id
return if_node.node
def substitute_ie_attrs(self, new_attrs: dict):

View File

@@ -109,3 +109,46 @@ class TestOps(unittest.TestCase):
graph = TestOps.check_graph_can_save(model, 'is_nan_model')
is_nan_node = graph.get_op_nodes(op="IsNaN")[0]
self.assertEqual(is_nan_node["version"], "opset10")
def test_if(self):
parameter_x = opset11.parameter([2], np.float32, "pX")
parameter_y = opset11.parameter([2], np.float32, "pY")
const_z = opset11.constant(4.0, dtype=np.float32)
condition = opset11.constant(True, dtype=bool)
# then_body
x_t = opset11.parameter([2], np.float32, "X")
y_t = opset11.parameter([2], np.float32, "Y")
mmul_t = opset11.matmul(x_t, y_t, False, False)
mul_t = opset11.multiply(y_t, x_t)
then_body_res_1 = opset11.result(mmul_t)
then_body_res_2 = opset11.result(mul_t)
then_body = Model([then_body_res_1, then_body_res_2], [x_t, y_t])
# else_body
x_e = opset11.parameter([2], np.float32, "X")
z_e = opset11.parameter([], np.float32, "Z")
mul_e = opset11.multiply(x_e, z_e)
else_body_res_1 = opset11.result(z_e)
else_body_res_2 = opset11.result(mul_e)
else_body = Model([else_body_res_1, else_body_res_2], [x_e, z_e])
if_node = opset11.if_op(condition)
if_node.set_friendly_name("If_opset8")
if_node.set_then_body(then_body)
if_node.set_else_body(else_body)
if_node.set_input(parameter_x.output(0), x_t, x_e)
if_node.set_input(parameter_y.output(0), y_t, None)
if_node.set_input(const_z.output(0), None, z_e)
out1 = if_node.set_output(then_body_res_1, else_body_res_1)
out2 = if_node.set_output(then_body_res_2, else_body_res_2)
model = Model([out1, out2], [parameter_x, parameter_y])
graph = TestOps.check_graph_can_save(model, 'if_model')
if_node = graph.get_op_nodes(op="If")[0]
self.assertEqual(if_node["version"], "opset8")
_, layer_info, _ = if_node['IE'][0]
_, callable_attribute = layer_info[0]
self.assertTrue(callable(callable_attribute))
self.assertEqual(callable_attribute(if_node), "If_opset8")