* [PT FE] Add implementation of MHA * [PT FE] Add tests, add scaled dot product attention * [PT FE] Fix missing transpose for Q,K,V & output Attention * [PT FE] Formatting errors * [PT FE] Fix testing class with nn.Linear * [PT FE] Fix incorrect key franspose in dot product attention computation * [PT FE] Fix incorrect matmul due to lack of transpose * [PT FE] Enable support for all boolean masks * [PT FE] Fix returned weights * [PT FE] Remove debugging artifacts * [PT FE] Remove unused nodes, optimize transpose nodes' usage, add comments to floating masks * [PT FE] Further reduce node usage, return None instead of 0 for return_weights=false * [PT FE] Allow for dynamic num_num_head, embed_dim * [PT FE] Improve error comment, remove unnecessary Unsqueeze * [PT FE] Clang format * Update tests/layer_tests/pytorch_tests/test_native_multi_head_attention.py Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * [PT FE] Add masks comments, improve mask broadcasting --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
# Copyright (C) 2018-2023 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import re
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from common import constants
|
|
|
|
|
|
def pytest_make_parametrize_id(config, val, argname):
|
|
return " {0}:{1} ".format(argname, val)
|
|
|
|
|
|
def pytest_collection_modifyitems(items):
|
|
def remove_ignored_attrs(ref_dict, dict_to_upd):
|
|
_dict_to_upd = dict_to_upd.copy()
|
|
for key, value in dict_to_upd.items():
|
|
if key not in ref_dict.keys():
|
|
_dict_to_upd.pop(key)
|
|
elif isinstance(value, dict):
|
|
_dict_to_upd[key] = remove_ignored_attrs(ref_dict[key], value)
|
|
return _dict_to_upd
|
|
|
|
for test in items:
|
|
special_marks = [mark for mark in test.own_markers if "special_" in mark.name]
|
|
for mark in special_marks:
|
|
if mark.name == "special_xfail":
|
|
params = test.callspec.params
|
|
# Remove items from params if key of item is not in mark.kwargs["args"].
|
|
# Remaining items will be used to mark test cases that contain them.
|
|
# It is required to specify in mark only valuable parameters
|
|
# (e.g. {"device": "FP16"} will mean that for all test cases with FP16 test will be marked)
|
|
params = remove_ignored_attrs(mark.kwargs["args"], params)
|
|
if mark.kwargs["args"] == params:
|
|
test.add_marker(pytest.mark.xfail(reason=mark.kwargs["reason"]))
|
|
|
|
|
|
@pytest.hookimpl(hookwrapper=True)
|
|
def pytest_runtest_makereport(item, call):
|
|
pytest_html = item.config.pluginmanager.getplugin('html')
|
|
outcome = yield
|
|
report = outcome.get_result()
|
|
extra = getattr(report, 'extra', [])
|
|
if report.when == 'call':
|
|
xfail_reason = getattr(report, 'wasxfail', None)
|
|
if report.skipped and xfail_reason:
|
|
jira_ticket_nums = re.findall(r"\*-\d+", xfail_reason)
|
|
for ticket_num in jira_ticket_nums:
|
|
extra.append(pytest_html.extras.url(ticket_num))
|
|
report.extra = extra
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
"""Specify command-line options for all plugins"""
|
|
parser.addoption(
|
|
"--ir_version",
|
|
default=11,
|
|
action="store",
|
|
help="Version of IR to generate by Model Optimizer")
|
|
parser.addoption(
|
|
"--use_new_frontend",
|
|
required=False,
|
|
action="store_true",
|
|
help="Use Model Optimizer with new FrontEnd")
|
|
parser.addoption(
|
|
"--use_old_api",
|
|
action="store_true",
|
|
help="Use old API for model processing in Inference Engine",
|
|
)
|
|
parser.addoption(
|
|
"--tflite",
|
|
required=False,
|
|
action="store_true",
|
|
help="Switch to tflite tests version")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def ir_version(request):
|
|
"""Fixture function for command-line option."""
|
|
return request.config.getoption('ir_version')
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def use_new_frontend(request):
|
|
"""Fixture function for command-line option."""
|
|
return request.config.getoption('use_new_frontend')
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def use_old_api(request):
|
|
"""Fixture function for command-line option."""
|
|
return request.config.getoption('use_old_api')
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def tflite(request):
|
|
"""Fixture function for command-line option."""
|
|
return request.config.getoption('tflite')
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def checks_for_keys_usage(request):
|
|
if request.config.getoption('use_old_api') and request.config.getoption('use_new_frontend'):
|
|
pytest.fail("Old API and new FrontEnd usage detected. Old API doesn't support new FrontEnd")
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def temp_dir(request):
|
|
"""Create directory for test purposes."""
|
|
Path(constants.out_path).mkdir(parents=True, exist_ok=True)
|
|
test_name = re.sub(r"[^\w_]", "_", request.node.originalname)
|
|
device = request.node.funcargs["ie_device"].upper()
|
|
temp_dir = tempfile.TemporaryDirectory(dir=constants.out_path, prefix=f"{device}_{test_name}")
|
|
yield str(temp_dir.name)
|