Merge pull request #13367 from nvim-treesitter/offset-lang-injection

feat(treesitter): add offset predicate for language injection
This commit is contained in:
Björn Linse 2020-12-16 13:59:36 +01:00 committed by GitHub
commit 5e202f69b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 251 additions and 51 deletions

View File

@ -197,11 +197,11 @@ query:iter_captures({node}, {bufnr}, {start_row}, {end_row})
as the node, i e to get syntax highlight matches in the current as the node, i e to get syntax highlight matches in the current
viewport) viewport)
The iterator returns two values, a numeric id identifying the capture The iterator returns three values, a numeric id identifying the capture,
and the captured node. The following example shows how to get captures the captured node, and metadata from any directives processing the match.
by name: The following example shows how to get captures by name:
> >
for id, node in query:iter_captures(tree:root(), bufnr, first, last) do for id, node, metadata in query:iter_captures(tree:root(), bufnr, first, last) do
local name = query.captures[id] -- name of the capture in the query local name = query.captures[id] -- name of the capture in the query
-- typically useful info about the node: -- typically useful info about the node:
local type = node:type() -- type of the captured node local type = node:type() -- type of the captured node
@ -213,16 +213,19 @@ query:iter_matches({node}, {bufnr}, {start_row}, {end_row})
*query:iter_matches()* *query:iter_matches()*
Iterate over all matches within a node. The arguments are the same as Iterate over all matches within a node. The arguments are the same as
for |query:iter_captures()| but the iterated values are different: for |query:iter_captures()| but the iterated values are different:
an (1-based) index of the pattern in the query, and a table mapping an (1-based) index of the pattern in the query, a table mapping
capture indices to nodes. If the query has more than one pattern capture indices to nodes, and metadata from any directives processing the match.
the capture table might be sparse, and e.g. `pairs` should be used and not If the query has more than one pattern the capture table might be sparse,
`ipairs`. Here an example iterating over all captures in and e.g. `pairs()` method should be used over `ipairs`.
every match: Here an example iterating over all captures in every match:
> >
for pattern, match in cquery:iter_matches(tree:root(), bufnr, first, last) do for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do
for id,node in pairs(match) do for id, node in pairs(match) do
local name = query.captures[id] local name = query.captures[id]
-- `node` was captured by the `name` capture in the match -- `node` was captured by the `name` capture in the match
local node_data = metadata[id] -- Node level metadata
... use the info here ... ... use the info here ...
end end
end end
@ -265,6 +268,29 @@ Here is a list of built-in predicates :
Each predicate has a `not-` prefixed predicate that is just the negation of Each predicate has a `not-` prefixed predicate that is just the negation of
the predicate. the predicate.
Treesitter Query Directive *lua-treesitter-directives*
Treesitter queries can also contain `directives`. Directives store metadata for a node
or match and perform side effects. for example, the |set!| predicate sets metadata on
the match or node : >
((identifier) @foo (#set! "type" "parameter"))
Here is a list of built-in directives:
`set!` *ts-directive-set!*
Sets key/value metadata for a specific node or match : >
((identifier) @foo (#set! @foo "kind" "parameter"))
((node1) @left (node2) @right (#set! "type" "pair"))
<
`offset!` *ts-predicate-offset!*
Takes the range of the captured node and applies the offsets
to it's range : >
((idenfitier) @constant (#offset! @constant 0 1 0 -1))
< This will generate a range object for the captured node with the
offsets applied. The arguments are
`({capture_id}, {start_row}, {start_col}, {end_row}, {end_col}, {key?})`
The default key is "offset".
*vim.treesitter.query.add_predicate()* *vim.treesitter.query.add_predicate()*
vim.treesitter.query.add_predicate({name}, {handler}) vim.treesitter.query.add_predicate({name}, {handler})
@ -277,6 +303,16 @@ vim.treesitter.query.list_predicates()
This lists the currently available predicates to use in queries. This lists the currently available predicates to use in queries.
*vim.treesitter.query.add_directive()*
vim.treesitter.query.add_directive({name}, {handler})
This adds a directive with the name {name} to be used in queries.
{handler} should be a function whose signature will be : >
handler(match, pattern, bufnr, predicate, metadata)
Handlers can set match level data by setting directly on the metadata object `metadata.key = value`
Handlers can set node level data by using the capture id on the metadata table
`metadata[capture_id].key = value`
Treesitter syntax highlighting (WIP) *lua-treesitter-highlight* Treesitter syntax highlighting (WIP) *lua-treesitter-highlight*
NOTE: This is a partially implemented feature, and not usable as a default NOTE: This is a partially implemented feature, and not usable as a default

View File

@ -289,7 +289,7 @@ function LanguageTree:_get_injections()
local root_node = tree:root() local root_node = tree:root()
local start_line, _, end_line, _ = root_node:range() local start_line, _, end_line, _ = root_node:range()
for pattern, match in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do for pattern, match, metadata in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do
local lang = nil local lang = nil
local injection_node = nil local injection_node = nil
local combined = false local combined = false
@ -298,9 +298,9 @@ function LanguageTree:_get_injections()
-- using a tag with the language, for example -- using a tag with the language, for example
-- @javascript -- @javascript
for id, node in pairs(match) do for id, node in pairs(match) do
local data = metadata[id]
local name = self._injection_query.captures[id] local name = self._injection_query.captures[id]
-- TODO add a way to offset the content passed to the parser. local offset_range = data and data.offset
-- Needed to shave off leading quotes and things of that nature.
-- Lang should override any other language tag -- Lang should override any other language tag
if name == "language" then if name == "language" then
@ -308,7 +308,7 @@ function LanguageTree:_get_injections()
elseif name == "combined" then elseif name == "combined" then
combined = true combined = true
elseif name == "content" then elseif name == "content" then
injection_node = node injection_node = offset_range or node
-- Ignore any tags that start with "_" -- Ignore any tags that start with "_"
-- Allows for other tags to be used in matches -- Allows for other tags to be used in matches
elseif string.sub(name, 1, 1) ~= "_" then elseif string.sub(name, 1, 1) ~= "_" then
@ -317,7 +317,7 @@ function LanguageTree:_get_injections()
end end
if not injection_node then if not injection_node then
injection_node = node injection_node = offset_range or node
end end
end end
end end

View File

@ -92,6 +92,17 @@ local function read_query_files(filenames)
return table.concat(contents, '\n') return table.concat(contents, '\n')
end end
local match_metatable = {
__index = function(tbl, key)
rawset(tbl, key, {})
return tbl[key]
end
}
local function new_match_metadata()
return setmetatable({}, match_metatable)
end
--- Returns the runtime query {query_name} for {lang}. --- Returns the runtime query {query_name} for {lang}.
-- --
-- @param lang The language to use for the query -- @param lang The language to use for the query
@ -222,6 +233,44 @@ local predicate_handlers = {
-- As we provide lua-match? also expose vim-match? -- As we provide lua-match? also expose vim-match?
predicate_handlers["vim-match?"] = predicate_handlers["match?"] predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-- Directives store metadata or perform side effects against a match.
-- Directives should always end with a `!`.
-- Directive handler receive the following arguments
-- (match, pattern, bufnr, predicate)
local directive_handlers = {
["set!"] = function(_, _, _, pred, metadata)
if #pred == 4 then
-- (set! @capture "key" "value")
metadata[pred[2]][pred[3]] = pred[4]
else
-- (set! "key" "value")
metadata[pred[2]] = pred[3]
end
end,
-- Shifts the range of a node.
-- Example: (#offset! @_node 0 1 0 -1)
["offset!"] = function(match, _, _, pred, metadata)
local offset_node = match[pred[2]]
local range = {offset_node:range()}
local start_row_offset = pred[3] or 0
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
local end_col_offset = pred[6] or 0
local key = pred[7] or "offset"
range[1] = range[1] + start_row_offset
range[2] = range[2] + start_col_offset
range[3] = range[3] + end_row_offset
range[4] = range[4] + end_col_offset
-- If this produces an invalid range, we just skip it.
if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
metadata[pred[2]][key] = range
end
end
}
--- Adds a new predicates to be used in queries --- Adds a new predicates to be used in queries
-- --
-- @param name the name of the predicate, without leading # -- @param name the name of the predicate, without leading #
@ -229,12 +278,25 @@ predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-- signature will be (match, pattern, bufnr, predicate) -- signature will be (match, pattern, bufnr, predicate)
function M.add_predicate(name, handler, force) function M.add_predicate(name, handler, force)
if predicate_handlers[name] and not force then if predicate_handlers[name] and not force then
a.nvim_err_writeln(string.format("Overriding %s", name)) error(string.format("Overriding %s", name))
end end
predicate_handlers[name] = handler predicate_handlers[name] = handler
end end
--- Adds a new directive to be used in queries
--
-- @param name the name of the directive, without leading #
-- @param handler the handler function to be used
-- signature will be (match, pattern, bufnr, predicate)
function M.add_directive(name, handler, force)
if directive_handlers[name] and not force then
error(string.format("Overriding %s", name))
end
directive_handlers[name] = handler
end
--- Returns the list of currently supported predicates --- Returns the list of currently supported predicates
function M.list_predicates() function M.list_predicates()
return vim.tbl_keys(predicate_handlers) return vim.tbl_keys(predicate_handlers)
@ -244,6 +306,10 @@ local function xor(x, y)
return (x or y) and not (x and y) return (x or y) and not (x and y)
end end
local function is_directive(name)
return string.sub(name, -1) == "!"
end
function Query:match_preds(match, pattern, source) function Query:match_preds(match, pattern, source)
local preds = self.info.patterns[pattern] local preds = self.info.patterns[pattern]
@ -254,30 +320,52 @@ function Query:match_preds(match, pattern, source)
-- Also, tree-sitter strips the leading # from predicates for us. -- Also, tree-sitter strips the leading # from predicates for us.
local pred_name local pred_name
local is_not local is_not
if string.sub(pred[1], 1, 4) == "not-" then
pred_name = string.sub(pred[1], 5)
is_not = true
else
pred_name = pred[1]
is_not = false
end
local handler = predicate_handlers[pred_name] -- Skip over directives... they will get processed after all the predicates.
if not is_directive(pred[1]) then
if string.sub(pred[1], 1, 4) == "not-" then
pred_name = string.sub(pred[1], 5)
is_not = true
else
pred_name = pred[1]
is_not = false
end
if not handler then local handler = predicate_handlers[pred_name]
a.nvim_err_writeln(string.format("No handler for %s", pred[1]))
return false
end
local pred_matches = handler(match, pattern, source, pred) if not handler then
error(string.format("No handler for %s", pred[1]))
return false
end
if not xor(is_not, pred_matches) then local pred_matches = handler(match, pattern, source, pred)
return false
if not xor(is_not, pred_matches) then
return false
end
end end
end end
return true return true
end end
--- Applies directives against a match and pattern.
function Query:apply_directives(match, pattern, source, metadata)
local preds = self.info.patterns[pattern]
for _, pred in pairs(preds or {}) do
if is_directive(pred[1]) then
local handler = directive_handlers[pred[1]]
if not handler then
error(string.format("No handler for %s", pred[1]))
return
end
handler(match, pattern, source, pred, metadata)
end
end
end
--- Iterates of the captures of self on a given range. --- Iterates of the captures of self on a given range.
-- --
-- @param node The node under witch the search will occur -- @param node The node under witch the search will occur
@ -294,14 +382,18 @@ function Query:iter_captures(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, true, start, stop) local raw_iter = node:_rawquery(self.query, true, start, stop)
local function iter() local function iter()
local capture, captured_node, match = raw_iter() local capture, captured_node, match = raw_iter()
local metadata = new_match_metadata()
if match ~= nil then if match ~= nil then
local active = self:match_preds(match, match.pattern, source) local active = self:match_preds(match, match.pattern, source)
match.active = active match.active = active
if not active then if not active then
return iter() -- tail call: try next match return iter() -- tail call: try next match
end end
self:apply_directives(match, match.pattern, source, metadata)
end end
return capture, captured_node return capture, captured_node, metadata
end end
return iter return iter
end end
@ -322,13 +414,17 @@ function Query:iter_matches(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, false, start, stop) local raw_iter = node:_rawquery(self.query, false, start, stop)
local function iter() local function iter()
local pattern, match = raw_iter() local pattern, match = raw_iter()
local metadata = new_match_metadata()
if match ~= nil then if match ~= nil then
local active = self:match_preds(match, pattern, source) local active = self:match_preds(match, pattern, source)
if not active then if not active then
return iter() -- tail call: try next match return iter() -- tail call: try next match
end end
self:apply_directives(match, pattern, source, metadata)
end end
return pattern, match return pattern, match, metadata
end end
return iter return iter
end end

View File

@ -871,12 +871,12 @@ local hl_query = [[
before_each(function() before_each(function()
insert([[ insert([[
int x = INT_MAX; int x = INT_MAX;
#define READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y)) #define READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
#define READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y)) #define READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
#define VALUE 0 #define VALUE 123
#define VALUE1 1 #define VALUE1 123
#define VALUE2 2 #define VALUE2 123
]]) ]])
end) end)
@ -891,12 +891,12 @@ local hl_query = [[
eq("table", exec_lua("return type(parser:children().c)")) eq("table", exec_lua("return type(parser:children().c)"))
eq(5, exec_lua("return #parser:children().c:trees()")) eq(5, exec_lua("return #parser:children().c:trees()"))
eq({ eq({
{0, 2, 7, 0}, -- root tree {0, 0, 7, 0}, -- root tree
{3, 16, 3, 17}, -- VALUE 0 {3, 14, 3, 17}, -- VALUE 123
{4, 17, 4, 18}, -- VALUE1 1 {4, 15, 4, 18}, -- VALUE1 123
{5, 17, 5, 18}, -- VALUE2 2 {5, 15, 5, 18}, -- VALUE2 123
{1, 28, 1, 67}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y)) {1, 26, 1, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
{2, 31, 2, 70} -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y)) {2, 29, 2, 68} -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges()) }, get_ranges())
end) end)
end) end)
@ -912,15 +912,35 @@ local hl_query = [[
eq("table", exec_lua("return type(parser:children().c)")) eq("table", exec_lua("return type(parser:children().c)"))
eq(2, exec_lua("return #parser:children().c:trees()")) eq(2, exec_lua("return #parser:children().c:trees()"))
eq({ eq({
{0, 2, 7, 0}, -- root tree {0, 0, 7, 0}, -- root tree
{3, 16, 5, 18}, -- VALUE 0 {3, 14, 5, 18}, -- VALUE 123
-- VALUE1 1 -- VALUE1 123
-- VALUE2 2 -- VALUE2 123
{1, 28, 2, 70} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y)) {1, 26, 2, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y)) -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges()) }, get_ranges())
end) end)
end) end)
describe("when using the offset directive", function()
it("should shift the range by the directive amount", function()
exec_lua([[
parser = vim.treesitter.get_parser(0, "c", {
queries = {
c = "(preproc_def ((preproc_arg) @c (#offset! @c 0 2 0 -1))) (preproc_function_def value: (preproc_arg) @c)"}})
]])
eq("table", exec_lua("return type(parser:children().c)"))
eq({
{0, 0, 7, 0}, -- root tree
{3, 15, 3, 16}, -- VALUE 123
{4, 16, 4, 17}, -- VALUE1 123
{5, 16, 5, 17}, -- VALUE2 123
{1, 26, 1, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
{2, 29, 2, 68} -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
end)
end)
end) end)
describe("when getting the language for a range", function() describe("when getting the language for a range", function()
@ -944,4 +964,52 @@ int x = INT_MAX;
eq(result, true) eq(result, true)
end) end)
end) end)
describe("when getting/setting match data", function()
describe("when setting for the whole match", function()
it("should set/get the data correctly", function()
insert([[
int x = 3;
]])
local result = exec_lua([[
local result
query = vim.treesitter.parse_query("c", '((number_literal) @number (#set! "key" "value"))')
parser = vim.treesitter.get_parser(0, "c")
for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0, 0, 1) do
result = metadata.key
end
return result
]])
eq(result, "value")
end)
end)
describe("when setting for a capture match", function()
it("should set/get the data correctly", function()
insert([[
int x = 3;
]])
local result = exec_lua([[
local result
query = vim.treesitter.parse_query("c", '((number_literal) @number (#set! @number "key" "value"))')
parser = vim.treesitter.get_parser(0, "c")
for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0, 0, 1) do
result = metadata[pattern].key
end
return result
]])
eq(result, "value")
end)
end)
end)
end) end)