[MO] Add support to moc_frontend of ":" as delimiter for --input and --output (#6543)
* [MO] Add support to moc_frontend of ":" as delimiter for --input Additions: Changed default logic for 'Place::get_in(out)put_port' to return nullptr Changed default logic for 'InputModel::get_place_by_tensor(operation)_name' to return nullptr * Corrected comments in code * Missing empty line * Clang format fixes * Fix review comments * Updated test to verify review comments fixes * Update unit tests after rebase * Apply review comments
This commit is contained in:
parent
a95d59014c
commit
868fad33ab
@ -451,6 +451,17 @@ def extract_node_attrs(graph: Graph, extractor: callable):
|
|||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def raise_no_node(node_name: str):
|
||||||
|
raise Error('No node with name {}'.format(node_name))
|
||||||
|
|
||||||
|
|
||||||
|
def raise_node_name_collision(node_name: str, found_nodes: list):
|
||||||
|
raise Error('Name collision was found, there are several nodes for mask "{}": {}. '
|
||||||
|
'If your intention was to specify port for node, please instead specify node names connected to '
|
||||||
|
'this port. If your intention was to specify the node name, please add port to the node '
|
||||||
|
'name'.format(node_name, found_nodes))
|
||||||
|
|
||||||
|
|
||||||
def get_node_id_with_ports(graph: Graph, node_name: str, skip_if_no_port=True):
|
def get_node_id_with_ports(graph: Graph, node_name: str, skip_if_no_port=True):
|
||||||
"""
|
"""
|
||||||
Extracts port and node ID out of user provided name
|
Extracts port and node ID out of user provided name
|
||||||
@ -483,12 +494,9 @@ def get_node_id_with_ports(graph: Graph, node_name: str, skip_if_no_port=True):
|
|||||||
|
|
||||||
found_names.append((in_port, out_port, name))
|
found_names.append((in_port, out_port, name))
|
||||||
if len(found_names) == 0:
|
if len(found_names) == 0:
|
||||||
raise Error('No node with name {}'.format(node_name))
|
raise_no_node(node_name)
|
||||||
if len(found_names) > 1:
|
if len(found_names) > 1:
|
||||||
raise Error('Name collision was found, there are several nodes for mask "{}": {}. '
|
raise_node_name_collision(node_name, [name for _, _, name in found_names])
|
||||||
'If your intention was to specify port for node, please instead specify node names connected to '
|
|
||||||
'this port. If your intention was to specify the node name, please add port to the node '
|
|
||||||
'name'.format(node_name, [name for _, _, name in found_names]))
|
|
||||||
in_port, out_port, name = found_names[0]
|
in_port, out_port, name = found_names[0]
|
||||||
node_id = graph.get_node_id_by_name(name)
|
node_id = graph.get_node_id_by_name(name)
|
||||||
if in_port is not None:
|
if in_port is not None:
|
||||||
|
@ -1,35 +1,75 @@
|
|||||||
# Copyright (C) 2018-2021 Intel Corporation
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import logging as log
|
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
|
||||||
from copy import copy
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
from mo.front.extractor import raise_no_node, raise_node_name_collision
|
||||||
from mo.utils.error import Error
|
from mo.utils.error import Error
|
||||||
|
|
||||||
from ngraph.frontend import InputModel # pylint: disable=no-name-in-module,import-error
|
from ngraph.frontend import InputModel # pylint: disable=no-name-in-module,import-error
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def decode_name_with_port(input_model: InputModel, node_name: str):
|
def decode_name_with_port(input_model: InputModel, node_name: str):
|
||||||
"""
|
"""
|
||||||
Decode name with optional port specification w/o traversing all the nodes in the graph
|
Decode name with optional port specification w/o traversing all the nodes in the graph
|
||||||
TODO: in future node_name can specify input/output port groups and indices (58562)
|
TODO: in future node_name can specify input/output port groups as well as indices (58562)
|
||||||
:param input_model: Input Model
|
:param input_model: Input Model
|
||||||
:param node_name: user provided node name
|
:param node_name: user provided node name
|
||||||
:return: decoded place in the graph
|
:return: decoded place in the graph
|
||||||
"""
|
"""
|
||||||
# Check exact match with one of the names in the graph first
|
found_nodes = []
|
||||||
|
found_node_names = []
|
||||||
|
|
||||||
node = input_model.get_place_by_tensor_name(node_name)
|
node = input_model.get_place_by_tensor_name(node_name)
|
||||||
if node:
|
if node:
|
||||||
return node
|
found_node_names.append('Tensor:' + node_name)
|
||||||
|
found_nodes.append(node)
|
||||||
|
|
||||||
|
node = input_model.get_place_by_operation_name(node_name)
|
||||||
|
if node:
|
||||||
|
found_node_names.append('Operation:' + node_name)
|
||||||
|
found_nodes.append(node)
|
||||||
|
|
||||||
|
regexp_post = r'(.+):(\d+)'
|
||||||
|
match_post = re.search(regexp_post, node_name)
|
||||||
|
if match_post:
|
||||||
|
node_post = input_model.get_place_by_operation_name(match_post.group(1))
|
||||||
|
if node_post:
|
||||||
|
node_post = node_post.get_output_port(
|
||||||
|
outputPortIndex=int(match_post.group(2)))
|
||||||
|
if node_post:
|
||||||
|
found_node_names.append(match_post.group(1))
|
||||||
|
found_nodes.append(node_post)
|
||||||
|
|
||||||
|
regexp_pre = r'(\d+):(.+)'
|
||||||
|
match_pre = re.search(regexp_pre, node_name)
|
||||||
|
if match_pre:
|
||||||
|
node_pre = input_model.get_place_by_operation_name(match_pre.group(2))
|
||||||
|
if node_pre:
|
||||||
|
node_pre = node_pre.get_input_port(
|
||||||
|
inputPortIndex=int(match_pre.group(1)))
|
||||||
|
if node_pre:
|
||||||
|
found_node_names.append(match_pre.group(2))
|
||||||
|
found_nodes.append(node_pre)
|
||||||
|
|
||||||
|
if len(found_nodes) == 0:
|
||||||
|
raise_no_node(node_name)
|
||||||
|
|
||||||
|
# Check that there is no collision, all found places shall point to same data
|
||||||
|
if not all([n.is_equal_data(found_nodes[0]) for n in found_nodes]):
|
||||||
|
raise_node_name_collision(node_name, found_node_names)
|
||||||
|
|
||||||
|
# TODO: ONNX specific (59408)
|
||||||
|
# To comply with legacy behavior, for ONNX-only there shall be considered additional 2 possibilities
|
||||||
|
# 1) "abc:1" - get_place_by_tensor_name("abc").get_producing_operation().get_output_port(1)
|
||||||
|
# 2) "1:abc" - get_place_by_tensor_name("abc").get_producing_operation().get_input_port(1)
|
||||||
|
# This logic is not going to work with other frontends
|
||||||
|
|
||||||
# TODO: Add support for input/output group name and port index here (58562)
|
# TODO: Add support for input/output group name and port index here (58562)
|
||||||
# Legacy frontends use format "number:name:number" to specify input and output port indices
|
# For new frontends logic shall be extended to additionally support input and output group names
|
||||||
# For new frontends this logic shall be extended to additionally support input and output group names
|
return found_nodes[0]
|
||||||
raise Error('There is no node with name {}'.format(node_name))
|
|
||||||
|
|
||||||
|
|
||||||
def fe_input_user_data_repack(input_model: InputModel, input_user_shapes: [None, list, dict, np.ndarray],
|
def fe_input_user_data_repack(input_model: InputModel, input_user_shapes: [None, list, dict, np.ndarray],
|
||||||
|
@ -34,6 +34,15 @@ def test_frontends():
|
|||||||
assert not status.returncode
|
assert not status.returncode
|
||||||
|
|
||||||
|
|
||||||
|
def test_moc_extractor():
|
||||||
|
setup_env()
|
||||||
|
args = [sys.executable, '-m', 'pytest',
|
||||||
|
os.path.join(os.path.dirname(__file__), 'moc_frontend/moc_extractor_test_actual.py'), '-s']
|
||||||
|
|
||||||
|
status = subprocess.run(args, env=os.environ)
|
||||||
|
assert not status.returncode
|
||||||
|
|
||||||
|
|
||||||
def test_main_test():
|
def test_main_test():
|
||||||
setup_env()
|
setup_env()
|
||||||
args = [sys.executable, '-m', 'pytest',
|
args = [sys.executable, '-m', 'pytest',
|
||||||
|
@ -0,0 +1,246 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from mo.moc_frontend.extractor import decode_name_with_port
|
||||||
|
from mo.utils.error import Error
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
mock_available = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# pylint: disable=no-name-in-module,import-error
|
||||||
|
from mock_mo_python_api import get_model_statistic, get_place_statistic, \
|
||||||
|
clear_frontend_statistic, clear_model_statistic, clear_place_statistic, \
|
||||||
|
clear_setup, set_equal_data, set_max_port_counts
|
||||||
|
|
||||||
|
# pylint: disable=no-name-in-module,import-error
|
||||||
|
from ngraph.frontend import FrontEndManager
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
print("No mock frontend API available,"
|
||||||
|
"ensure to use -DENABLE_TESTS=ON option when running these tests")
|
||||||
|
mock_available = False
|
||||||
|
|
||||||
|
# FrontEndManager shall be initialized and destroyed after all tests finished
|
||||||
|
# This is because destroy of FrontEndManager will unload all plugins,
|
||||||
|
# no objects shall exist after this
|
||||||
|
if mock_available:
|
||||||
|
fem = FrontEndManager()
|
||||||
|
|
||||||
|
mock_needed = pytest.mark.skipif(not mock_available,
|
||||||
|
reason="mock MO fe is not available")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMainFrontend(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
clear_frontend_statistic()
|
||||||
|
clear_model_statistic()
|
||||||
|
clear_place_statistic()
|
||||||
|
clear_setup()
|
||||||
|
set_max_port_counts(10, 10)
|
||||||
|
self.fe = fem.load_by_framework('mock_mo_ngraph_frontend')
|
||||||
|
self.model = self.fe.load('abc.bin')
|
||||||
|
|
||||||
|
# Mock model has 'tensor' tensor place
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_tensor(self):
|
||||||
|
node = decode_name_with_port(self.model, "tensor")
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 1
|
||||||
|
assert node
|
||||||
|
|
||||||
|
# Mock model has 'operation' operation place
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_op(self):
|
||||||
|
node = decode_name_with_port(self.model, "operation")
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 1
|
||||||
|
assert node
|
||||||
|
|
||||||
|
# pylint: disable=wrong-spelling-in-comment
|
||||||
|
# Mock model doesn't have 'mocknoname' place
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_noname(self):
|
||||||
|
with self.assertRaisesRegex(Error, 'No\\ node\\ with\\ name.*mocknoname*'):
|
||||||
|
decode_name_with_port(self.model, 'mocknoname')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 1
|
||||||
|
|
||||||
|
# Mock model has both tensor and operation with same name and non-equal data
|
||||||
|
# Collision is expected
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_collision_op_tensor(self):
|
||||||
|
with self.assertRaisesRegex(Error, 'Name\\ collision.*tensorAndOp*'):
|
||||||
|
decode_name_with_port(self.model, 'tensorAndOp')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 1
|
||||||
|
assert place_stat.is_equal_data > 0
|
||||||
|
|
||||||
|
# Mock model has 'operation' and output port up to 10
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_op_out(self):
|
||||||
|
node = decode_name_with_port(self.model, 'operation:7')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.get_output_port == 1
|
||||||
|
assert place_stat.lastArgInt == 7
|
||||||
|
assert node
|
||||||
|
|
||||||
|
# Mock model has 'operation' and input port up to 10
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_op_in(self):
|
||||||
|
node = decode_name_with_port(self.model, '7:operation')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.get_input_port == 1
|
||||||
|
assert place_stat.lastArgInt == 7
|
||||||
|
assert node
|
||||||
|
|
||||||
|
# Mock model has 'operation' and 'operation:0' op places, collision is expected
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_op_collision_out(self):
|
||||||
|
with self.assertRaisesRegex(Error, 'Name\\ collision(?!.*Tensor.*).*operation\\:0*'):
|
||||||
|
decode_name_with_port(self.model, 'operation:0')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.is_equal_data > 0
|
||||||
|
assert place_stat.get_output_port == 1
|
||||||
|
assert place_stat.lastArgInt == 0
|
||||||
|
|
||||||
|
# Mock model has 'operation' and '0:operation' op places, collision is expected
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_op_collision_in(self):
|
||||||
|
with self.assertRaisesRegex(Error, 'Name\\ collision(?!.*Tensor.*).*0\\:operation*'):
|
||||||
|
decode_name_with_port(self.model, '0:operation')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.is_equal_data > 0
|
||||||
|
assert place_stat.get_input_port == 1
|
||||||
|
assert place_stat.lastArgInt == 0
|
||||||
|
|
||||||
|
# Mock model has 'tensor' and 'tensor:0' tensor places, no collision is expected
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_tensor_no_collision_out(self):
|
||||||
|
node = decode_name_with_port(self.model, 'tensor:0')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.get_output_port == 0
|
||||||
|
assert node
|
||||||
|
|
||||||
|
# Mock model has 'tensor' and '0:tensor' tensor places, no collision is expected
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_tensor_no_collision_in(self):
|
||||||
|
node = decode_name_with_port(self.model, '0:tensor')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.get_input_port == 0
|
||||||
|
assert node
|
||||||
|
|
||||||
|
# Mock model doesn't have such '1234:operation' or output port=1234 for 'operation'
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_no_port_out(self):
|
||||||
|
with self.assertRaisesRegex(Error, 'No\\ node\\ with\\ name.*operation\\:1234*'):
|
||||||
|
decode_name_with_port(self.model, 'operation:1234')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.get_output_port == 1
|
||||||
|
assert place_stat.lastArgInt == 1234
|
||||||
|
|
||||||
|
# Mock model doesn't have such '1234:operation' or input port=1234 for 'operation'
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_no_port_in(self):
|
||||||
|
with self.assertRaisesRegex(Error, 'No\\ node\\ with\\ name.*1234\\:operation*'):
|
||||||
|
decode_name_with_port(self.model, '1234:operation')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.get_input_port == 1
|
||||||
|
assert place_stat.lastArgInt == 1234
|
||||||
|
|
||||||
|
# Mock model has tensor with name 'conv2d:0' and operation 'conv2d' with output port = 1
|
||||||
|
# It is setup to return 'is_equal_data=True' for these tensor and port
|
||||||
|
# So no collision is expected
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_equal_data_out(self):
|
||||||
|
set_equal_data('conv2d', 'conv2d')
|
||||||
|
node = decode_name_with_port(self.model, 'conv2d:0')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.get_output_port == 1
|
||||||
|
assert place_stat.is_equal_data > 0
|
||||||
|
assert node
|
||||||
|
|
||||||
|
# Mock model has tensor with name '0:conv2d' and operation 'conv2d' with input port = 1
|
||||||
|
# It is setup to return 'is_equal_data=True' for these tensor and port
|
||||||
|
# So no collision is expected
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_equal_data_in(self):
|
||||||
|
set_equal_data('conv2d', 'conv2d')
|
||||||
|
node = decode_name_with_port(self.model, '0:conv2d')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 2
|
||||||
|
assert place_stat.get_input_port == 1
|
||||||
|
assert place_stat.is_equal_data > 0
|
||||||
|
assert node
|
||||||
|
|
||||||
|
# Stress case: Mock model has:
|
||||||
|
# Tensor '8:9'
|
||||||
|
# Operation '8:9'
|
||||||
|
# Operation '8' with output port = 9
|
||||||
|
# Operation '9' with input port = 8
|
||||||
|
# All places point to same data - no collision is expected
|
||||||
|
@mock_needed
|
||||||
|
def test_decode_name_with_port_delim_all_same_data(self):
|
||||||
|
set_equal_data('8', '9')
|
||||||
|
node = decode_name_with_port(self.model, '8:9')
|
||||||
|
model_stat = get_model_statistic()
|
||||||
|
place_stat = get_place_statistic()
|
||||||
|
|
||||||
|
assert model_stat.get_place_by_tensor_name == 1
|
||||||
|
assert model_stat.get_place_by_operation_name == 3
|
||||||
|
assert place_stat.get_input_port == 1
|
||||||
|
assert place_stat.get_output_port == 1
|
||||||
|
# At least 3 comparisons of places are expected
|
||||||
|
assert place_stat.is_equal_data > 2
|
||||||
|
assert node
|
@ -14,6 +14,11 @@ FeStat FrontEndMockPy::m_stat = {};
|
|||||||
ModelStat InputModelMockPy::m_stat = {};
|
ModelStat InputModelMockPy::m_stat = {};
|
||||||
PlaceStat PlaceMockPy::m_stat = {};
|
PlaceStat PlaceMockPy::m_stat = {};
|
||||||
|
|
||||||
|
std::string MockSetup::m_equal_data_node1 = {};
|
||||||
|
std::string MockSetup::m_equal_data_node2 = {};
|
||||||
|
int MockSetup::m_max_input_port_index = 0;
|
||||||
|
int MockSetup::m_max_output_port_index = 0;
|
||||||
|
|
||||||
PartialShape InputModelMockPy::m_returnShape = {};
|
PartialShape InputModelMockPy::m_returnShape = {};
|
||||||
|
|
||||||
extern "C" MOCK_API FrontEndVersion GetAPIVersion()
|
extern "C" MOCK_API FrontEndVersion GetAPIVersion()
|
||||||
|
@ -21,6 +21,35 @@ using namespace ngraph;
|
|||||||
using namespace ngraph::frontend;
|
using namespace ngraph::frontend;
|
||||||
|
|
||||||
////////////////////////////////
|
////////////////////////////////
|
||||||
|
/// \brief This structure holds number static setup values
|
||||||
|
/// It will be used by Python unit tests to setup particular mock behavior
|
||||||
|
struct MOCK_API MockSetup
|
||||||
|
{
|
||||||
|
static std::string m_equal_data_node1;
|
||||||
|
static std::string m_equal_data_node2;
|
||||||
|
static int m_max_input_port_index;
|
||||||
|
static int m_max_output_port_index;
|
||||||
|
|
||||||
|
static void clear_setup()
|
||||||
|
{
|
||||||
|
m_equal_data_node1 = {};
|
||||||
|
m_equal_data_node2 = {};
|
||||||
|
m_max_input_port_index = 0;
|
||||||
|
m_max_output_port_index = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void set_equal_data(const std::string& node1, const std::string& node2)
|
||||||
|
{
|
||||||
|
m_equal_data_node1 = node1;
|
||||||
|
m_equal_data_node2 = node2;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void set_max_port_counts(int max_input, int max_output)
|
||||||
|
{
|
||||||
|
m_max_input_port_index = max_input;
|
||||||
|
m_max_output_port_index = max_output;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// \brief This structure holds number of calls of particular methods of Place objects
|
/// \brief This structure holds number of calls of particular methods of Place objects
|
||||||
/// It will be used by Python unit tests to verify that appropriate API
|
/// It will be used by Python unit tests to verify that appropriate API
|
||||||
@ -33,6 +62,7 @@ struct MOCK_API PlaceStat
|
|||||||
int m_is_input = 0;
|
int m_is_input = 0;
|
||||||
int m_is_output = 0;
|
int m_is_output = 0;
|
||||||
int m_is_equal = 0;
|
int m_is_equal = 0;
|
||||||
|
int m_is_equal_data = 0;
|
||||||
|
|
||||||
// Arguments tracking
|
// Arguments tracking
|
||||||
std::string m_lastArgString;
|
std::string m_lastArgString;
|
||||||
@ -46,6 +76,7 @@ struct MOCK_API PlaceStat
|
|||||||
int is_input() const { return m_is_input; }
|
int is_input() const { return m_is_input; }
|
||||||
int is_output() const { return m_is_output; }
|
int is_output() const { return m_is_output; }
|
||||||
int is_equal() const { return m_is_equal; }
|
int is_equal() const { return m_is_equal; }
|
||||||
|
int is_equal_data() const { return m_is_equal_data; }
|
||||||
|
|
||||||
// Arguments getters
|
// Arguments getters
|
||||||
std::string get_lastArgString() const { return m_lastArgString; }
|
std::string get_lastArgString() const { return m_lastArgString; }
|
||||||
@ -60,10 +91,14 @@ class MOCK_API PlaceMockPy : public Place
|
|||||||
{
|
{
|
||||||
static PlaceStat m_stat;
|
static PlaceStat m_stat;
|
||||||
std::string m_name;
|
std::string m_name;
|
||||||
|
bool m_is_op = false;
|
||||||
|
int m_portIndex = -1;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PlaceMockPy(const std::string& name = {})
|
PlaceMockPy(const std::string& name = {}, bool is_op = false, int portIndex = -1)
|
||||||
: m_name(name)
|
: m_name(name)
|
||||||
|
, m_is_op(is_op)
|
||||||
|
, m_portIndex(portIndex)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +119,11 @@ public:
|
|||||||
{
|
{
|
||||||
m_stat.m_get_input_port++;
|
m_stat.m_get_input_port++;
|
||||||
m_stat.m_lastArgInt = inputPortIndex;
|
m_stat.m_lastArgInt = inputPortIndex;
|
||||||
return std::make_shared<PlaceMockPy>();
|
if (inputPortIndex < MockSetup::m_max_input_port_index)
|
||||||
|
{
|
||||||
|
return std::make_shared<PlaceMockPy>(m_name, false, inputPortIndex);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Place::Ptr get_input_port(const std::string& inputName) const override
|
Place::Ptr get_input_port(const std::string& inputName) const override
|
||||||
@ -114,7 +153,11 @@ public:
|
|||||||
{
|
{
|
||||||
m_stat.m_get_output_port++;
|
m_stat.m_get_output_port++;
|
||||||
m_stat.m_lastArgInt = outputPortIndex;
|
m_stat.m_lastArgInt = outputPortIndex;
|
||||||
return std::make_shared<PlaceMockPy>();
|
if (outputPortIndex < MockSetup::m_max_output_port_index)
|
||||||
|
{
|
||||||
|
return std::make_shared<PlaceMockPy>(m_name, false, outputPortIndex);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Place::Ptr get_output_port(const std::string& outputName) const override
|
Place::Ptr get_output_port(const std::string& outputName) const override
|
||||||
@ -149,7 +192,27 @@ public:
|
|||||||
{
|
{
|
||||||
m_stat.m_is_equal++;
|
m_stat.m_is_equal++;
|
||||||
m_stat.m_lastArgPlace = another;
|
m_stat.m_lastArgPlace = another;
|
||||||
return m_name == another->get_names().at(0);
|
std::shared_ptr<PlaceMockPy> mock = std::dynamic_pointer_cast<PlaceMockPy>(another);
|
||||||
|
return m_name == mock->m_name && m_is_op == mock->m_is_op &&
|
||||||
|
m_portIndex == mock->m_portIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_equal_data(Ptr another) const override
|
||||||
|
{
|
||||||
|
m_stat.m_is_equal_data++;
|
||||||
|
m_stat.m_lastArgPlace = another;
|
||||||
|
std::shared_ptr<PlaceMockPy> mock = std::dynamic_pointer_cast<PlaceMockPy>(another);
|
||||||
|
if (!MockSetup::m_equal_data_node1.empty() && !MockSetup::m_equal_data_node2.empty())
|
||||||
|
{
|
||||||
|
if ((mock->m_name.find(MockSetup::m_equal_data_node1) != std::string::npos ||
|
||||||
|
mock->m_name.find(MockSetup::m_equal_data_node2) != std::string::npos) &&
|
||||||
|
(m_name.find(MockSetup::m_equal_data_node1) != std::string::npos ||
|
||||||
|
m_name.find(MockSetup::m_equal_data_node2) != std::string::npos))
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mock->m_is_op == m_is_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
//---------------Stat--------------------
|
//---------------Stat--------------------
|
||||||
@ -167,6 +230,7 @@ struct MOCK_API ModelStat
|
|||||||
int m_get_inputs = 0;
|
int m_get_inputs = 0;
|
||||||
int m_get_outputs = 0;
|
int m_get_outputs = 0;
|
||||||
int m_get_place_by_tensor_name = 0;
|
int m_get_place_by_tensor_name = 0;
|
||||||
|
int m_get_place_by_operation_name = 0;
|
||||||
int m_set_partial_shape = 0;
|
int m_set_partial_shape = 0;
|
||||||
int m_get_partial_shape = 0;
|
int m_get_partial_shape = 0;
|
||||||
int m_set_element_type = 0;
|
int m_set_element_type = 0;
|
||||||
@ -190,6 +254,7 @@ struct MOCK_API ModelStat
|
|||||||
int extract_subgraph() const { return m_extract_subgraph; }
|
int extract_subgraph() const { return m_extract_subgraph; }
|
||||||
int override_all_inputs() const { return m_override_all_inputs; }
|
int override_all_inputs() const { return m_override_all_inputs; }
|
||||||
int override_all_outputs() const { return m_override_all_outputs; }
|
int override_all_outputs() const { return m_override_all_outputs; }
|
||||||
|
int get_place_by_operation_name() const { return m_get_place_by_operation_name; }
|
||||||
int get_place_by_tensor_name() const { return m_get_place_by_tensor_name; }
|
int get_place_by_tensor_name() const { return m_get_place_by_tensor_name; }
|
||||||
int set_partial_shape() const { return m_set_partial_shape; }
|
int set_partial_shape() const { return m_set_partial_shape; }
|
||||||
int get_partial_shape() const { return m_get_partial_shape; }
|
int get_partial_shape() const { return m_get_partial_shape; }
|
||||||
@ -208,12 +273,31 @@ struct MOCK_API ModelStat
|
|||||||
/// \brief Mock implementation of InputModel
|
/// \brief Mock implementation of InputModel
|
||||||
/// Every call increments appropriate counters in statistic and stores argument values to statistics
|
/// Every call increments appropriate counters in statistic and stores argument values to statistics
|
||||||
/// as well
|
/// as well
|
||||||
/// ("mock_output1", "mock_output2")
|
|
||||||
class MOCK_API InputModelMockPy : public InputModel
|
class MOCK_API InputModelMockPy : public InputModel
|
||||||
{
|
{
|
||||||
static ModelStat m_stat;
|
static ModelStat m_stat;
|
||||||
static PartialShape m_returnShape;
|
static PartialShape m_returnShape;
|
||||||
|
|
||||||
|
std::set<std::string> m_operations = {
|
||||||
|
"8", "9", "8:9", "operation", "operation:0", "0:operation", "tensorAndOp", "conv2d"};
|
||||||
|
std::set<std::string> m_tensors = {"8:9",
|
||||||
|
"tensor",
|
||||||
|
"tensor:0",
|
||||||
|
"0:tensor",
|
||||||
|
"tensorAndOp",
|
||||||
|
"conv2d:0",
|
||||||
|
"0:conv2d",
|
||||||
|
"mock_input1",
|
||||||
|
"mock_input2",
|
||||||
|
"newInput1",
|
||||||
|
"newIn1",
|
||||||
|
"newIn2",
|
||||||
|
"mock_output1",
|
||||||
|
"mock_output2",
|
||||||
|
"new_output2",
|
||||||
|
"newOut1",
|
||||||
|
"newOut2"};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
std::vector<Place::Ptr> get_inputs() const override
|
std::vector<Place::Ptr> get_inputs() const override
|
||||||
{
|
{
|
||||||
@ -229,12 +313,27 @@ public:
|
|||||||
std::make_shared<PlaceMockPy>("mock_output2")};
|
std::make_shared<PlaceMockPy>("mock_output2")};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Place::Ptr get_place_by_operation_name(const std::string& opName) const override
|
||||||
|
{
|
||||||
|
m_stat.m_get_place_by_operation_name++;
|
||||||
|
m_stat.m_lastArgString = opName;
|
||||||
|
if (m_operations.count(opName))
|
||||||
|
{
|
||||||
|
return std::make_shared<PlaceMockPy>(opName, true);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
Place::Ptr get_place_by_tensor_name(const std::string& tensorName) const override
|
Place::Ptr get_place_by_tensor_name(const std::string& tensorName) const override
|
||||||
{
|
{
|
||||||
m_stat.m_get_place_by_tensor_name++;
|
m_stat.m_get_place_by_tensor_name++;
|
||||||
m_stat.m_lastArgString = tensorName;
|
m_stat.m_lastArgString = tensorName;
|
||||||
|
if (m_tensors.count(tensorName))
|
||||||
|
{
|
||||||
return std::make_shared<PlaceMockPy>(tensorName);
|
return std::make_shared<PlaceMockPy>(tensorName);
|
||||||
}
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
void override_all_outputs(const std::vector<Place::Ptr>& outputs) override
|
void override_all_outputs(const std::vector<Place::Ptr>& outputs) override
|
||||||
{
|
{
|
||||||
@ -319,7 +418,7 @@ public:
|
|||||||
|
|
||||||
static void clear_stat() { m_stat = {}; }
|
static void clear_stat() { m_stat = {}; }
|
||||||
|
|
||||||
protected:
|
private:
|
||||||
InputModel::Ptr load_impl(const std::vector<std::shared_ptr<Variant>>& params) const override
|
InputModel::Ptr load_impl(const std::vector<std::shared_ptr<Variant>>& params) const override
|
||||||
{
|
{
|
||||||
if (params.size() > 0 && is_type<VariantWrapper<std::string>>(params[0]))
|
if (params.size() > 0 && is_type<VariantWrapper<std::string>>(params[0]))
|
||||||
|
@ -21,6 +21,13 @@ static void register_mock_frontend_stat(py::module m)
|
|||||||
feStat.def_property_readonly("convert_model", &FeStat::convert_model);
|
feStat.def_property_readonly("convert_model", &FeStat::convert_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void register_mock_setup(py::module m)
|
||||||
|
{
|
||||||
|
m.def("clear_setup", &MockSetup::clear_setup);
|
||||||
|
m.def("set_equal_data", &MockSetup::set_equal_data);
|
||||||
|
m.def("set_max_port_counts", &MockSetup::set_max_port_counts);
|
||||||
|
}
|
||||||
|
|
||||||
static void register_mock_model_stat(py::module m)
|
static void register_mock_model_stat(py::module m)
|
||||||
{
|
{
|
||||||
m.def("get_model_statistic", &InputModelMockPy::get_stat);
|
m.def("get_model_statistic", &InputModelMockPy::get_stat);
|
||||||
@ -30,6 +37,8 @@ static void register_mock_model_stat(py::module m)
|
|||||||
py::class_<ModelStat> mdlStat(m, "ModelStat", py::dynamic_attr());
|
py::class_<ModelStat> mdlStat(m, "ModelStat", py::dynamic_attr());
|
||||||
mdlStat.def_property_readonly("get_inputs", &ModelStat::get_inputs);
|
mdlStat.def_property_readonly("get_inputs", &ModelStat::get_inputs);
|
||||||
mdlStat.def_property_readonly("get_outputs", &ModelStat::get_outputs);
|
mdlStat.def_property_readonly("get_outputs", &ModelStat::get_outputs);
|
||||||
|
mdlStat.def_property_readonly("get_place_by_operation_name",
|
||||||
|
&ModelStat::get_place_by_operation_name);
|
||||||
mdlStat.def_property_readonly("get_place_by_tensor_name", &ModelStat::get_place_by_tensor_name);
|
mdlStat.def_property_readonly("get_place_by_tensor_name", &ModelStat::get_place_by_tensor_name);
|
||||||
|
|
||||||
mdlStat.def_property_readonly("set_partial_shape", &ModelStat::set_partial_shape);
|
mdlStat.def_property_readonly("set_partial_shape", &ModelStat::set_partial_shape);
|
||||||
@ -66,12 +75,14 @@ static void register_mock_place_stat(py::module m)
|
|||||||
placeStat.def_property_readonly("is_input", &PlaceStat::is_input);
|
placeStat.def_property_readonly("is_input", &PlaceStat::is_input);
|
||||||
placeStat.def_property_readonly("is_output", &PlaceStat::is_output);
|
placeStat.def_property_readonly("is_output", &PlaceStat::is_output);
|
||||||
placeStat.def_property_readonly("is_equal", &PlaceStat::is_equal);
|
placeStat.def_property_readonly("is_equal", &PlaceStat::is_equal);
|
||||||
|
placeStat.def_property_readonly("is_equal_data", &PlaceStat::is_equal_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
PYBIND11_MODULE(mock_mo_python_api, m)
|
PYBIND11_MODULE(mock_mo_python_api, m)
|
||||||
{
|
{
|
||||||
m.doc() = "Mock frontend call counters for testing Pyngraph frontend bindings";
|
m.doc() = "Mock frontend call counters for testing Pyngraph frontend bindings";
|
||||||
register_mock_frontend_stat(m);
|
register_mock_frontend_stat(m);
|
||||||
|
register_mock_setup(m);
|
||||||
register_mock_model_stat(m);
|
register_mock_model_stat(m);
|
||||||
register_mock_place_stat(m);
|
register_mock_place_stat(m);
|
||||||
}
|
}
|
||||||
|
@ -69,13 +69,14 @@ namespace ngraph
|
|||||||
/// \brief Returns a tensor place by a tensor name following framework conventions, or
|
/// \brief Returns a tensor place by a tensor name following framework conventions, or
|
||||||
/// nullptr if a tensor with this name doesn't exist.
|
/// nullptr if a tensor with this name doesn't exist.
|
||||||
/// \param tensor_name Name of tensor
|
/// \param tensor_name Name of tensor
|
||||||
/// \return Tensor place corresponding to specifed tensor name
|
/// \return Tensor place corresponding to specified tensor name or nullptr if not exists
|
||||||
virtual Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const;
|
virtual Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const;
|
||||||
|
|
||||||
/// \brief Returns an operation place by an operation name following framework
|
/// \brief Returns an operation place by an operation name following framework
|
||||||
/// conventions, or nullptr if an operation with this name doesn't exist. \param
|
/// conventions, or nullptr if an operation with this name doesn't exist.
|
||||||
/// operation_name Name of operation \return Place representing operation
|
/// \param operation_name Name of operation
|
||||||
virtual Place::Ptr get_place_by_operation_name(const std::string& operation_name);
|
/// \return Place representing operation or nullptr if not exists
|
||||||
|
virtual Place::Ptr get_place_by_operation_name(const std::string& operation_name) const;
|
||||||
|
|
||||||
/// \brief Returns an input port place by operation name and appropriate port index
|
/// \brief Returns an input port place by operation name and appropriate port index
|
||||||
/// \param operation_name Name of operation
|
/// \param operation_name Name of operation
|
||||||
@ -88,7 +89,7 @@ namespace ngraph
|
|||||||
/// \brief Returns an output port place by operation name and appropriate port index
|
/// \brief Returns an output port place by operation name and appropriate port index
|
||||||
/// \param operation_name Name of operation
|
/// \param operation_name Name of operation
|
||||||
/// \param output_port_index Index of output port for this operation
|
/// \param output_port_index Index of output port for this operation
|
||||||
/// \return Place representing output port of operation
|
/// \return Place representing output port of operation or nullptr if not exists
|
||||||
virtual Place::Ptr
|
virtual Place::Ptr
|
||||||
get_place_by_operation_name_and_output_port(const std::string& operation_name,
|
get_place_by_operation_name_and_output_port(const std::string& operation_name,
|
||||||
int output_port_index);
|
int output_port_index);
|
||||||
|
@ -232,14 +232,14 @@ namespace ngraph
|
|||||||
/// \brief For operation node returns reference to an input port; applicable if
|
/// \brief For operation node returns reference to an input port; applicable if
|
||||||
/// operation node has only one input port
|
/// operation node has only one input port
|
||||||
///
|
///
|
||||||
/// \return Input port place
|
/// \return Input port place or nullptr if not exists
|
||||||
virtual Ptr get_input_port() const;
|
virtual Ptr get_input_port() const;
|
||||||
|
|
||||||
/// \brief For operation node returns reference to an input port with specified index
|
/// \brief For operation node returns reference to an input port with specified index
|
||||||
///
|
///
|
||||||
/// \param input_port_index Input port index
|
/// \param input_port_index Input port index
|
||||||
///
|
///
|
||||||
/// \return Appropriate input port place
|
/// \return Appropriate input port place or nullptr if not exists
|
||||||
virtual Ptr get_input_port(int input_port_index) const;
|
virtual Ptr get_input_port(int input_port_index) const;
|
||||||
|
|
||||||
/// \brief For operation node returns reference to an input port with specified name;
|
/// \brief For operation node returns reference to an input port with specified name;
|
||||||
@ -247,7 +247,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param input_name Name of port group
|
/// \param input_name Name of port group
|
||||||
///
|
///
|
||||||
/// \return Appropriate input port place
|
/// \return Appropriate input port place or nullptr if not exists
|
||||||
virtual Ptr get_input_port(const std::string& input_name) const;
|
virtual Ptr get_input_port(const std::string& input_name) const;
|
||||||
|
|
||||||
/// \brief For operation node returns reference to an input port with specified name and
|
/// \brief For operation node returns reference to an input port with specified name and
|
||||||
@ -257,20 +257,20 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param input_port_index Input port index in a group
|
/// \param input_port_index Input port index in a group
|
||||||
///
|
///
|
||||||
/// \return Appropriate input port place
|
/// \return Appropriate input port place or nullptr if not exists
|
||||||
virtual Ptr get_input_port(const std::string& input_name, int input_port_index) const;
|
virtual Ptr get_input_port(const std::string& input_name, int input_port_index) const;
|
||||||
|
|
||||||
/// \brief For operation node returns reference to an output port; applicable for
|
/// \brief For operation node returns reference to an output port; applicable for
|
||||||
/// operations with only one output port
|
/// operations with only one output port
|
||||||
///
|
///
|
||||||
/// \return Appropriate output port place
|
/// \return Appropriate output port place or nullptr if not exists
|
||||||
virtual Ptr get_output_port() const;
|
virtual Ptr get_output_port() const;
|
||||||
|
|
||||||
/// \brief For operation node returns reference to an output port with specified index
|
/// \brief For operation node returns reference to an output port with specified index
|
||||||
///
|
///
|
||||||
/// \param output_port_index Output port index
|
/// \param output_port_index Output port index
|
||||||
///
|
///
|
||||||
/// \return Appropriate output port place
|
/// \return Appropriate output port place or nullptr if not exists
|
||||||
virtual Ptr get_output_port(int output_port_index) const;
|
virtual Ptr get_output_port(int output_port_index) const;
|
||||||
|
|
||||||
/// \brief For operation node returns reference to an output port with specified name;
|
/// \brief For operation node returns reference to an output port with specified name;
|
||||||
@ -278,7 +278,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param output_name Name of output port group
|
/// \param output_name Name of output port group
|
||||||
///
|
///
|
||||||
/// \return Appropriate output port place
|
/// \return Appropriate output port place or nullptr if not exists
|
||||||
virtual Ptr get_output_port(const std::string& output_name) const;
|
virtual Ptr get_output_port(const std::string& output_name) const;
|
||||||
|
|
||||||
/// \brief For operation node returns reference to an output port with specified name
|
/// \brief For operation node returns reference to an output port with specified name
|
||||||
@ -288,7 +288,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param output_port_index Output port index
|
/// \param output_port_index Output port index
|
||||||
///
|
///
|
||||||
/// \return Appropriate output port place
|
/// \return Appropriate output port place or nullptr if not exists
|
||||||
virtual Ptr get_output_port(const std::string& output_name,
|
virtual Ptr get_output_port(const std::string& output_name,
|
||||||
int output_port_index) const;
|
int output_port_index) const;
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ Place::Ptr InputModel::get_place_by_tensor_name(const std::string& tensor_name)
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Place::Ptr InputModel::get_place_by_operation_name(const std::string& operation_name)
|
Place::Ptr InputModel::get_place_by_operation_name(const std::string& operation_name) const
|
||||||
{
|
{
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -334,7 +334,7 @@ public:
|
|||||||
return std::make_shared<PlaceMockPy>();
|
return std::make_shared<PlaceMockPy>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Place::Ptr get_place_by_operation_name(const std::string& operationName) override
|
Place::Ptr get_place_by_operation_name(const std::string& operationName) const override
|
||||||
{
|
{
|
||||||
m_stat.m_get_place_by_operation_name++;
|
m_stat.m_get_place_by_operation_name++;
|
||||||
m_stat.m_lastArgString = operationName;
|
m_stat.m_lastArgString = operationName;
|
||||||
|
Loading…
Reference in New Issue
Block a user