perf(decor): join predicates and matches cache

This commit is contained in:
vanaigr 2024-12-18 12:23:28 -06:00
parent ef77845b97
commit 8d2ee542a8

View File

@ -762,16 +762,7 @@ end
---@private
---@param match TSQueryMatch
---@param source integer|string
function Query:match_preds(match, source)
local _, pattern = match:info()
local preds = self.info.patterns[pattern]
if not preds then
return true
end
local captures = match:captures()
function Query:match_preds(preds, pattern, captures, source)
for _, pred in pairs(preds) do
-- Here we only want to return if a predicate DOES NOT match, and
-- continue on the other case. This way unknown predicates will not be considered,
@ -807,17 +798,9 @@ end
---@private
---@param match TSQueryMatch
---@return vim.treesitter.query.TSMetadata metadata
function Query:apply_directives(match, source)
function Query:apply_directives(preds, pattern, captures, source)
---@type vim.treesitter.query.TSMetadata
local metadata = {}
local _, pattern = match:info()
local preds = self.info.patterns[pattern]
if not preds then
return metadata
end
local captures = match:captures()
for _, pred in pairs(preds) do
if is_directive(pred[1]) then
@ -902,8 +885,10 @@ function Query:iter_captures(node, source, start, stop)
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
local apply_directives = memoize(match_id_hash, self.apply_directives, false)
local match_preds = memoize(match_id_hash, self.match_preds, false)
-- For faster checks that a match is not in the cache.
local highest_cached_match_id = -1
---@type table<integer, vim.treesitter.query.TSMetadata>
local match_cache = {}
local function iter(end_line)
local capture, captured_node, match = cursor:next_capture()
@ -912,16 +897,35 @@ function Query:iter_captures(node, source, start, stop)
return
end
if not match_preds(self, match, source) then
local match_id = match:info()
cursor:remove_match(match_id)
if end_line and captured_node:range() > end_line then
return nil, captured_node, nil, nil
end
return iter(end_line) -- tail call: try next match
local match_id, pattern = match:info()
--- @type vim.treesitter.query.TSMetadata
local metadata
if match_id <= highest_cached_match_id then
metadata = match_cache[match_id]
end
local metadata = apply_directives(self, match, source)
if not metadata then
local preds = self.info.patterns[pattern]
if preds then
local captures = match:captures()
if not self:match_preds(preds, pattern, captures, source) then
cursor:remove_match(match_id)
if end_line and captured_node:range() > end_line then
return nil, captured_node, nil, nil
end
return iter(end_line) -- tail call: try next match
end
metadata = self:apply_directives(preds, pattern, captures, source)
else
metadata = {}
end
highest_cached_match_id = math.max(highest_cached_match_id, match_id)
match_cache[match_id] = metadata
end
return capture, captured_node, metadata, match
end
@ -985,16 +989,21 @@ function Query:iter_matches(node, source, start, stop, opts)
end
local match_id, pattern = match:info()
if not self:match_preds(match, source) then
cursor:remove_match(match_id)
return iter() -- tail call: try next match
end
local metadata = self:apply_directives(match, source)
local preds = self.info.patterns[pattern]
local captures = match:captures()
--- @type vim.treesitter.query.TSMetadata
local metadata
if preds then
if not self:match_preds(preds, pattern, captures, source) then
cursor:remove_match(match_id)
return iter() -- tail call: try next match
end
metadata = self:apply_directives(preds, pattern, captures, source)
else
metadata = {}
end
if opts.all == false then
-- Convert the match table into the old buggy version for backward
-- compatibility. This is slow, but we only do it when the caller explicitly opted into it by