Files
openvino/tests/layer_tests/conftest.py
Piotr Krzemiński 3d8a620ac3 [PT FE] Add aten::_native_multi_head_attention (#17550)
* [PT FE] Add implementation of MHA

* [PT FE] Add tests, add scaled dot product attention

* [PT FE] Fix missing transpose for Q,K,V & output Attention

* [PT FE] Formatting errors

* [PT FE] Fix testing class with nn.Linear

* [PT FE] Fix incorrect key franspose in dot product attention computation

* [PT FE] Fix incorrect matmul due to lack of transpose

* [PT FE] Enable support for all boolean masks

* [PT FE] Fix returned weights

* [PT FE] Remove debugging artifacts

* [PT FE] Remove unused nodes, optimize transpose nodes' usage, add comments to floating masks

* [PT FE] Further reduce node usage, return None instead of 0 for return_weights=false

* [PT FE] Allow for dynamic num_num_head, embed_dim

* [PT FE] Improve error comment, remove unnecessary Unsqueeze

* [PT FE] Clang format

* Update tests/layer_tests/pytorch_tests/test_native_multi_head_attention.py

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* [PT FE] Add masks comments, improve mask broadcasting

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
2023-06-05 10:55:03 +02:00

118 lines
4.0 KiB
Python

# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import re
import tempfile
from pathlib import Path
import pytest
from common import constants
def pytest_make_parametrize_id(config, val, argname):
return " {0}:{1} ".format(argname, val)
def pytest_collection_modifyitems(items):
def remove_ignored_attrs(ref_dict, dict_to_upd):
_dict_to_upd = dict_to_upd.copy()
for key, value in dict_to_upd.items():
if key not in ref_dict.keys():
_dict_to_upd.pop(key)
elif isinstance(value, dict):
_dict_to_upd[key] = remove_ignored_attrs(ref_dict[key], value)
return _dict_to_upd
for test in items:
special_marks = [mark for mark in test.own_markers if "special_" in mark.name]
for mark in special_marks:
if mark.name == "special_xfail":
params = test.callspec.params
# Remove items from params if key of item is not in mark.kwargs["args"].
# Remaining items will be used to mark test cases that contain them.
# It is required to specify in mark only valuable parameters
# (e.g. {"device": "FP16"} will mean that for all test cases with FP16 test will be marked)
params = remove_ignored_attrs(mark.kwargs["args"], params)
if mark.kwargs["args"] == params:
test.add_marker(pytest.mark.xfail(reason=mark.kwargs["reason"]))
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(item, call):
pytest_html = item.config.pluginmanager.getplugin('html')
outcome = yield
report = outcome.get_result()
extra = getattr(report, 'extra', [])
if report.when == 'call':
xfail_reason = getattr(report, 'wasxfail', None)
if report.skipped and xfail_reason:
jira_ticket_nums = re.findall(r"\*-\d+", xfail_reason)
for ticket_num in jira_ticket_nums:
extra.append(pytest_html.extras.url(ticket_num))
report.extra = extra
def pytest_addoption(parser):
"""Specify command-line options for all plugins"""
parser.addoption(
"--ir_version",
default=11,
action="store",
help="Version of IR to generate by Model Optimizer")
parser.addoption(
"--use_new_frontend",
required=False,
action="store_true",
help="Use Model Optimizer with new FrontEnd")
parser.addoption(
"--use_old_api",
action="store_true",
help="Use old API for model processing in Inference Engine",
)
parser.addoption(
"--tflite",
required=False,
action="store_true",
help="Switch to tflite tests version")
@pytest.fixture(scope="session")
def ir_version(request):
"""Fixture function for command-line option."""
return request.config.getoption('ir_version')
@pytest.fixture(scope="session")
def use_new_frontend(request):
"""Fixture function for command-line option."""
return request.config.getoption('use_new_frontend')
@pytest.fixture(scope="session")
def use_old_api(request):
"""Fixture function for command-line option."""
return request.config.getoption('use_old_api')
@pytest.fixture(scope="session")
def tflite(request):
"""Fixture function for command-line option."""
return request.config.getoption('tflite')
@pytest.fixture(scope="session", autouse=True)
def checks_for_keys_usage(request):
if request.config.getoption('use_old_api') and request.config.getoption('use_new_frontend'):
pytest.fail("Old API and new FrontEnd usage detected. Old API doesn't support new FrontEnd")
@pytest.fixture(scope="function")
def temp_dir(request):
"""Create directory for test purposes."""
Path(constants.out_path).mkdir(parents=True, exist_ok=True)
test_name = re.sub(r"[^\w_]", "_", request.node.originalname)
device = request.node.funcargs["ie_device"].upper()
temp_dir = tempfile.TemporaryDirectory(dir=constants.out_path, prefix=f"{device}_{test_name}")
yield str(temp_dir.name)