changed Exit op infer function implementation: re-write in new API to avoid unmasking of shapes (#9664)
This commit is contained in:
committed by
GitHub
parent
6c2d1e923c
commit
b9293dc424
@@ -1,7 +1,6 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import mo_array
|
||||
from openvino.tools.mo.graph.graph import Node, Graph
|
||||
from openvino.tools.mo.ops.op import Op
|
||||
|
||||
@@ -20,8 +19,11 @@ class Exit(Op):
|
||||
|
||||
@staticmethod
|
||||
def exit_infer(node: Node):
|
||||
output_shape = node.in_node(0).shape
|
||||
output_value = node.in_node(0).value
|
||||
for _, out_node in node.graph.out_edges(node.id):
|
||||
node.graph.node[out_node]['shape'] = mo_array(output_shape)
|
||||
node.graph.node[out_node]['value'] = None if output_value is None else mo_array(output_value)
|
||||
output_shape = node.in_port(0).data.get_shape()
|
||||
output_value = node.in_port(0).data.get_value()
|
||||
|
||||
for port in node.out_ports():
|
||||
if not node.out_port(port).disconnected():
|
||||
node.out_port(port).data.set_shape(output_shape)
|
||||
if output_value is not None:
|
||||
node.out_port(port).data.set_value(output_value)
|
||||
|
||||
46
tools/mo/unit_tests/mo/ops/exit_test.py
Normal file
46
tools/mo/unit_tests/mo/ops/exit_test.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (C) 2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, dynamic_dimension_value
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.Exit import Exit
|
||||
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, shaped_parameter
|
||||
|
||||
|
||||
# test for TensorIterator
|
||||
graph_nodes = {
|
||||
**shaped_parameter("input", int64_array([1, 4, 64, 54])),
|
||||
**regular_op_with_empty_data("exit", {'op': "Exit"}),
|
||||
**result("output")
|
||||
}
|
||||
|
||||
|
||||
class ExitTest(unittest.TestCase):
|
||||
def test_exit_static(self):
|
||||
graph = build_graph(nodes_attrs=graph_nodes,
|
||||
edges=[*connect('input', 'exit'),
|
||||
*connect('exit', 'output')],
|
||||
nodes_with_edges_only=True)
|
||||
exit_node = Node(graph, 'exit')
|
||||
in_node = Node(graph, 'input')
|
||||
|
||||
Exit.exit_infer(exit_node)
|
||||
|
||||
self.assertTrue(np.ma.allequal(exit_node.out_port(0).data.get_shape(), in_node.shape))
|
||||
|
||||
def test_exit_dynamic(self):
|
||||
graph = build_graph(nodes_attrs=graph_nodes,
|
||||
edges=[*connect('input', 'exit'),
|
||||
*connect('exit', 'output')],
|
||||
nodes_with_edges_only=True)
|
||||
exit_node = Node(graph, 'exit')
|
||||
in_node = Node(graph, 'input')
|
||||
shape = int64_array([-1, 36])
|
||||
in_node.shape = np.ma.masked_array(shape, mask=shape == -1, fill_value=dynamic_dimension_value)
|
||||
in_node.out_port(0).data.set_shape(in_node.shape)
|
||||
|
||||
Exit.exit_infer(exit_node)
|
||||
|
||||
self.assertTrue(np.ma.allequal(exit_node.out_port(0).data.get_shape(), in_node.shape))
|
||||
Reference in New Issue
Block a user