[POT] Update IEEngine with the Dynamic model support (#10717)
* Update IEEngine with the Dynamic models support * Update with the batch * Method naming fix * Update image_loader & tests with dynamic models * Update test_sanity.py * Replace custom_mo_config from the model
This commit is contained in:
parent
3b8e960b10
commit
41818a377f
@ -5,6 +5,7 @@ from cv2 import imread, IMREAD_GRAYSCALE
|
|||||||
|
|
||||||
from openvino.runtime import Layout, Dimension # pylint: disable=E0611,E0401
|
from openvino.runtime import Layout, Dimension # pylint: disable=E0611,E0401
|
||||||
from openvino.tools.mo.utils.cli_parser import get_layout_values
|
from openvino.tools.mo.utils.cli_parser import get_layout_values
|
||||||
|
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined
|
||||||
from ..api.data_loader import DataLoader
|
from ..api.data_loader import DataLoader
|
||||||
from ..data_loaders.utils import prepare_image, collect_img_files
|
from ..data_loaders.utils import prepare_image, collect_img_files
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
@ -34,6 +35,7 @@ class ImageLoader(DataLoader):
|
|||||||
|
|
||||||
@shape.setter
|
@shape.setter
|
||||||
def shape(self, shape):
|
def shape(self, shape):
|
||||||
|
self._is_shape_static = is_fully_defined(shape)
|
||||||
self._shape = tuple(shape)
|
self._shape = tuple(shape)
|
||||||
|
|
||||||
def _read_and_preproc_image(self, img_path):
|
def _read_and_preproc_image(self, img_path):
|
||||||
@ -47,7 +49,9 @@ class ImageLoader(DataLoader):
|
|||||||
if image is None:
|
if image is None:
|
||||||
raise Exception('Can not read the image: {}'.format(img_path))
|
raise Exception('Can not read the image: {}'.format(img_path))
|
||||||
|
|
||||||
return prepare_image(image, self._layout, (self.shape[H], self.shape[W]),
|
dst_shape = (self.shape[H], self.shape[W]) if self._is_shape_static else None
|
||||||
|
|
||||||
|
return prepare_image(image, self._layout, dst_shape,
|
||||||
self._crop_central_fraction, self._shape[C] == 1)
|
self._crop_central_fraction, self._shape[C] == 1)
|
||||||
|
|
||||||
def get_layout(self, input_node=None):
|
def get_layout(self, input_node=None):
|
||||||
|
@ -35,11 +35,13 @@ def crop(image, central_fraction):
|
|||||||
return image[start_height:start_height + dst_height, start_width:start_width + dst_width]
|
return image[start_height:start_height + dst_height, start_width:start_width + dst_width]
|
||||||
|
|
||||||
|
|
||||||
def prepare_image(image, layout, dst_shape, central_fraction=None, grayscale=False):
|
def prepare_image(image, layout, dst_shape=None, central_fraction=None, grayscale=False):
|
||||||
if central_fraction:
|
if central_fraction:
|
||||||
image = crop(image, central_fraction)
|
image = crop(image, central_fraction)
|
||||||
|
|
||||||
image = cv.resize(image, dst_shape[::-1])
|
if dst_shape:
|
||||||
|
image = cv.resize(image, dst_shape[::-1])
|
||||||
|
|
||||||
if grayscale:
|
if grayscale:
|
||||||
image = np.expand_dims(image, 2)
|
image = np.expand_dims(image, 2)
|
||||||
|
|
||||||
|
@ -230,20 +230,26 @@ class IEEngine(Engine):
|
|||||||
"""
|
"""
|
||||||
input_info = model.inputs
|
input_info = model.inputs
|
||||||
|
|
||||||
|
def is_dynamic_input(input_blob):
|
||||||
|
return input_blob.partial_shape.is_dynamic
|
||||||
|
|
||||||
|
def process_input(input_blob, input_data):
|
||||||
|
return input_data if is_dynamic_input(input_blob) else np.reshape(input_data, input_blob.shape)
|
||||||
|
|
||||||
if isinstance(image_batch[0], dict):
|
if isinstance(image_batch[0], dict):
|
||||||
feed_dict = {}
|
feed_dict = {}
|
||||||
input_blobs = {get_clean_name(in_node.get_node().friendly_name): in_node for in_node in input_info}
|
input_blobs = {get_clean_name(in_node.get_node().friendly_name): in_node for in_node in input_info}
|
||||||
for input_name in image_batch[0].keys():
|
for input_name in image_batch[0].keys():
|
||||||
input_blob = input_blobs[input_name]
|
input_blob = input_blobs[input_name]
|
||||||
input_blob_name = self._get_input_any_name(input_blob)
|
input_blob_name = self._get_input_any_name(input_blob)
|
||||||
feed_dict[input_blob_name] = np.reshape(image_batch[0][input_name], input_blob.shape)
|
feed_dict[input_blob_name] = process_input(input_blob, image_batch[0][input_name])
|
||||||
return feed_dict
|
return feed_dict
|
||||||
|
|
||||||
if len(input_info) == 1:
|
if len(input_info) == 1:
|
||||||
input_blob = next(iter(input_info))
|
input_blob = next(iter(input_info))
|
||||||
input_blob_name = self._get_input_any_name(input_blob)
|
input_blob_name = self._get_input_any_name(input_blob)
|
||||||
image_batch = {input_blob_name: np.reshape(image_batch, input_blob.shape)}
|
image_batch = {input_blob_name: process_input(input_blob, image_batch)}
|
||||||
if Shape(image_batch[input_blob_name].shape) != input_info[0].shape:
|
if not is_dynamic_input(input_blob) and Shape(image_batch[input_blob_name].shape) != input_info[0].shape:
|
||||||
raise ValueError(f"Incompatible input shapes. "
|
raise ValueError(f"Incompatible input shapes. "
|
||||||
f"Cannot infer {Shape(image_batch[input_blob_name].shape)} into {input_info[0].shape}."
|
f"Cannot infer {Shape(image_batch[input_blob_name].shape)} into {input_info[0].shape}."
|
||||||
f"Try to specify the layout of the model.")
|
f"Try to specify the layout of the model.")
|
||||||
@ -262,8 +268,9 @@ class IEEngine(Engine):
|
|||||||
lambda x: x.get_any_name() != image_info_name, input_info)))
|
lambda x: x.get_any_name() != image_info_name, input_info)))
|
||||||
image_tensor_name = image_tensor_node.get_any_name()
|
image_tensor_name = image_tensor_node.get_any_name()
|
||||||
|
|
||||||
image_tensor = (image_tensor_name, np.reshape(image_batch, input_blob.shape))
|
image_tensor = (image_tensor_name, process_input(image_tensor_node, image_batch))
|
||||||
if Shape(image_tensor[1].shape) != image_tensor_node.shape:
|
if not is_dynamic_input(image_tensor_node) and \
|
||||||
|
Shape(image_tensor[1].shape) != image_tensor_node.shape:
|
||||||
raise ValueError(f"Incompatible input shapes. "
|
raise ValueError(f"Incompatible input shapes. "
|
||||||
f"Cannot infer {Shape(image_tensor[1].shape)} into {image_tensor_node.shape}."
|
f"Cannot infer {Shape(image_tensor[1].shape)} into {image_tensor_node.shape}."
|
||||||
f"Try to specify the layout of the model.")
|
f"Try to specify the layout of the model.")
|
||||||
|
@ -113,7 +113,9 @@ def test_compression(_params, tmp_path, models):
|
|||||||
|
|
||||||
|
|
||||||
TEST_SAMPLE_MODELS = [
|
TEST_SAMPLE_MODELS = [
|
||||||
('mobilenet-v2-1.0-224', 'tf', 'DefaultQuantization', 'performance', {'accuracy@top1': 0.716})
|
('mobilenet-v2-1.0-224', 'tf', 'DefaultQuantization', 'performance', {'accuracy@top1': 0.716}, []),
|
||||||
|
('mobilenet-v2-1.0-224', 'tf', 'DefaultQuantization', 'performance', {'accuracy@top1': 0.716},
|
||||||
|
['--input_shape=[1,?,?,3]'])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -124,7 +126,7 @@ def _sample_params(request):
|
|||||||
|
|
||||||
|
|
||||||
def test_sample_compression(_sample_params, tmp_path, models):
|
def test_sample_compression(_sample_params, tmp_path, models):
|
||||||
model_name, model_framework, algorithm, preset, expected_accuracy = _sample_params
|
model_name, model_framework, algorithm, preset, expected_accuracy, custom_mo_config = _sample_params
|
||||||
|
|
||||||
# hack for sample imports because sample app works only from sample directory
|
# hack for sample imports because sample app works only from sample directory
|
||||||
pot_dir = Path(__file__).parent.parent
|
pot_dir = Path(__file__).parent.parent
|
||||||
@ -132,7 +134,7 @@ def test_sample_compression(_sample_params, tmp_path, models):
|
|||||||
# pylint: disable=C0415
|
# pylint: disable=C0415
|
||||||
from openvino.tools.pot.api.samples.classification.classification_sample import optimize_model
|
from openvino.tools.pot.api.samples.classification.classification_sample import optimize_model
|
||||||
|
|
||||||
model = models.get(model_name, model_framework, tmp_path)
|
model = models.get(model_name, model_framework, tmp_path, custom_mo_config=custom_mo_config)
|
||||||
data_source, annotations = get_dataset_info('imagenet_1001_classes')
|
data_source, annotations = get_dataset_info('imagenet_1001_classes')
|
||||||
|
|
||||||
args = Dict({
|
args = Dict({
|
||||||
@ -167,15 +169,17 @@ def test_sample_compression(_sample_params, tmp_path, models):
|
|||||||
|
|
||||||
SIMPLIFIED_TEST_MODELS = [
|
SIMPLIFIED_TEST_MODELS = [
|
||||||
('mobilenet-v2-pytorch', 'pytorch', 'DefaultQuantization', 'performance',
|
('mobilenet-v2-pytorch', 'pytorch', 'DefaultQuantization', 'performance',
|
||||||
{'accuracy@top1': 0.701, 'accuracy@top5': 0.91})
|
{'accuracy@top1': 0.701, 'accuracy@top5': 0.91}, []),
|
||||||
|
('mobilenet-v2-pytorch', 'pytorch', 'DefaultQuantization', 'performance',
|
||||||
|
{'accuracy@top1': 0.707, 'accuracy@top5': 0.904}, ['--input_shape=[1,3,?,?]'])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def launch_simplified_mode(tmp_path, models, engine_config):
|
def launch_simplified_mode(_simplified_params, tmp_path, models, engine_config):
|
||||||
model_name, model_framework, algorithm, preset, _ = SIMPLIFIED_TEST_MODELS[0]
|
model_name, model_framework, algorithm, preset, _, custom_mo_config = _simplified_params
|
||||||
algorithm_config = make_algo_config(algorithm, preset)
|
algorithm_config = make_algo_config(algorithm, preset)
|
||||||
|
|
||||||
model = models.get(model_name, model_framework, tmp_path)
|
model = models.get(model_name, model_framework, tmp_path, custom_mo_config=custom_mo_config)
|
||||||
config = merge_configs(model.model_params, engine_config, algorithm_config)
|
config = merge_configs(model.model_params, engine_config, algorithm_config)
|
||||||
|
|
||||||
_ = optimize(config)
|
_ = optimize(config)
|
||||||
@ -205,7 +209,12 @@ def launch_simplified_mode(tmp_path, models, engine_config):
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def test_simplified_mode(tmp_path, models):
|
@pytest.fixture(scope='module', params=SIMPLIFIED_TEST_MODELS,
|
||||||
|
ids=['{}_{}_{}_{}'.format(*m) for m in SIMPLIFIED_TEST_MODELS])
|
||||||
|
def _simplified_params(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
def test_simplified_mode(_simplified_params, tmp_path, models):
|
||||||
with open(PATHS2DATASETS_CONFIG.as_posix()) as f:
|
with open(PATHS2DATASETS_CONFIG.as_posix()) as f:
|
||||||
data_source = Dict(json.load(f))['ImageNet2012'].pop('source_dir')
|
data_source = Dict(json.load(f))['ImageNet2012'].pop('source_dir')
|
||||||
|
|
||||||
@ -214,8 +223,8 @@ def test_simplified_mode(tmp_path, models):
|
|||||||
'device': 'CPU',
|
'device': 'CPU',
|
||||||
'central_fraction': 0.875})
|
'central_fraction': 0.875})
|
||||||
|
|
||||||
_, _, _, _, expected_accuracy = SIMPLIFIED_TEST_MODELS[0]
|
_, _, _, _, expected_accuracy, _ = _simplified_params
|
||||||
metrics = launch_simplified_mode(tmp_path, models, engine_config)
|
metrics = launch_simplified_mode(_simplified_params, tmp_path, models, engine_config)
|
||||||
assert metrics == pytest.approx(expected_accuracy, abs=0.006)
|
assert metrics == pytest.approx(expected_accuracy, abs=0.006)
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ class ModelStore:
|
|||||||
# load model description to self.models
|
# load model description to self.models
|
||||||
self._load_models_description()
|
self._load_models_description()
|
||||||
|
|
||||||
def get(self, name, framework, tmp_path, model_precision='FP32'):
|
def get(self, name, framework, tmp_path, model_precision='FP32', custom_mo_config=None):
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
if framework != model.framework:
|
if framework != model.framework:
|
||||||
continue
|
continue
|
||||||
@ -38,7 +38,7 @@ class ModelStore:
|
|||||||
'Couldn\'t load model {} from the framework {}'.format(model.name, model.framework))
|
'Couldn\'t load model {} from the framework {}'.format(model.name, model.framework))
|
||||||
assert omz_model_download(model) == 0,\
|
assert omz_model_download(model) == 0,\
|
||||||
'Can not download model: {}'.format(model.name)
|
'Can not download model: {}'.format(model.name)
|
||||||
convert_value = omz_model_convert(model)
|
convert_value = omz_model_convert(model, custom_mo_config)
|
||||||
assert convert_value == 0, 'Can not convert model: {}'.format(model.name)
|
assert convert_value == 0, 'Can not convert model: {}'.format(model.name)
|
||||||
model_path = tmp_path.joinpath(
|
model_path = tmp_path.joinpath(
|
||||||
model.subdirectory.as_posix(), model.precision, model.name)
|
model.subdirectory.as_posix(), model.precision, model.name)
|
||||||
|
@ -58,7 +58,7 @@ def download(config):
|
|||||||
return runner.run()
|
return runner.run()
|
||||||
|
|
||||||
|
|
||||||
def command_line_for_convert(config):
|
def command_line_for_convert(config, custom_mo_config=None):
|
||||||
python_path = DOWNLOAD_PATH.as_posix()
|
python_path = DOWNLOAD_PATH.as_posix()
|
||||||
executable = OMZ_DOWNLOADER_PATH.joinpath('converter.py').as_posix()
|
executable = OMZ_DOWNLOADER_PATH.joinpath('converter.py').as_posix()
|
||||||
cli_args = ' -o ' + config.model_params.output_dir
|
cli_args = ' -o ' + config.model_params.output_dir
|
||||||
@ -66,6 +66,9 @@ def command_line_for_convert(config):
|
|||||||
cli_args += ' --name ' + config.name
|
cli_args += ' --name ' + config.name
|
||||||
cli_args += ' --mo ' + MO_PATH.joinpath('mo.py').as_posix()
|
cli_args += ' --mo ' + MO_PATH.joinpath('mo.py').as_posix()
|
||||||
cli_args += ' --precisions ' + config.precision
|
cli_args += ' --precisions ' + config.precision
|
||||||
|
if custom_mo_config:
|
||||||
|
for custom_mo_arg in custom_mo_config:
|
||||||
|
cli_args += ' --add_mo_arg=' + custom_mo_arg
|
||||||
script_launch_cli = '{python_exe} {main_py} {args}'.format(
|
script_launch_cli = '{python_exe} {main_py} {args}'.format(
|
||||||
python_exe=sys.executable, main_py=executable, args=cli_args
|
python_exe=sys.executable, main_py=executable, args=cli_args
|
||||||
)
|
)
|
||||||
@ -77,8 +80,8 @@ def command_line_for_convert(config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert(config):
|
def convert(config, custom_mo_config=None):
|
||||||
runner = Command(command_line_for_convert(config))
|
runner = Command(command_line_for_convert(config, custom_mo_config))
|
||||||
return runner.run()
|
return runner.run()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user