mirror of
https://github.com/neovim/neovim.git
synced 2025-02-25 18:55:25 -06:00
refactor(treesitter): redesign query iterating
Problem: `TSNode:_rawquery()` is complicated, has known issues and the Lua and C code is awkwardly coupled (see logic with `active`). Solution: - Add `TSQueryCursor` and `TSQueryMatch` bindings. - Replace `TSNode:_rawquery()` with `TSQueryCursor:next_capture()` and `TSQueryCursor:next_match()` - Do more stuff in Lua - API for `Query:iter_captures()` and `Query:iter_matches()` remains the same. - `treesitter.c` no longer contains any logic related to predicates. - Add `match_limit` option to `iter_matches()`. Default is still 256.
This commit is contained in:
parent
16a416cb3c
commit
aca2048bcd
@ -1152,6 +1152,10 @@ Query:iter_captures({node}, {source}, {start}, {stop})
|
|||||||
end
|
end
|
||||||
<
|
<
|
||||||
|
|
||||||
|
Note: ~
|
||||||
|
• Captures are only returned if the query pattern of a specific capture
|
||||||
|
contained predicates.
|
||||||
|
|
||||||
Parameters: ~
|
Parameters: ~
|
||||||
• {node} (`TSNode`) under which the search will occur
|
• {node} (`TSNode`) under which the search will occur
|
||||||
• {source} (`integer|string`) Source buffer or string to extract text
|
• {source} (`integer|string`) Source buffer or string to extract text
|
||||||
@ -1162,7 +1166,7 @@ Query:iter_captures({node}, {source}, {start}, {stop})
|
|||||||
Defaults to `node:end_()`.
|
Defaults to `node:end_()`.
|
||||||
|
|
||||||
Return: ~
|
Return: ~
|
||||||
(`fun(end_line: integer?): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer, TSNode>`)
|
(`fun(end_line: integer?): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer,TSNode[]>?`)
|
||||||
capture id, capture node, metadata, match
|
capture id, capture node, metadata, match
|
||||||
|
|
||||||
*Query:iter_matches()*
|
*Query:iter_matches()*
|
||||||
@ -1206,6 +1210,8 @@ Query:iter_matches({node}, {source}, {start}, {stop}, {opts})
|
|||||||
• max_start_depth (integer) if non-zero, sets the maximum
|
• max_start_depth (integer) if non-zero, sets the maximum
|
||||||
start depth for each match. This is used to prevent
|
start depth for each match. This is used to prevent
|
||||||
traversing too deep into a tree.
|
traversing too deep into a tree.
|
||||||
|
• match_limit (integer) Set the maximum number of
|
||||||
|
in-progress matches (Default: 256).
|
||||||
• all (boolean) When set, the returned match table maps
|
• all (boolean) When set, the returned match table maps
|
||||||
capture IDs to a list of nodes. Older versions of
|
capture IDs to a list of nodes. Older versions of
|
||||||
iter_matches incorrectly mapped capture IDs to a single
|
iter_matches incorrectly mapped capture IDs to a single
|
||||||
|
@ -34,22 +34,6 @@ error('Cannot require a meta file')
|
|||||||
---@field byte_length fun(self: TSNode): integer
|
---@field byte_length fun(self: TSNode): integer
|
||||||
local TSNode = {}
|
local TSNode = {}
|
||||||
|
|
||||||
---@param query TSQuery
|
|
||||||
---@param captures true
|
|
||||||
---@param start? integer
|
|
||||||
---@param end_? integer
|
|
||||||
---@param opts? table
|
|
||||||
---@return fun(): integer, TSNode, vim.treesitter.query.TSMatch
|
|
||||||
function TSNode:_rawquery(query, captures, start, end_, opts) end
|
|
||||||
|
|
||||||
---@param query TSQuery
|
|
||||||
---@param captures false
|
|
||||||
---@param start? integer
|
|
||||||
---@param end_? integer
|
|
||||||
---@param opts? table
|
|
||||||
---@return fun(): integer, vim.treesitter.query.TSMatch
|
|
||||||
function TSNode:_rawquery(query, captures, start, end_, opts) end
|
|
||||||
|
|
||||||
---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string)
|
---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string)
|
||||||
|
|
||||||
---@class TSParser: userdata
|
---@class TSParser: userdata
|
||||||
@ -90,3 +74,31 @@ vim._ts_parse_query = function(lang, query) end
|
|||||||
---@param lang string
|
---@param lang string
|
||||||
---@return TSParser
|
---@return TSParser
|
||||||
vim._create_ts_parser = function(lang) end
|
vim._create_ts_parser = function(lang) end
|
||||||
|
|
||||||
|
--- @class TSQueryMatch: userdata
|
||||||
|
--- @field captures fun(self: TSQueryMatch): table<integer,TSNode[]>
|
||||||
|
local TSQueryMatch = {}
|
||||||
|
|
||||||
|
--- @return integer match_id
|
||||||
|
--- @return integer pattern_index
|
||||||
|
function TSQueryMatch:info() end
|
||||||
|
|
||||||
|
--- @class TSQueryCursor: userdata
|
||||||
|
--- @field remove_match fun(self: TSQueryCursor, id: integer)
|
||||||
|
local TSQueryCursor = {}
|
||||||
|
|
||||||
|
--- @return integer capture
|
||||||
|
--- @return TSNode captured_node
|
||||||
|
--- @return TSQueryMatch match
|
||||||
|
function TSQueryCursor:next_capture() end
|
||||||
|
|
||||||
|
--- @return TSQueryMatch match
|
||||||
|
function TSQueryCursor:next_match() end
|
||||||
|
|
||||||
|
--- @param node TSNode
|
||||||
|
--- @param query TSQuery
|
||||||
|
--- @param start integer?
|
||||||
|
--- @param stop integer?
|
||||||
|
--- @param opts? { max_start_depth?: integer, match_limit?: integer}
|
||||||
|
--- @return TSQueryCursor
|
||||||
|
function vim._create_ts_querycursor(node, query, start, stop, opts) end
|
||||||
|
@ -122,7 +122,7 @@ local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
|
|||||||
end)
|
end)
|
||||||
|
|
||||||
--- @param buf integer
|
--- @param buf integer
|
||||||
--- @param match vim.treesitter.query.TSMatch
|
--- @param match table<integer,TSNode[]>
|
||||||
--- @param query vim.treesitter.Query
|
--- @param query vim.treesitter.Query
|
||||||
--- @param lang_context QueryLinterLanguageContext
|
--- @param lang_context QueryLinterLanguageContext
|
||||||
--- @param diagnostics vim.Diagnostic[]
|
--- @param diagnostics vim.Diagnostic[]
|
||||||
|
@ -258,7 +258,7 @@ end)
|
|||||||
--- handling the "any" vs "all" semantics. They are called from the
|
--- handling the "any" vs "all" semantics. They are called from the
|
||||||
--- predicate_handlers table with the appropriate arguments for each predicate.
|
--- predicate_handlers table with the appropriate arguments for each predicate.
|
||||||
local impl = {
|
local impl = {
|
||||||
--- @param match vim.treesitter.query.TSMatch
|
--- @param match table<integer,TSNode[]>
|
||||||
--- @param source integer|string
|
--- @param source integer|string
|
||||||
--- @param predicate any[]
|
--- @param predicate any[]
|
||||||
--- @param any boolean
|
--- @param any boolean
|
||||||
@ -293,7 +293,7 @@ local impl = {
|
|||||||
return not any
|
return not any
|
||||||
end,
|
end,
|
||||||
|
|
||||||
--- @param match vim.treesitter.query.TSMatch
|
--- @param match table<integer,TSNode[]>
|
||||||
--- @param source integer|string
|
--- @param source integer|string
|
||||||
--- @param predicate any[]
|
--- @param predicate any[]
|
||||||
--- @param any boolean
|
--- @param any boolean
|
||||||
@ -333,7 +333,7 @@ local impl = {
|
|||||||
end,
|
end,
|
||||||
})
|
})
|
||||||
|
|
||||||
--- @param match vim.treesitter.query.TSMatch
|
--- @param match table<integer,TSNode[]>
|
||||||
--- @param source integer|string
|
--- @param source integer|string
|
||||||
--- @param predicate any[]
|
--- @param predicate any[]
|
||||||
--- @param any boolean
|
--- @param any boolean
|
||||||
@ -356,7 +356,7 @@ local impl = {
|
|||||||
end
|
end
|
||||||
end)(),
|
end)(),
|
||||||
|
|
||||||
--- @param match vim.treesitter.query.TSMatch
|
--- @param match table<integer,TSNode[]>
|
||||||
--- @param source integer|string
|
--- @param source integer|string
|
||||||
--- @param predicate any[]
|
--- @param predicate any[]
|
||||||
--- @param any boolean
|
--- @param any boolean
|
||||||
@ -383,13 +383,7 @@ local impl = {
|
|||||||
end,
|
end,
|
||||||
}
|
}
|
||||||
|
|
||||||
---@nodoc
|
---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean
|
||||||
---@class vim.treesitter.query.TSMatch
|
|
||||||
---@field pattern? integer
|
|
||||||
---@field active? boolean
|
|
||||||
---@field [integer] TSNode[]
|
|
||||||
|
|
||||||
---@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean
|
|
||||||
|
|
||||||
-- Predicate handler receive the following arguments
|
-- Predicate handler receive the following arguments
|
||||||
-- (match, pattern, bufnr, predicate)
|
-- (match, pattern, bufnr, predicate)
|
||||||
@ -504,7 +498,7 @@ predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?']
|
|||||||
---@field [integer] vim.treesitter.query.TSMetadata
|
---@field [integer] vim.treesitter.query.TSMetadata
|
||||||
---@field [string] integer|string
|
---@field [string] integer|string
|
||||||
|
|
||||||
---@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
|
---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
|
||||||
|
|
||||||
-- Predicate handler receive the following arguments
|
-- Predicate handler receive the following arguments
|
||||||
-- (match, pattern, bufnr, predicate)
|
-- (match, pattern, bufnr, predicate)
|
||||||
@ -726,13 +720,19 @@ local function is_directive(name)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@private
|
---@private
|
||||||
---@param match vim.treesitter.query.TSMatch
|
---@param match TSQueryMatch
|
||||||
---@param pattern integer
|
|
||||||
---@param source integer|string
|
---@param source integer|string
|
||||||
function Query:match_preds(match, pattern, source)
|
function Query:match_preds(match, source)
|
||||||
|
local _, pattern = match:info()
|
||||||
local preds = self.info.patterns[pattern]
|
local preds = self.info.patterns[pattern]
|
||||||
|
|
||||||
for _, pred in pairs(preds or {}) do
|
if not preds then
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
|
||||||
|
local captures = match:captures()
|
||||||
|
|
||||||
|
for _, pred in pairs(preds) do
|
||||||
-- Here we only want to return if a predicate DOES NOT match, and
|
-- Here we only want to return if a predicate DOES NOT match, and
|
||||||
-- continue on the other case. This way unknown predicates will not be considered,
|
-- continue on the other case. This way unknown predicates will not be considered,
|
||||||
-- which allows some testing and easier user extensibility (#12173).
|
-- which allows some testing and easier user extensibility (#12173).
|
||||||
@ -754,7 +754,7 @@ function Query:match_preds(match, pattern, source)
|
|||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
local pred_matches = handler(match, pattern, source, pred)
|
local pred_matches = handler(captures, pattern, source, pred)
|
||||||
|
|
||||||
if not xor(is_not, pred_matches) then
|
if not xor(is_not, pred_matches) then
|
||||||
return false
|
return false
|
||||||
@ -765,23 +765,33 @@ function Query:match_preds(match, pattern, source)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@private
|
---@private
|
||||||
---@param match vim.treesitter.query.TSMatch
|
---@param match TSQueryMatch
|
||||||
---@param metadata vim.treesitter.query.TSMetadata
|
---@return vim.treesitter.query.TSMetadata metadata
|
||||||
function Query:apply_directives(match, pattern, source, metadata)
|
function Query:apply_directives(match, source)
|
||||||
|
---@type vim.treesitter.query.TSMetadata
|
||||||
|
local metadata = {}
|
||||||
|
local _, pattern = match:info()
|
||||||
local preds = self.info.patterns[pattern]
|
local preds = self.info.patterns[pattern]
|
||||||
|
|
||||||
for _, pred in pairs(preds or {}) do
|
if not preds then
|
||||||
|
return metadata
|
||||||
|
end
|
||||||
|
|
||||||
|
local captures = match:captures()
|
||||||
|
|
||||||
|
for _, pred in pairs(preds) do
|
||||||
if is_directive(pred[1]) then
|
if is_directive(pred[1]) then
|
||||||
local handler = directive_handlers[pred[1]]
|
local handler = directive_handlers[pred[1]]
|
||||||
|
|
||||||
if not handler then
|
if not handler then
|
||||||
error(string.format('No handler for %s', pred[1]))
|
error(string.format('No handler for %s', pred[1]))
|
||||||
return
|
|
||||||
end
|
end
|
||||||
|
|
||||||
handler(match, pattern, source, pred, metadata)
|
handler(captures, pattern, source, pred, metadata)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
return metadata
|
||||||
end
|
end
|
||||||
|
|
||||||
--- Returns the start and stop value if set else the node's range.
|
--- Returns the start and stop value if set else the node's range.
|
||||||
@ -831,8 +841,10 @@ end
|
|||||||
---@param start? integer Starting line for the search. Defaults to `node:start()`.
|
---@param start? integer Starting line for the search. Defaults to `node:start()`.
|
||||||
---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
|
---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
|
||||||
---
|
---
|
||||||
---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer, TSNode>):
|
---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer,TSNode[]>?):
|
||||||
--- capture id, capture node, metadata, match
|
--- capture id, capture node, metadata, match
|
||||||
|
---
|
||||||
|
---@note Captures are only returned if the query pattern of a specific capture contained predicates.
|
||||||
function Query:iter_captures(node, source, start, stop)
|
function Query:iter_captures(node, source, start, stop)
|
||||||
if type(source) == 'number' and source == 0 then
|
if type(source) == 'number' and source == 0 then
|
||||||
source = api.nvim_get_current_buf()
|
source = api.nvim_get_current_buf()
|
||||||
@ -840,24 +852,38 @@ function Query:iter_captures(node, source, start, stop)
|
|||||||
|
|
||||||
start, stop = value_or_node_range(start, stop, node)
|
start, stop = value_or_node_range(start, stop, node)
|
||||||
|
|
||||||
local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch
|
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
|
||||||
|
|
||||||
|
local max_match_id = -1
|
||||||
|
|
||||||
local function iter(end_line)
|
local function iter(end_line)
|
||||||
local capture, captured_node, match = raw_iter()
|
local capture, captured_node, match = cursor:next_capture()
|
||||||
|
|
||||||
|
if not capture then
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
local captures --- @type table<integer,TSNode[]>?
|
||||||
|
local match_id, pattern_index = match:info()
|
||||||
|
|
||||||
local metadata = {}
|
local metadata = {}
|
||||||
|
|
||||||
if match ~= nil then
|
local preds = self.info.patterns[pattern_index] or {}
|
||||||
local active = self:match_preds(match, match.pattern, source)
|
|
||||||
match.active = active
|
if #preds > 0 and match_id > max_match_id then
|
||||||
if not active then
|
captures = match:captures()
|
||||||
|
max_match_id = match_id
|
||||||
|
if not self:match_preds(match, source) then
|
||||||
|
cursor:remove_match(match_id)
|
||||||
if end_line and captured_node:range() > end_line then
|
if end_line and captured_node:range() > end_line then
|
||||||
return nil, captured_node, nil
|
return nil, captured_node, nil
|
||||||
end
|
end
|
||||||
return iter(end_line) -- tail call: try next match
|
return iter(end_line) -- tail call: try next match
|
||||||
end
|
end
|
||||||
|
|
||||||
self:apply_directives(match, match.pattern, source, metadata)
|
metadata = self:apply_directives(match, source)
|
||||||
end
|
end
|
||||||
return capture, captured_node, metadata, match
|
return capture, captured_node, metadata, captures
|
||||||
end
|
end
|
||||||
return iter
|
return iter
|
||||||
end
|
end
|
||||||
@ -899,45 +925,54 @@ end
|
|||||||
---@param opts? table Optional keyword arguments:
|
---@param opts? table Optional keyword arguments:
|
||||||
--- - max_start_depth (integer) if non-zero, sets the maximum start depth
|
--- - max_start_depth (integer) if non-zero, sets the maximum start depth
|
||||||
--- for each match. This is used to prevent traversing too deep into a tree.
|
--- for each match. This is used to prevent traversing too deep into a tree.
|
||||||
|
--- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256).
|
||||||
--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes.
|
--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes.
|
||||||
--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is
|
--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is
|
||||||
--- incorrect behavior. This option will eventually become the default and removed.
|
--- incorrect behavior. This option will eventually become the default and removed.
|
||||||
---
|
---
|
||||||
---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata
|
---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata
|
||||||
function Query:iter_matches(node, source, start, stop, opts)
|
function Query:iter_matches(node, source, start, stop, opts)
|
||||||
local all = opts and opts.all
|
opts = opts or {}
|
||||||
|
opts.match_limit = opts.match_limit or 256
|
||||||
|
|
||||||
if type(source) == 'number' and source == 0 then
|
if type(source) == 'number' and source == 0 then
|
||||||
source = api.nvim_get_current_buf()
|
source = api.nvim_get_current_buf()
|
||||||
end
|
end
|
||||||
|
|
||||||
start, stop = value_or_node_range(start, stop, node)
|
start, stop = value_or_node_range(start, stop, node)
|
||||||
|
|
||||||
local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch
|
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts)
|
||||||
|
|
||||||
local function iter()
|
local function iter()
|
||||||
local pattern, match = raw_iter()
|
local match = cursor:next_match()
|
||||||
local metadata = {}
|
|
||||||
|
|
||||||
if match ~= nil then
|
if not match then
|
||||||
local active = self:match_preds(match, pattern, source)
|
return
|
||||||
if not active then
|
|
||||||
return iter() -- tail call: try next match
|
|
||||||
end
|
|
||||||
|
|
||||||
self:apply_directives(match, pattern, source, metadata)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
if not all then
|
local match_id, pattern = match:info()
|
||||||
|
|
||||||
|
if not self:match_preds(match, source) then
|
||||||
|
cursor:remove_match(match_id)
|
||||||
|
return iter() -- tail call: try next match
|
||||||
|
end
|
||||||
|
|
||||||
|
local metadata = self:apply_directives(match, source)
|
||||||
|
|
||||||
|
local captures = match:captures()
|
||||||
|
|
||||||
|
if not opts.all then
|
||||||
-- Convert the match table into the old buggy version for backward
|
-- Convert the match table into the old buggy version for backward
|
||||||
-- compatibility. This is slow. Plugin authors, if you're reading this, set the "all"
|
-- compatibility. This is slow. Plugin authors, if you're reading this, set the "all"
|
||||||
-- option!
|
-- option!
|
||||||
local old_match = {} ---@type table<integer, TSNode>
|
local old_match = {} ---@type table<integer, TSNode>
|
||||||
for k, v in pairs(match or {}) do
|
for k, v in pairs(captures or {}) do
|
||||||
old_match[k] = v[#v]
|
old_match[k] = v[#v]
|
||||||
end
|
end
|
||||||
return pattern, old_match, metadata
|
return pattern, old_match, metadata
|
||||||
end
|
end
|
||||||
|
|
||||||
return pattern, match, metadata
|
return pattern, captures, metadata
|
||||||
end
|
end
|
||||||
return iter
|
return iter
|
||||||
end
|
end
|
||||||
|
@ -1909,6 +1909,9 @@ static void nlua_add_treesitter(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL
|
|||||||
lua_pushcfunction(lstate, tslua_push_parser);
|
lua_pushcfunction(lstate, tslua_push_parser);
|
||||||
lua_setfield(lstate, -2, "_create_ts_parser");
|
lua_setfield(lstate, -2, "_create_ts_parser");
|
||||||
|
|
||||||
|
lua_pushcfunction(lstate, tslua_push_querycursor);
|
||||||
|
lua_setfield(lstate, -2, "_create_ts_querycursor");
|
||||||
|
|
||||||
lua_pushcfunction(lstate, tslua_add_language);
|
lua_pushcfunction(lstate, tslua_add_language);
|
||||||
lua_setfield(lstate, -2, "_ts_add_language");
|
lua_setfield(lstate, -2, "_ts_add_language");
|
||||||
|
|
||||||
|
@ -33,14 +33,9 @@
|
|||||||
#define TS_META_NODE "treesitter_node"
|
#define TS_META_NODE "treesitter_node"
|
||||||
#define TS_META_QUERY "treesitter_query"
|
#define TS_META_QUERY "treesitter_query"
|
||||||
#define TS_META_QUERYCURSOR "treesitter_querycursor"
|
#define TS_META_QUERYCURSOR "treesitter_querycursor"
|
||||||
|
#define TS_META_QUERYMATCH "treesitter_querymatch"
|
||||||
#define TS_META_TREECURSOR "treesitter_treecursor"
|
#define TS_META_TREECURSOR "treesitter_treecursor"
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
TSQueryCursor *cursor;
|
|
||||||
int predicated_match;
|
|
||||||
int max_match_id;
|
|
||||||
} TSLua_cursor;
|
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
LuaRef cb;
|
LuaRef cb;
|
||||||
lua_State *lstate;
|
lua_State *lstate;
|
||||||
@ -108,7 +103,6 @@ static struct luaL_Reg node_meta[] = {
|
|||||||
{ "named_descendant_for_range", node_named_descendant_for_range },
|
{ "named_descendant_for_range", node_named_descendant_for_range },
|
||||||
{ "parent", node_parent },
|
{ "parent", node_parent },
|
||||||
{ "iter_children", node_iter_children },
|
{ "iter_children", node_iter_children },
|
||||||
{ "_rawquery", node_rawquery },
|
|
||||||
{ "next_sibling", node_next_sibling },
|
{ "next_sibling", node_next_sibling },
|
||||||
{ "prev_sibling", node_prev_sibling },
|
{ "prev_sibling", node_prev_sibling },
|
||||||
{ "next_named_sibling", node_next_named_sibling },
|
{ "next_named_sibling", node_next_named_sibling },
|
||||||
@ -130,18 +124,27 @@ static struct luaL_Reg query_meta[] = {
|
|||||||
{ NULL, NULL }
|
{ NULL, NULL }
|
||||||
};
|
};
|
||||||
|
|
||||||
// cursors are not exposed, but still needs garbage collection
|
// TSQueryCursor
|
||||||
static struct luaL_Reg querycursor_meta[] = {
|
static struct luaL_Reg querycursor_meta[] = {
|
||||||
|
{ "remove_match", querycursor_remove_match },
|
||||||
|
{ "next_capture", querycursor_next_capture },
|
||||||
|
{ "next_match", querycursor_next_match },
|
||||||
{ "__gc", querycursor_gc },
|
{ "__gc", querycursor_gc },
|
||||||
{ NULL, NULL }
|
{ NULL, NULL }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TSQueryMatch
|
||||||
|
static struct luaL_Reg querymatch_meta[] = {
|
||||||
|
{ "info", querymatch_info },
|
||||||
|
{ "captures", querymatch_captures },
|
||||||
|
{ NULL, NULL }
|
||||||
|
};
|
||||||
|
|
||||||
static struct luaL_Reg treecursor_meta[] = {
|
static struct luaL_Reg treecursor_meta[] = {
|
||||||
{ "__gc", treecursor_gc },
|
{ "__gc", treecursor_gc },
|
||||||
{ NULL, NULL }
|
{ NULL, NULL }
|
||||||
};
|
};
|
||||||
|
|
||||||
static kvec_t(TSQueryCursor *) cursors = KV_INITIAL_VALUE;
|
|
||||||
static PMap(cstr_t) langs = MAP_INIT;
|
static PMap(cstr_t) langs = MAP_INIT;
|
||||||
|
|
||||||
static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta)
|
static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta)
|
||||||
@ -166,6 +169,7 @@ void tslua_init(lua_State *L)
|
|||||||
build_meta(L, TS_META_NODE, node_meta);
|
build_meta(L, TS_META_NODE, node_meta);
|
||||||
build_meta(L, TS_META_QUERY, query_meta);
|
build_meta(L, TS_META_QUERY, query_meta);
|
||||||
build_meta(L, TS_META_QUERYCURSOR, querycursor_meta);
|
build_meta(L, TS_META_QUERYCURSOR, querycursor_meta);
|
||||||
|
build_meta(L, TS_META_QUERYMATCH, querymatch_meta);
|
||||||
build_meta(L, TS_META_TREECURSOR, treecursor_meta);
|
build_meta(L, TS_META_TREECURSOR, treecursor_meta);
|
||||||
|
|
||||||
ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree);
|
ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree);
|
||||||
@ -1361,173 +1365,156 @@ static int node_equal(lua_State *L)
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// assumes the match table being on top of the stack
|
int tslua_push_querycursor(lua_State *L)
|
||||||
static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx)
|
|
||||||
{
|
|
||||||
// [match]
|
|
||||||
for (size_t i = 0; i < match->capture_count; i++) {
|
|
||||||
lua_rawgeti(L, -1, (int)match->captures[i].index + 1); // [match, captures]
|
|
||||||
if (lua_isnil(L, -1)) { // [match, nil]
|
|
||||||
lua_pop(L, 1); // [match]
|
|
||||||
lua_createtable(L, 1, 0); // [match, captures]
|
|
||||||
}
|
|
||||||
push_node(L, match->captures[i].node, nodeidx); // [match, captures, node]
|
|
||||||
lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, captures]
|
|
||||||
lua_rawseti(L, -2, (int)match->captures[i].index + 1); // [match]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static int query_next_match(lua_State *L)
|
|
||||||
{
|
|
||||||
TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
|
|
||||||
TSQueryCursor *cursor = ud->cursor;
|
|
||||||
|
|
||||||
TSQuery *query = query_check(L, lua_upvalueindex(3));
|
|
||||||
TSQueryMatch match;
|
|
||||||
if (ts_query_cursor_next_match(cursor, &match)) {
|
|
||||||
lua_pushinteger(L, match.pattern_index + 1); // [index]
|
|
||||||
lua_createtable(L, (int)ts_query_capture_count(query), 0); // [index, match]
|
|
||||||
set_match(L, &match, lua_upvalueindex(2));
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int query_next_capture(lua_State *L)
|
|
||||||
{
|
|
||||||
// Upvalues are:
|
|
||||||
// [ cursor, node, query, current_match ]
|
|
||||||
TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
|
|
||||||
TSQueryCursor *cursor = ud->cursor;
|
|
||||||
|
|
||||||
TSQuery *query = query_check(L, lua_upvalueindex(3));
|
|
||||||
|
|
||||||
if (ud->predicated_match > -1) {
|
|
||||||
lua_getfield(L, lua_upvalueindex(4), "active");
|
|
||||||
bool active = lua_toboolean(L, -1);
|
|
||||||
lua_pop(L, 1);
|
|
||||||
if (!active) {
|
|
||||||
ts_query_cursor_remove_match(cursor, (uint32_t)ud->predicated_match);
|
|
||||||
}
|
|
||||||
ud->predicated_match = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
TSQueryMatch match;
|
|
||||||
uint32_t capture_index;
|
|
||||||
if (ts_query_cursor_next_capture(cursor, &match, &capture_index)) {
|
|
||||||
TSQueryCapture capture = match.captures[capture_index];
|
|
||||||
|
|
||||||
// TODO(vigoux): handle capture quantifiers here
|
|
||||||
lua_pushinteger(L, capture.index + 1); // [index]
|
|
||||||
push_node(L, capture.node, lua_upvalueindex(2)); // [index, node]
|
|
||||||
|
|
||||||
// Now check if we need to run the predicates
|
|
||||||
uint32_t n_pred;
|
|
||||||
ts_query_predicates_for_pattern(query, match.pattern_index, &n_pred);
|
|
||||||
|
|
||||||
if (n_pred > 0 && (ud->max_match_id < (int)match.id)) {
|
|
||||||
ud->max_match_id = (int)match.id;
|
|
||||||
|
|
||||||
// Create a new cleared match table
|
|
||||||
lua_createtable(L, (int)ts_query_capture_count(query), 2); // [index, node, match]
|
|
||||||
set_match(L, &match, lua_upvalueindex(2));
|
|
||||||
lua_pushinteger(L, match.pattern_index + 1);
|
|
||||||
lua_setfield(L, -2, "pattern");
|
|
||||||
|
|
||||||
if (match.capture_count > 1) {
|
|
||||||
ud->predicated_match = (int)match.id;
|
|
||||||
lua_pushboolean(L, false);
|
|
||||||
lua_setfield(L, -2, "active");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set current_match to the new match
|
|
||||||
lua_replace(L, lua_upvalueindex(4)); // [index, node]
|
|
||||||
lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match]
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int node_rawquery(lua_State *L)
|
|
||||||
{
|
{
|
||||||
TSNode node;
|
TSNode node;
|
||||||
if (!node_check(L, 1, &node)) {
|
if (!node_check(L, 1, &node)) {
|
||||||
return 0;
|
return luaL_error(L, "TSNode expected");
|
||||||
}
|
}
|
||||||
|
|
||||||
TSQuery *query = query_check(L, 2);
|
TSQuery *query = query_check(L, 2);
|
||||||
|
if (!query) {
|
||||||
TSQueryCursor *cursor;
|
return luaL_error(L, "TSQuery expected");
|
||||||
if (kv_size(cursors) > 0) {
|
|
||||||
cursor = kv_pop(cursors);
|
|
||||||
} else {
|
|
||||||
cursor = ts_query_cursor_new();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ts_query_cursor_set_max_start_depth(cursor, UINT32_MAX);
|
TSQueryCursor *cursor = ts_query_cursor_new();
|
||||||
ts_query_cursor_set_match_limit(cursor, 256);
|
|
||||||
ts_query_cursor_exec(cursor, query, node);
|
ts_query_cursor_exec(cursor, query, node);
|
||||||
|
|
||||||
bool captures = lua_toboolean(L, 3);
|
if (lua_gettop(L) >= 3) {
|
||||||
|
uint32_t start = (uint32_t)luaL_checkinteger(L, 3);
|
||||||
if (lua_gettop(L) >= 4) {
|
uint32_t end = lua_gettop(L) >= 4 ? (uint32_t)luaL_checkinteger(L, 4) : MAXLNUM;
|
||||||
uint32_t start = (uint32_t)luaL_checkinteger(L, 4);
|
|
||||||
uint32_t end = lua_gettop(L) >= 5 ? (uint32_t)luaL_checkinteger(L, 5) : MAXLNUM;
|
|
||||||
ts_query_cursor_set_point_range(cursor, (TSPoint){ start, 0 }, (TSPoint){ end, 0 });
|
ts_query_cursor_set_point_range(cursor, (TSPoint){ start, 0 }, (TSPoint){ end, 0 });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lua_gettop(L) >= 6 && !lua_isnil(L, 6)) {
|
if (lua_gettop(L) >= 5 && !lua_isnil(L, 5)) {
|
||||||
if (!lua_istable(L, 6)) {
|
if (!lua_istable(L, 5)) {
|
||||||
return luaL_error(L, "table expected");
|
return luaL_error(L, "table expected");
|
||||||
}
|
}
|
||||||
lua_pushnil(L);
|
lua_pushnil(L); // [dict, ..., nil]
|
||||||
// stack: [dict, ..., nil]
|
while (lua_next(L, 5)) {
|
||||||
while (lua_next(L, 6)) {
|
// [dict, ..., key, value]
|
||||||
// stack: [dict, ..., key, value]
|
|
||||||
if (lua_type(L, -2) == LUA_TSTRING) {
|
if (lua_type(L, -2) == LUA_TSTRING) {
|
||||||
char *k = (char *)lua_tostring(L, -2);
|
char *k = (char *)lua_tostring(L, -2);
|
||||||
if (strequal("max_start_depth", k)) {
|
if (strequal("max_start_depth", k)) {
|
||||||
uint32_t max_start_depth = (uint32_t)lua_tointeger(L, -1);
|
uint32_t max_start_depth = (uint32_t)lua_tointeger(L, -1);
|
||||||
ts_query_cursor_set_max_start_depth(cursor, max_start_depth);
|
ts_query_cursor_set_max_start_depth(cursor, max_start_depth);
|
||||||
|
} else if (strequal("match_limit", k)) {
|
||||||
|
uint32_t match_limit = (uint32_t)lua_tointeger(L, -1);
|
||||||
|
ts_query_cursor_set_match_limit(cursor, match_limit);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lua_pop(L, 1); // pop the value; lua_next will pop the key.
|
// pop the value; lua_next will pop the key.
|
||||||
// stack: [dict, ..., key]
|
lua_pop(L, 1); // [dict, ..., key]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TSLua_cursor *ud = lua_newuserdata(L, sizeof(*ud)); // [udata]
|
TSQueryCursor **ud = lua_newuserdata(L, sizeof(*ud)); // [node, query, ..., udata]
|
||||||
ud->cursor = cursor;
|
*ud = cursor;
|
||||||
ud->predicated_match = -1;
|
lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR); // [node, query, ..., udata, meta]
|
||||||
ud->max_match_id = -1;
|
lua_setmetatable(L, -2); // [node, query, ..., udata]
|
||||||
|
|
||||||
lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR);
|
// Copy the fenv which contains the nodes tree.
|
||||||
lua_setmetatable(L, -2); // [udata]
|
lua_getfenv(L, 1); // [udata, reftable]
|
||||||
lua_pushvalue(L, 1); // [udata, node]
|
lua_setfenv(L, -2); // [udata]
|
||||||
|
|
||||||
// include query separately, as to keep a ref to it for gc
|
|
||||||
lua_pushvalue(L, 2); // [udata, node, query]
|
|
||||||
|
|
||||||
if (captures) {
|
|
||||||
// placeholder for match state
|
|
||||||
lua_createtable(L, (int)ts_query_capture_count(query), 2); // [u, n, q, match]
|
|
||||||
lua_pushcclosure(L, query_next_capture, 4); // [closure]
|
|
||||||
} else {
|
|
||||||
lua_pushcclosure(L, query_next_match, 3); // [closure]
|
|
||||||
}
|
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int querycursor_remove_match(lua_State *L)
|
||||||
|
{
|
||||||
|
TSQueryCursor *cursor = querycursor_check(L, 1);
|
||||||
|
uint32_t match_id = (uint32_t)luaL_checkinteger(L, 2);
|
||||||
|
ts_query_cursor_remove_match(cursor, match_id);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void push_querymatch(lua_State *L, TSQueryMatch *match, int uindex)
|
||||||
|
{
|
||||||
|
TSQueryMatch *ud = lua_newuserdata(L, sizeof(TSQueryMatch)); // [udata]
|
||||||
|
*ud = *match;
|
||||||
|
lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYMATCH); // [udata, meta]
|
||||||
|
lua_setmetatable(L, -2); // [udata]
|
||||||
|
|
||||||
|
// Copy the fenv which contains the nodes tree.
|
||||||
|
lua_getfenv(L, uindex); // [udata, reftable]
|
||||||
|
lua_setfenv(L, -2); // [udata]
|
||||||
|
}
|
||||||
|
|
||||||
|
static int querycursor_next_capture(lua_State *L)
|
||||||
|
{
|
||||||
|
TSQueryCursor *cursor = querycursor_check(L, 1);
|
||||||
|
|
||||||
|
TSQueryMatch match;
|
||||||
|
uint32_t capture_index;
|
||||||
|
if (!ts_query_cursor_next_capture(cursor, &match, &capture_index)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
TSQueryCapture capture = match.captures[capture_index];
|
||||||
|
|
||||||
|
// Handle capture quantifiers here
|
||||||
|
lua_pushinteger(L, capture.index + 1); // [index]
|
||||||
|
push_node(L, capture.node, 1); // [index, node]
|
||||||
|
push_querymatch(L, &match, 1);
|
||||||
|
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int querycursor_next_match(lua_State *L)
|
||||||
|
{
|
||||||
|
TSQueryCursor *cursor = querycursor_check(L, 1);
|
||||||
|
|
||||||
|
TSQueryMatch match;
|
||||||
|
if (!ts_query_cursor_next_match(cursor, &match)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
push_querymatch(L, &match, 1);
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static TSQueryCursor *querycursor_check(lua_State *L, int index)
|
||||||
|
{
|
||||||
|
TSQueryCursor **ud = luaL_checkudata(L, index, TS_META_QUERYCURSOR);
|
||||||
|
return *ud;
|
||||||
|
}
|
||||||
|
|
||||||
static int querycursor_gc(lua_State *L)
|
static int querycursor_gc(lua_State *L)
|
||||||
{
|
{
|
||||||
TSLua_cursor *ud = luaL_checkudata(L, 1, TS_META_QUERYCURSOR);
|
TSQueryCursor *cursor = querycursor_check(L, 1);
|
||||||
kv_push(cursors, ud->cursor);
|
ts_query_cursor_delete(cursor);
|
||||||
ud->cursor = NULL;
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int querymatch_info(lua_State *L)
|
||||||
|
{
|
||||||
|
TSQueryMatch *ud = luaL_checkudata(L, 1, TS_META_QUERYMATCH);
|
||||||
|
lua_pushinteger(L, ud->id);
|
||||||
|
lua_pushinteger(L, ud->pattern_index + 1);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int querymatch_captures(lua_State *L)
|
||||||
|
{
|
||||||
|
TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH);
|
||||||
|
lua_newtable(L); // [match, nodes, captures]
|
||||||
|
for (size_t i = 0; i < match->capture_count; i++) {
|
||||||
|
TSQueryCapture capture = match->captures[i];
|
||||||
|
int index = (int)capture.index + 1;
|
||||||
|
|
||||||
|
lua_rawgeti(L, -1, index); // [match, node, captures]
|
||||||
|
if (lua_isnil(L, -1)) { // [match, node, captures, nil]
|
||||||
|
lua_pop(L, 1); // [match, node, captures]
|
||||||
|
lua_newtable(L); // [match, node, captures, nodes]
|
||||||
|
}
|
||||||
|
push_node(L, capture.node, 1); // [match, node, captures, nodes, node]
|
||||||
|
lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, node, captures, nodes]
|
||||||
|
lua_rawseti(L, -2, index); // [match, node, captures]
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Query methods
|
// Query methods
|
||||||
|
|
||||||
int tslua_parse_query(lua_State *L)
|
int tslua_parse_query(lua_State *L)
|
||||||
@ -1638,7 +1625,7 @@ static void query_err_string(const char *src, int error_offset, TSQueryError err
|
|||||||
static TSQuery *query_check(lua_State *L, int index)
|
static TSQuery *query_check(lua_State *L, int index)
|
||||||
{
|
{
|
||||||
TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY);
|
TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY);
|
||||||
return *ud;
|
return ud ? *ud : NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int query_gc(lua_State *L)
|
static int query_gc(lua_State *L)
|
||||||
|
Loading…
Reference in New Issue
Block a user