feat(treesitter): add offset predicate for language injection

refactor(treesitter): add directives to queries
This commit is contained in:
Steven Sojka
2020-11-24 08:50:33 -06:00
parent 82100a6bdb
commit 929f194145
4 changed files with 251 additions and 51 deletions

View File

@@ -289,7 +289,7 @@ function LanguageTree:_get_injections()
local root_node = tree:root()
local start_line, _, end_line, _ = root_node:range()
for pattern, match in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do
for pattern, match, metadata in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do
local lang = nil
local injection_node = nil
local combined = false
@@ -298,9 +298,9 @@ function LanguageTree:_get_injections()
-- using a tag with the language, for example
-- @javascript
for id, node in pairs(match) do
local data = metadata[id]
local name = self._injection_query.captures[id]
-- TODO add a way to offset the content passed to the parser.
-- Needed to shave off leading quotes and things of that nature.
local offset_range = data and data.offset
-- Lang should override any other language tag
if name == "language" then
@@ -308,7 +308,7 @@ function LanguageTree:_get_injections()
elseif name == "combined" then
combined = true
elseif name == "content" then
injection_node = node
injection_node = offset_range or node
-- Ignore any tags that start with "_"
-- Allows for other tags to be used in matches
elseif string.sub(name, 1, 1) ~= "_" then
@@ -317,7 +317,7 @@ function LanguageTree:_get_injections()
end
if not injection_node then
injection_node = node
injection_node = offset_range or node
end
end
end

View File

@@ -92,6 +92,17 @@ local function read_query_files(filenames)
return table.concat(contents, '\n')
end
local match_metatable = {
__index = function(tbl, key)
rawset(tbl, key, {})
return tbl[key]
end
}
local function new_match_metadata()
return setmetatable({}, match_metatable)
end
--- Returns the runtime query {query_name} for {lang}.
--
-- @param lang The language to use for the query
@@ -222,6 +233,44 @@ local predicate_handlers = {
-- As we provide lua-match? also expose vim-match?
predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-- Directives store metadata or perform side effects against a match.
-- Directives should always end with a `!`.
-- Directive handler receive the following arguments
-- (match, pattern, bufnr, predicate)
local directive_handlers = {
["set!"] = function(_, _, _, pred, metadata)
if #pred == 4 then
-- (set! @capture "key" "value")
metadata[pred[2]][pred[3]] = pred[4]
else
-- (set! "key" "value")
metadata[pred[2]] = pred[3]
end
end,
-- Shifts the range of a node.
-- Example: (#offset! @_node 0 1 0 -1)
["offset!"] = function(match, _, _, pred, metadata)
local offset_node = match[pred[2]]
local range = {offset_node:range()}
local start_row_offset = pred[3] or 0
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
local end_col_offset = pred[6] or 0
local key = pred[7] or "offset"
range[1] = range[1] + start_row_offset
range[2] = range[2] + start_col_offset
range[3] = range[3] + end_row_offset
range[4] = range[4] + end_col_offset
-- If this produces an invalid range, we just skip it.
if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
metadata[pred[2]][key] = range
end
end
}
--- Adds a new predicates to be used in queries
--
-- @param name the name of the predicate, without leading #
@@ -229,12 +278,25 @@ predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-- signature will be (match, pattern, bufnr, predicate)
function M.add_predicate(name, handler, force)
if predicate_handlers[name] and not force then
a.nvim_err_writeln(string.format("Overriding %s", name))
error(string.format("Overriding %s", name))
end
predicate_handlers[name] = handler
end
--- Adds a new directive to be used in queries
--
-- @param name the name of the directive, without leading #
-- @param handler the handler function to be used
-- signature will be (match, pattern, bufnr, predicate)
function M.add_directive(name, handler, force)
if directive_handlers[name] and not force then
error(string.format("Overriding %s", name))
end
directive_handlers[name] = handler
end
--- Returns the list of currently supported predicates
function M.list_predicates()
return vim.tbl_keys(predicate_handlers)
@@ -244,6 +306,10 @@ local function xor(x, y)
return (x or y) and not (x and y)
end
local function is_directive(name)
return string.sub(name, -1) == "!"
end
function Query:match_preds(match, pattern, source)
local preds = self.info.patterns[pattern]
@@ -254,30 +320,52 @@ function Query:match_preds(match, pattern, source)
-- Also, tree-sitter strips the leading # from predicates for us.
local pred_name
local is_not
if string.sub(pred[1], 1, 4) == "not-" then
pred_name = string.sub(pred[1], 5)
is_not = true
else
pred_name = pred[1]
is_not = false
end
local handler = predicate_handlers[pred_name]
-- Skip over directives... they will get processed after all the predicates.
if not is_directive(pred[1]) then
if string.sub(pred[1], 1, 4) == "not-" then
pred_name = string.sub(pred[1], 5)
is_not = true
else
pred_name = pred[1]
is_not = false
end
if not handler then
a.nvim_err_writeln(string.format("No handler for %s", pred[1]))
return false
end
local handler = predicate_handlers[pred_name]
local pred_matches = handler(match, pattern, source, pred)
if not handler then
error(string.format("No handler for %s", pred[1]))
return false
end
if not xor(is_not, pred_matches) then
return false
local pred_matches = handler(match, pattern, source, pred)
if not xor(is_not, pred_matches) then
return false
end
end
end
return true
end
--- Applies directives against a match and pattern.
function Query:apply_directives(match, pattern, source, metadata)
local preds = self.info.patterns[pattern]
for _, pred in pairs(preds or {}) do
if is_directive(pred[1]) then
local handler = directive_handlers[pred[1]]
if not handler then
error(string.format("No handler for %s", pred[1]))
return
end
handler(match, pattern, source, pred, metadata)
end
end
end
--- Iterates of the captures of self on a given range.
--
-- @param node The node under witch the search will occur
@@ -294,14 +382,18 @@ function Query:iter_captures(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, true, start, stop)
local function iter()
local capture, captured_node, match = raw_iter()
local metadata = new_match_metadata()
if match ~= nil then
local active = self:match_preds(match, match.pattern, source)
match.active = active
if not active then
return iter() -- tail call: try next match
end
self:apply_directives(match, match.pattern, source, metadata)
end
return capture, captured_node
return capture, captured_node, metadata
end
return iter
end
@@ -322,13 +414,17 @@ function Query:iter_matches(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, false, start, stop)
local function iter()
local pattern, match = raw_iter()
local metadata = new_match_metadata()
if match ~= nil then
local active = self:match_preds(match, pattern, source)
if not active then
return iter() -- tail call: try next match
end
self:apply_directives(match, pattern, source, metadata)
end
return pattern, match
return pattern, match, metadata
end
return iter
end