Files
openvino/tests/stress_tests/scripts/compare_memcheck_2_runs.py
Ilya Churaev 0c9abf43a9 Updated copyright headers (#15124)
* Updated copyright headers

* Revert "Fixed linker warnings in docs snippets on Windows (#15119)"

This reverts commit 372699ec49.
2023-01-16 11:02:17 +04:00

208 lines
8.6 KiB
Python

#!/usr/bin/env python3
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""
Create comparison table based on MemCheckTests results from 2 runs
Usage: ./scrips/compare_memcheck_2_runs.py cur_source ref_source \
--db_collection collection_name --out_file file_name
"""
# pylint:disable=line-too-long
import argparse
import json
import logging as log
import os
import sys
from collections import OrderedDict
from glob import glob
from operator import itemgetter
from pathlib import Path
from pymongo import MongoClient
# Database arguments
from memcheck_upload import DATABASE, DB_COLLECTIONS
from memcheck_upload import create_memcheck_records
class HashableDict(dict):
"""Dictionary class with defined __hash__ to make it hashable
(e.g. use as key in another dictionary)"""
def __hash__(self):
return hash(tuple(sorted(self.items())))
def get_db_memcheck_records(query, db_collection, db_name, db_url):
"""Request MemCheckTests records from database by provided query"""
client = MongoClient(db_url)
collection = client[db_name][db_collection]
items = list(collection.find(query))
return items
def get_memcheck_records(source, db_collection=None, db_name=None, db_url=None):
"""provide MemCheckTests records"""
if os.path.isdir(source):
logs = list(glob(os.path.join(source, '**', '*.log'), recursive=True))
items = create_memcheck_records(logs, build_url=None, artifact_root=source)
else:
assert db_collection and db_name and db_url
query = json.loads(source)
items = get_db_memcheck_records(query, db_collection, db_name, db_url)
return items
def compare_memcheck_2_runs(cur_values, references, output_file=None):
"""Compares 2 MemCheckTests runs and prepares a report on specified path"""
import pandas # pylint:disable=import-outside-toplevel
from scipy.stats import gmean # pylint:disable=import-outside-toplevel
returncode = 0
# constants
metric_name_template = "{} {}"
GEOMEAN_THRESHOLD = 0.9
# Fields should be presented in both `references` and `cur_values`.
# Some of metrics may be missing for one of `references` and `cur_values`.
# Report will contain data with order defined in `required_fields` and `required_metrics`
required_fields = [
# "metrics" should be excluded because it will be handled automatically
"model", "device", "test_name"
]
required_metrics = [
"vmrss", "vmhwm",
# "vmsize", "vmpeak" # temporarily disabled as unused
]
# `Ops` is a template applied for every metric defined in `required_metrics`
ops = OrderedDict([
# x means ref, y means cur
("ref", lambda x, y: x),
("cur", lambda x, y: y),
("cur-ref", lambda x, y: y - x if (x is not None and y is not None) else None),
("ref/cur", lambda x, y: x / y if (x is not None and y is not None) else None)
])
# `Comparison_ops` is a template applied for metrics columns
# generated by applied `ops` to propagate status of function
comparison_ops = {
# format: {metric_col_name: (operation, message)}
metric_name_template.format("vmrss", "ref/cur"):
lambda x: (gmean(x) > GEOMEAN_THRESHOLD,
"geomean={} is less than threshold={}".format(gmean(x), GEOMEAN_THRESHOLD)),
metric_name_template.format("vmhwm", "ref/cur"):
lambda x: (gmean(x) > GEOMEAN_THRESHOLD,
"geomean={} is less than threshold={}".format(gmean(x), GEOMEAN_THRESHOLD))
}
filtered_refs = {}
for record in references:
filtered_rec = {key: val for key, val in record.items() if key in required_fields}
filtered_rec_metrics = {key: val for key, val in record["metrics"].items() if key in required_metrics}
filtered_refs[HashableDict(filtered_rec)] = filtered_rec_metrics
filtered_cur_val = {}
for record in cur_values:
filtered_rec = {key: val for key, val in record.items() if key in required_fields}
filtered_rec_metrics = {key: val for key, val in record["metrics"].items() if key in required_metrics}
filtered_cur_val[HashableDict(filtered_rec)] = filtered_rec_metrics
comparison_data = []
for data in [filtered_refs, filtered_cur_val]:
for record in data:
rec = OrderedDict()
for field in required_fields:
rec.update({field: record[field]})
rec.move_to_end(field)
if rec not in comparison_data:
# Comparison data should contain unique records combined from references and current values
comparison_data.append(rec)
comparison_data = sorted(comparison_data, key=itemgetter("model"))
for record in comparison_data:
metrics_rec = OrderedDict()
for metric in required_metrics:
ref = filtered_refs.get(HashableDict(record), {}).get(metric, None)
cur = filtered_cur_val.get(HashableDict(record), {}).get(metric, None)
for op_name, op in ops.items():
op_res = op(ref, cur)
metric_name = metric_name_template.format(metric, op_name)
metrics_rec.update({metric_name: op_res})
metrics_rec.move_to_end(metric_name)
# update `comparison_data` with metrics
for metric_name, op_res in metrics_rec.items():
record.update({metric_name: op_res})
record.move_to_end(metric_name)
# compare data using `comparison_ops`
orig_data = pandas.DataFrame(comparison_data)
data = orig_data.dropna()
devices = data["device"].unique()
for device in devices:
frame = data[data["device"] == device]
for field, comp_op in comparison_ops.items():
status, msg = comp_op(frame.loc[:, field])
if not status:
log.error('Comparison for field="%s" for device="%s" failed: %s', field, device, msg)
returncode = 1
# dump data to file
if output_file:
if os.path.splitext(output_file)[1] == ".html":
orig_data.to_html(output_file)
else:
orig_data.to_csv(output_file)
log.info('Created memcheck comparison report %s', output_file)
return returncode
def cli_parser():
"""parse command-line arguments"""
parser = argparse.ArgumentParser(description='Compare 2 runs of MemCheckTests')
parser.add_argument('cur_source',
help='Source of current values of MemCheckTests. '
'Should contain path to a folder with logs or '
'JSON-format query to request data from DB.')
parser.add_argument('ref_source',
help='Source of reference values of MemCheckTests. '
'Should contain path to a folder with logs or '
'JSON-format query to request data from DB.')
parser.add_argument('--db_url',
help='MongoDB URL in a for "mongodb://server:port".')
parser.add_argument('--db_collection',
help=f'Collection name in "{DATABASE}" database to query'
f' data using current source.',
choices=DB_COLLECTIONS)
parser.add_argument('--ref_db_collection',
help=f'Collection name in "{DATABASE}" database to query'
f' data using reference source.',
choices=DB_COLLECTIONS)
parser.add_argument('--out_file', dest='output_file', type=Path,
help='Path to a file (with name) to save results. '
'Example: /home/.../file.csv')
args = parser.parse_args()
missed_args = []
if not (os.path.isdir(args.cur_source) and os.path.isdir(args.ref_source)) and not args.db_url:
missed_args.append("--db_url")
if not os.path.isdir(args.cur_source) and not args.db_collection:
missed_args.append("--db_collection")
if not os.path.isdir(args.ref_source) and not args.ref_db_collection:
missed_args.append("--ref_db_collection")
if missed_args:
raise argparse.ArgumentError("Arguments {} are required".format(",".join(missed_args)))
return args
if __name__ == "__main__":
args = cli_parser()
references = get_memcheck_records(args.ref_source, args.ref_db_collection, DATABASE, args.db_url)
cur_values = get_memcheck_records(args.cur_source, args.db_collection, DATABASE, args.db_url)
exit_code = compare_memcheck_2_runs(cur_values, references, output_file=args.output_file)
sys.exit(exit_code)