[POT] Implement DataFreeEngine (#9484)

* [POT] Implement DataFreeEngine

* Add CLI

* Updated CLI

* Moved logic to SynteticImageLoader

* Fix bug with draw modes

* Fix bug in DataFreeEngine

* Fix multiprocessing

* Fix pylint

* Add DataFreeEngine test

* Download models

* Fill background

* Fix test

* Fix args

* Support config option for DataFree mode

* Minor fixes

* Add data_free config

* Add more test cases

* Enable RCNN models quantization
This commit is contained in:
Liubov Talamanova 2022-02-01 15:15:20 +03:00 committed by GitHub
parent 09f53b56e6
commit ca09ddd123
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 544 additions and 16 deletions

View File

@ -0,0 +1,34 @@
{
"model": {
"model_name": "model_name", // Model name
"model": "<MODEL_PATH>", // Path to model (.xml format)
"weights": "<PATH_TO_WEIGHTS>" // Path to weights (.bin format)
},
"engine": {
"type": "data_free", // Engine type
"generate_data": "True", // (Optional) If True, generate synthetic data and store to `data_source`
// Otherwise, the dataset from `--data-source` will be used'
"layout": "NCHW", // (Optional) Layout of input data. Supported: ["NCHW", "NHWC", "CHW", "CWH"]
"shape": "[None, None, None, None]", // (Optional) if model has dynamic shapes, input shapes must be provided
"data_type": "image", // (Optional) You can specify the type of data to be generated.
// Currently only `image` is supported.
// It is planned to add 'text` and 'audio' cases
"data_source": "PATH_TO_SOURCE" // (Optional) You can specify path to directory
// where synthetic dataset is located or will be generated and saved
},
"compression": {
"algorithms": [
{
"name": "DefaultQuantization", // Optimization algorithm name
"params": {
"preset": "performance", // Preset [performance, mixed, accuracy] which control the quantization
// mode (symmetric, mixed (weights symmetric and activations asymmetric)
// and fully asymmetric respectively)
"stat_subset_size": 300 // Size of subset to calculate activations statistics that can be used
// for quantization parameters calculation
}
}
]
}
}

View File

@ -13,7 +13,8 @@ def get_common_argument_parser():
parser.add_argument( parser.add_argument(
'-c', '-c',
'--config', '--config',
help='Path to a config file with optimization parameters. Overrides "-q | -m | -w | --ac-config" options') help='Path to a config file with optimization parameters. '
'Overrides "-q | -m | -w | --ac-config | --engine" options')
parser.add_argument( parser.add_argument(
'-q', '-q',
@ -47,6 +48,12 @@ def get_common_argument_parser():
type=str, type=str,
help='Model name. Applicable only when -q option is used.') help='Model name. Applicable only when -q option is used.')
parser.add_argument(
'--engine',
choices=['accuracy_checker', 'data_free', 'simplified'],
type=str,
help='Engine type. Default: `accuracy_checker`')
parser.add_argument( parser.add_argument(
'--ac-config', '--ac-config',
type=str, type=str,
@ -105,6 +112,37 @@ def get_common_argument_parser():
default=False, default=False,
help='Keep Convolution, Deconvolution and FullyConnected weights uncompressed') help='Keep Convolution, Deconvolution and FullyConnected weights uncompressed')
data_free_opt = parser.add_argument_group('DataFreeEngine options')
data_free_opt.add_argument(
'--data-source',
default='../../../pot_dataset',
help='Path to directory where synthetic dataset is located or will be generated and saved. '
'Default: `../../../pot_dataset`')
data_free_opt.add_argument(
'--shape',
type=str,
help='Required for models with dynamic shapes. '
'Input shape that should be fed to an input node of the model. '
'Shape is defined as a comma-separated list of integer numbers enclosed in '
'parentheses or square brackets, for example [1,3,227,227] or (1,227,227,3), where '
'the order of dimensions depends on the framework input layout of the model.')
data_free_opt.add_argument(
'--data-type',
type=str,
default='image',
choices=['image'],
help='Type of data for generation. Dafault: `image`')
data_free_opt.add_argument(
'--generate-data',
action='store_true',
default=False,
help='If specified, generate synthetic data and store to `data-source`. '
'Otherwise, the dataset from `--data-source` will be used')
return parser return parser
@ -112,7 +150,7 @@ def check_dependencies(args):
if (args.quantize is not None and if (args.quantize is not None and
(args.model is None or (args.model is None or
args.weights is None or args.weights is None or
args.ac_config is None)): args.ac_config is None and args.engine != 'data_free')):
raise ValueError( raise ValueError(
'--quantize option requires model, weights, and AC config to be specified.') '--quantize option requires model, weights, and AC config to be specified.')
if args.quantize is None and args.config is None: if args.quantize is None and args.config is None:
@ -122,6 +160,8 @@ def check_dependencies(args):
raise ValueError('Either --config or --quantize option should be specified') raise ValueError('Either --config or --quantize option should be specified')
if args.quantize == 'accuracy_aware' and args.max_drop is None: if args.quantize == 'accuracy_aware' and args.max_drop is None:
raise ValueError('For AccuracyAwareQuantization --max-drop should be specified') raise ValueError('For AccuracyAwareQuantization --max-drop should be specified')
if args.engine == 'data_free' and args.ac_config is not None:
raise ValueError('Either DataFree mode or AC config should be specified')
check_extra_arguments(args, 'model') check_extra_arguments(args, 'model')
check_extra_arguments(args, 'weights') check_extra_arguments(args, 'weights')
check_extra_arguments(args, 'preset') check_extra_arguments(args, 'preset')

View File

@ -35,11 +35,17 @@ def app(argv):
_update_config_path(args) _update_config_path(args)
config = Config.read_config(args.config) config = Config.read_config(args.config)
if args.engine:
config.engine['type'] = args.engine if args.engine else 'accuracy_checker'
if 'data_source' not in config.engine:
config.engine['data_source'] = args.data_source
config.configure_params(args.ac_config) config.configure_params(args.ac_config)
config.update_from_args(args) config.update_from_args(args)
if config.engine.type == 'simplified' and args.evaluate: if config.engine.type != 'accuracy_checker' and args.evaluate:
raise Exception('Can not make evaluation in simplified mode') raise Exception('Can not make evaluation in simplified or data_free mode')
log_dir = _create_log_path(config) log_dir = _create_log_path(config)
init_logger(level=args.log_level, init_logger(level=args.log_level,

View File

@ -63,6 +63,19 @@ class Config(Dict):
self.model['output_dir'] = args.output_dir self.model['output_dir'] = args.output_dir
self.model['direct_dump'] = args.direct_dump self.model['direct_dump'] = args.direct_dump
self.engine['evaluate'] = args.evaluate self.engine['evaluate'] = args.evaluate
if self.engine.type == 'data_free':
if 'data_type' not in self.engine:
self.engine['data_type'] = args.data_type
if 'generate_data' not in self.engine:
self.engine['generate_data'] = args.generate_data
if 'shape' not in self.engine:
self.engine['shape'] = args.shape
if self.engine['generate_data']:
subset_size = 0
for algo in self.compression['algorithms']:
subset_size = max(subset_size, algo.get('stat_subset_size', 300))
self.engine['subset_size'] = subset_size
self.model['keep_uncompressed_weights'] = args.keep_uncompressed_weights self.model['keep_uncompressed_weights'] = args.keep_uncompressed_weights
if 'optimizer' in self: if 'optimizer' in self:
self.optimizer.params['keep_uncompressed_weights'] = args.keep_uncompressed_weights self.optimizer.params['keep_uncompressed_weights'] = args.keep_uncompressed_weights
@ -295,9 +308,9 @@ class Config(Dict):
if 'type' not in engine or engine.type == 'accuracy_checker': if 'type' not in engine or engine.type == 'accuracy_checker':
self._configure_ac_params() self._configure_ac_params()
self.engine.type = 'accuracy_checker' self.engine.type = 'accuracy_checker'
elif engine.type == 'simplified': elif engine.type == 'simplified' or engine.type == 'data_free':
if 'data_source' not in engine: if 'data_source' not in engine:
raise KeyError('Missed data dir for sample engine') raise KeyError(f'Missed data dir for {engine.type} engine')
self.engine.device = engine.device if engine.device else 'CPU' self.engine.device = engine.device if engine.device else 'CPU'
engine.data_source = Path(engine.data_source) engine.data_source = Path(engine.data_source)
else: else:

View File

@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from openvino.tools.pot.data_loaders.image_loader import ImageLoader from openvino.tools.pot.data_loaders.image_loader import ImageLoader
from openvino.tools.pot.data_loaders.synthetic_image_loader import SyntheticImageLoader
from openvino.tools.pot.graph.model_utils import get_nodes_by_type from openvino.tools.pot.graph.model_utils import get_nodes_by_type
@ -24,9 +25,16 @@ def create_data_loader(config, model):
data_loader = None data_loader = None
for in_node in inputs: for in_node in inputs:
if tuple(in_node.shape) != (1, 3): if tuple(in_node.shape) != (1, 3):
data_loader = ImageLoader(config) if config.type == 'simplified':
data_loader.shape = in_node.shape data_loader = ImageLoader(config)
data_loader.get_layout(in_node) data_loader.shape = in_node.shape
data_loader.get_layout(in_node)
elif config.type == 'data_free':
if not config.shape:
config.shape = in_node.shape
if not config.layout:
config.layout = in_node.graph.graph.get('layout', None)
data_loader = SyntheticImageLoader(config)
return data_loader return data_loader
if data_loader is None: if data_loader is None:

View File

@ -33,25 +33,34 @@ class ImageLoader(DataLoader):
self._shape = tuple(shape) self._shape = tuple(shape)
def _read_and_preproc_image(self, img_path): def _read_and_preproc_image(self, img_path):
C = self._layout.get_index_by_name('C')
H = self._layout.get_index_by_name('H')
W = self._layout.get_index_by_name('W')
image = imread(img_path, IMREAD_GRAYSCALE)\ image = imread(img_path, IMREAD_GRAYSCALE)\
if self._shape[1] == 1 else imread(img_path) if self._shape[C] == 1 else imread(img_path)
if image is None: if image is None:
raise Exception('Can not read the image: {}'.format(img_path)) raise Exception('Can not read the image: {}'.format(img_path))
return prepare_image(image, self._layout, self.shape[-2:], self._crop_central_fraction) return prepare_image(image, self._layout, (self.shape[H], self.shape[W]), self._crop_central_fraction)
def get_layout(self, input_node): def get_layout(self, input_node=None):
if self._layout is not None: if self._layout is not None:
if 'C' not in self._layout or 'H' not in self._layout or 'W' not in self._layout: if 'C' not in self._layout or 'H' not in self._layout or 'W' not in self._layout:
raise ValueError('Unexpected {} layout'.format(self._layout)) raise ValueError('Unexpected {} layout'.format(self._layout))
if self._shape is not None and 'N' in self._layout and len(self._shape) == 3:
self._layout = self._layout[1:]
self._layout = Layout(self._layout) self._layout = Layout(self._layout)
return return
layout_from_ir = input_node.graph.graph.get('layout', None) if input_node:
if layout_from_ir is not None: layout_from_ir = input_node.graph.graph.get('layout', None)
self._layout = Layout(layout_from_ir) if layout_from_ir is not None:
return if self._shape is not None and 'N' in layout_from_ir and len(self._shape) == 3:
layout_from_ir = layout_from_ir[1:]
self._layout = Layout(layout_from_ir)
return
image_colors_dim = (Dimension(3), Dimension(1)) image_colors_dim = (Dimension(3), Dimension(1))
num_dims = len(self._shape) num_dims = len(self._shape)

View File

@ -0,0 +1,327 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from multiprocessing import Pool
from pathlib import Path
import os
import re
import requests
import cv2 as cv
import numpy as np
from openvino.runtime import Layout # pylint: disable=E0611,E0401
from openvino.tools.pot.utils.logger import get_logger
from openvino.tools.pot.data_loaders.image_loader import ImageLoader
from .utils import collect_img_files
logger = get_logger(__name__)
class IFSFunction:
def __init__(self, prev_x, prev_y):
self.function = []
self.xs, self.ys = [prev_x], [prev_y]
self.select_function = []
self.cum_proba = 0.0
def set_param(self, params, proba, weights=None):
if weights is not None:
params = list(np.array(params) * np.array(weights))
self.function.append(params)
self.cum_proba += proba
self.select_function.append(self.cum_proba)
def calculate(self, iteration):
rand = np.random.random(iteration)
prev_x, prev_y = 0, 0
next_x, next_y = 0, 0
for i in range(iteration):
for func_params, select_func in zip(self.function, self.select_function):
a, b, c, d, e, f = func_params
if rand[i] <= select_func:
next_x = prev_x * a + prev_y * b + e
next_y = prev_x * c + prev_y * d + f
break
self.xs.append(next_x)
self.ys.append(next_y)
prev_x = next_x
prev_y = next_y
@staticmethod
def process_nans(data):
nan_index = np.where(np.isnan(data))
extend = np.array(range(nan_index[0][0] - 100, nan_index[0][0]))
delete_row = np.append(extend, nan_index)
return delete_row
def rescale(self, image_x, image_y, pad_x, pad_y):
xs = np.array(self.xs)
ys = np.array(self.ys)
if np.any(np.isnan(xs)):
delete_row = self.process_nans(xs)
xs = np.delete(xs, delete_row, axis=0)
ys = np.delete(ys, delete_row, axis=0)
if np.any(np.isnan(ys)):
delete_row = self.process_nans(ys)
xs = np.delete(xs, delete_row, axis=0)
ys = np.delete(ys, delete_row, axis=0)
if np.min(xs) < 0.0:
xs -= np.min(xs)
if np.min(ys) < 0.0:
ys -= np.min(ys)
xmax, xmin = np.max(xs), np.min(xs)
ymax, ymin = np.max(ys), np.min(ys)
self.xs = np.uint16(xs / (xmax - xmin) * (image_x - 2 * pad_x) + pad_x)
self.ys = np.uint16(ys / (ymax - ymin) * (image_y - 2 * pad_y) + pad_y)
def draw(self, draw_type, image_x, image_y, pad_x=6, pad_y=6):
self.rescale(image_x, image_y, pad_x, pad_y)
image = np.zeros((image_x, image_y), dtype=np.uint8)
for i in range(len(self.xs)):
if draw_type == 'point':
image[self.ys[i], self.xs[i]] = 127
else:
mask = '{:09b}'.format(np.random.randint(1, 512))
patch = 127 * np.array(list(map(int, list(mask))), dtype=np.uint8).reshape(3, 3)
x_start = self.xs[i] + 1
y_start = self.ys[i] + 1
image[x_start:x_start+3, y_start:y_start+3] = patch
return image
class SyntheticImageLoader(ImageLoader):
def __init__(self, config):
super().__init__(config)
np.random.seed(seed=1)
self.subset_size = config.get('subset_size', 300)
self._cpu_count = min(os.cpu_count(), self.subset_size)
self._shape = config.get('shape', None)
self.data_source = config.get('data_source', None)
self._weights = np.array([
0.2, 1, 1, 1, 1, 1,
0.6, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1.4, 1, 1, 1, 1, 1,
1.8, 1, 1, 1, 1, 1,
1, 0.2, 1, 1, 1, 1,
1, 0.6, 1, 1, 1, 1,
1, 1.4, 1, 1, 1, 1,
1, 1.8, 1, 1, 1, 1,
1, 1, 0.2, 1, 1, 1,
1, 1, 0.6, 1, 1, 1,
1, 1, 1.4, 1, 1, 1,
1, 1, 1.8, 1, 1, 1,
1, 1, 1, 0.2, 1, 1,
1, 1, 1, 0.6, 1, 1,
1, 1, 1, 1.4, 1, 1,
1, 1, 1, 1.8, 1, 1,
1, 1, 1, 1, 0.2, 1,
1, 1, 1, 1, 0.6, 1,
1, 1, 1, 1, 1.4, 1,
1, 1, 1, 1, 1.8, 1,
1, 1, 1, 1, 1, 0.2,
1, 1, 1, 1, 1, 0.6,
1, 1, 1, 1, 1, 1.4,
1, 1, 1, 1, 1, 1.8,
]).reshape(-1, 6)
self._threshold = 0.2
self._iterations = 200000
self._num_of_points = None
self._instances = None
self._categories = None
if isinstance(self._shape, str):
self._shape = list(map(int, re.findall(r'\d+', self._shape)))
super().get_layout()
self._check_input_shape()
if os.path.exists(self.data_source) and os.listdir(self.data_source) and not config.generate_data:
logger.info(f'Dataset was found in `{self.data_source}`')
else:
logger.info(f'Synthetic dataset will be stored in `{self.data_source}`')
if not os.path.exists(self.data_source):
os.mkdir(self.data_source)
assert os.path.isdir(self.data_source)
if config.generate_data or not os.listdir(self.data_source):
self._download_colorization_model()
logger.info(f'Start generating {self.subset_size} synthetic images')
self.generate_dataset()
self._img_files = collect_img_files(self.data_source)
def _check_input_shape(self):
if self._shape is None:
raise ValueError('Input shape should be specified. Please, use `--shape`')
if len(self._shape) < 3 or len(self._shape) > 4:
raise ValueError(f'Input shape should have 3 or 4 dimensions, but provided {self._shape}')
if self._shape[self._layout.get_index_by_name('C')] != 3:
raise ValueError('SyntheticImageLoader can generate images with only channels == 3')
def _download_colorization_model(self):
proto_name = 'colorization_deploy_v2.prototxt'
model_name = 'colorization_release_v2.caffemodel'
npy_name = 'pts_in_hull.npy'
if not os.path.exists(proto_name):
url = 'https://raw.githubusercontent.com/richzhang/colorization/caffe/colorization/models/'
proto = requests.get(url + proto_name)
open(proto_name, 'wb').write(proto.content)
if not os.path.exists(model_name):
url = 'http://eecs.berkeley.edu/~rich.zhang/projects/2016_colorization/files/demo_v2/'
model = requests.get(url + model_name)
open(model_name, 'wb').write(model.content)
if not os.path.exists(npy_name):
url = 'https://github.com/richzhang/colorization/raw/caffe/colorization/resources/'
pts_in_hull = requests.get(url + npy_name)
open(npy_name, 'wb').write(pts_in_hull.content)
def _initialize_params(self, height, width):
default_img_size = 362 * 362
points_coeff = max(1, int(np.round(height * width / default_img_size)))
self._num_of_points = 100000 * points_coeff
if self.subset_size < len(self._weights):
self._instances = 1
self._categories = 1
self._weights = self._weights[:self.subset_size, :]
else:
self._instances = np.ceil(0.25 * self.subset_size / self._weights.shape[0]).astype(int)
self._categories = np.ceil(self.subset_size / (self._instances * self._weights.shape[0])).astype(int)
def generate_dataset(self):
height = self._shape[self._layout.get_index_by_name('H')]
width = self._shape[self._layout.get_index_by_name('W')]
self._initialize_params(height, width)
# to avoid multiprocessing error: can't pickle openvino.pyopenvino.Layout objects
self._layout = str(self._layout)
with Pool(processes=self._cpu_count) as pool:
params = pool.map(self._generate_category, [1e-5] * self._categories)
instances_weights = np.repeat(self._weights, self._instances, axis=0)
weight_per_img = np.tile(instances_weights, (self._categories, 1))
repeated_params = np.repeat(params, self._weights.shape[0] * self._instances, axis=0)
repeated_params = repeated_params[:self.subset_size]
weight_per_img = weight_per_img[:self.subset_size]
assert weight_per_img.shape[0] == len(repeated_params) == self.subset_size
splits = min(self._cpu_count, self.subset_size)
params_per_proc = np.array_split(repeated_params, splits)
weights_per_proc = np.array_split(weight_per_img, splits)
generation_params = []
offset = 0
for param, w in zip(params_per_proc, weights_per_proc):
indices = list(range(offset, offset + len(param)))
offset += len(param)
generation_params.append((param, w, height, width, indices))
with Pool(processes=self._cpu_count) as pool:
pool.starmap(self._generate_image_batch, generation_params)
self._layout = Layout(self._layout)
def _generate_image_batch(self, params, weights, height, width, indices):
pts_in_hull = np.load('pts_in_hull.npy').transpose().reshape(2, 313, 1, 1).astype(np.float32)
net = cv.dnn.readNetFromCaffe('colorization_deploy_v2.prototxt', 'colorization_release_v2.caffemodel')
net.getLayer(net.getLayerId('class8_ab')).blobs = [pts_in_hull]
net.getLayer(net.getLayerId('conv8_313_rh')).blobs = [np.full([1, 313], 2.606, np.float32)]
for i, param, weight in zip(indices, params, weights):
image = self._generator(param, 'gray', self._iterations, height, width, weight)
color_image = self._colorize(image, net)
aug_image = self._augment(color_image)
cv.imwrite(os.path.join(self.data_source, "{:06d}.png".format(i)), aug_image)
@staticmethod
def _generator(params, draw_type, iterations, height=512, width=512, weight=None):
generators = IFSFunction(prev_x=0.0, prev_y=0.0)
for param in params:
generators.set_param(param[:6], param[6], weight)
generators.calculate(iterations)
img = generators.draw(draw_type, height, width)
return img
def _generate_category(self, eps, height=512, width=512):
pixels = -1
while pixels < self._threshold:
param_size = np.random.randint(2, 8)
params = np.zeros((param_size, 7), dtype=np.float32)
sum_proba = eps
for i in range(param_size):
a, b, c, d, e, f = np.random.uniform(-1.0, 1.0, 6)
prob = abs(a * d - b * c)
sum_proba += prob
params[i] = a, b, c, d, e, f, prob
params[:, 6] /= sum_proba
fracral_img = self._generator(params, 'point', self._num_of_points, height, width)
pixels = np.count_nonzero(fracral_img) / (height * width)
return params
@staticmethod
def _rgb2lab(frame):
y_coeffs = np.array([0.212671, 0.715160, 0.072169], dtype=np.float32)
frame = np.where(frame > 0.04045, np.power((frame + 0.055) / 1.055, 2.4), frame / 12.92)
y = frame @ y_coeffs.T
L = np.where(y > 0.008856, 116 * np.cbrt(y) - 16, 903.3 * y)
return L
def _colorize(self, frame, net):
H_orig, W_orig = frame.shape[:2] # original image size
if len(frame.shape) == 2 or frame.shape[-1] == 1:
frame = np.tile(frame.reshape(H_orig, W_orig, 1), (1, 1, 3))
frame = frame.astype(np.float32) / 255
img_l = self._rgb2lab(frame) # get L from Lab image
img_rs = cv.resize(img_l, (224, 224)) # resize image to network input size
img_l_rs = img_rs - 50 # subtract 50 for mean-centering
net.setInput(cv.dnn.blobFromImage(img_l_rs))
ab_dec = net.forward()[0, :, :, :].transpose((1, 2, 0))
ab_dec_us = cv.resize(ab_dec, (W_orig, H_orig))
img_lab_out = np.concatenate((img_l[..., np.newaxis], ab_dec_us), axis=2) # concatenate with original image L
img_bgr_out = np.clip(cv.cvtColor(img_lab_out, cv.COLOR_Lab2BGR), 0, 1)
frame_normed = 255 * (img_bgr_out - img_bgr_out.min()) / (img_bgr_out.max() - img_bgr_out.min())
frame_normed = np.array(frame_normed, dtype=np.uint8)
return cv.resize(frame_normed, (H_orig, W_orig))
def _augment(self, image):
if np.random.random(1) >= 0.5:
image = cv.flip(image, 1)
if np.random.random(1) >= 0.5:
image = cv.flip(image, 0)
height, width = image.shape[:2]
angle = np.random.uniform(-30, 30)
rotate_matrix = cv.getRotationMatrix2D(center=(width / 2, height / 2), angle=angle, scale=1)
image = cv.warpAffine(src=image, M=rotate_matrix, dsize=(width, height))
image = self._fill_background(image)
k_size = np.random.choice(list(range(1, 16, 2)))
image = cv.GaussianBlur(image, (k_size, k_size), 0)
return image
@staticmethod
def _fill_background(image):
synthetic_background = Path(__file__).parent / 'synthetic_background.npy'
imagenet_means = np.load(synthetic_background)
class_id = np.random.randint(0, imagenet_means.shape[0])
rows, cols = np.where(~np.any(image, axis=-1)) # background color = [0, 0, 0]
image[rows, cols] = imagenet_means[class_id]
return image

View File

@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from openvino.tools.pot.engines.ac_engine import ACEngine from openvino.tools.pot.engines.ac_engine import ACEngine
from openvino.tools.pot.engines.data_free_engine import DataFreeEngine
from openvino.tools.pot.engines.simplified_engine import SimplifiedEngine from openvino.tools.pot.engines.simplified_engine import SimplifiedEngine
@ -16,4 +17,6 @@ def create_engine(config, **kwargs):
return ACEngine(config) return ACEngine(config)
if config.type == 'simplified': if config.type == 'simplified':
return SimplifiedEngine(config, **kwargs) return SimplifiedEngine(config, **kwargs)
if config.type == 'data_free':
return DataFreeEngine(config, **kwargs)
raise RuntimeError('Unsupported engine type') raise RuntimeError('Unsupported engine type')

View File

@ -0,0 +1,19 @@
# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.pot.data_loaders.synthetic_image_loader import SyntheticImageLoader
from openvino.tools.pot.engines.simplified_engine import SimplifiedEngine
class DataFreeEngine(SimplifiedEngine):
def __init__(self, config, data_loader=None, metric=None):
super().__init__(config)
if not data_loader:
self._data_loader = self.get_data_loader(config)
else:
self._data_loader = data_loader
def get_data_loader(self, config):
if config.data_type == 'image':
return SyntheticImageLoader(config)
raise NotImplementedError("Currently data-free optimization is available for Computer Vision models only")

View File

@ -0,0 +1,51 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
from addict import Dict
import pytest
from openvino.tools.pot.data_loaders.creator import create_data_loader
from openvino.tools.pot.graph import load_model
from openvino.tools.pot.graph.model_utils import get_nodes_by_type
TEST_MODELS = [
('mobilenet-v2-pytorch', 'pytorch', None, None),
('mobilenet-v2-pytorch', 'pytorch', None, (3, 640, 720)),
('mobilenet-v2-pytorch', 'pytorch', 'HWC', (224, 224, 3)),
('mobilenet-v2-pytorch', 'pytorch', 'NHWC', (1, 224, 224, 3)),
('mobilenet-v2-pytorch', 'pytorch', 'CHW', (3, 224, 224)),
('mobilenet-v2-pytorch', 'pytorch', 'NCHW', (1, 3, 224, 224)),
]
@pytest.mark.parametrize(
'model_name, model_framework, layout, input_shape', TEST_MODELS,
ids=['{}_{}_{}_{}'.format(m[0], m[1], m[2], m[3]) for m in TEST_MODELS])
def test_generate_image(tmp_path, models, model_name, model_framework, layout, input_shape):
path_image_data = os.path.join(tmp_path, 'pot_dataset')
stat_subset_size = 5
engine_config = Dict({'device': 'CPU',
'type': 'data_free',
'data_source': path_image_data,
'subset_size': stat_subset_size,
'layout': layout,
'shape': input_shape,
'generate_data': 'True'})
model = models.get(model_name, model_framework, tmp_path)
model = load_model(model.model_params)
data_loader = create_data_loader(engine_config, model)
num_images_from_data_loader = len(list(data_loader))
num_images_in_dir = len(os.listdir(path_image_data))
assert num_images_from_data_loader == num_images_in_dir == stat_subset_size
image = data_loader[0]
if input_shape is None:
in_node = get_nodes_by_type(model, ['Parameter'], recursively=False)[0]
input_shape = tuple(in_node.shape[1:])
elif len(input_shape) == 4:
input_shape = input_shape[1:]
assert image.shape == input_shape

View File

@ -221,6 +221,24 @@ def test_simplified_mode(tmp_path, models):
assert metrics == pytest.approx(expected_accuracy, abs=0.006) assert metrics == pytest.approx(expected_accuracy, abs=0.006)
DATAFREE_TEST_MODELS = [
('mobilenet-v2-pytorch', 'pytorch', 'DefaultQuantization', 'performance',
{'accuracy@top1': 0.679, 'accuracy@top5': 0.888})
]
def test_datafree_mode(tmp_path, models):
engine_config = Dict({'type': 'data_free',
'data_source': os.path.join(tmp_path, 'pot_dataset'),
'generate_data': 'True',
'subset_size': 30,
'device': 'CPU'})
_, _, _, _, expected_accuracy = DATAFREE_TEST_MODELS[0]
metrics = launch_simplified_mode(tmp_path, models, engine_config)
assert metrics == pytest.approx(expected_accuracy, abs=0.06)
def test_frame_extractor_tool(): def test_frame_extractor_tool():
# hack due to strange python imports (same as in sample test) # hack due to strange python imports (same as in sample test)
pot_dir = Path(__file__).parent.parent pot_dir = Path(__file__).parent.parent