refactor(treesitter): delegate region calculation to treesitter (#22553)

This commit is contained in:
Lewis Russell 2023-03-08 17:22:28 +00:00 committed by GitHub
parent 898f902e00
commit 276b647fdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 323 additions and 189 deletions

View File

@ -1037,6 +1037,9 @@ LanguageTree:included_regions({self}) *LanguageTree:included_regions()*
Parameters: ~ Parameters: ~
• {self} • {self}
Return: ~
integer[][]
LanguageTree:invalidate({self}, {reload}) *LanguageTree:invalidate()* LanguageTree:invalidate({self}, {reload}) *LanguageTree:invalidate()*
Invalidates this parser and all its children Invalidates this parser and all its children
@ -1044,12 +1047,17 @@ LanguageTree:invalidate({self}, {reload}) *LanguageTree:invalidate()*
• {reload} (boolean|nil) • {reload} (boolean|nil)
• {self} • {self}
LanguageTree:is_valid({self}) *LanguageTree:is_valid()* *LanguageTree:is_valid()*
LanguageTree:is_valid({self}, {exclude_children})
Determines whether this tree is valid. If the tree is invalid, call `parse()` . This will return the updated tree. Determines whether this tree is valid. If the tree is invalid, call `parse()` . This will return the updated tree.
Parameters: ~ Parameters: ~
• {exclude_children} (boolean|nil)
• {self} • {self}
Return: ~
(boolean)
LanguageTree:lang({self}) *LanguageTree:lang()* LanguageTree:lang({self}) *LanguageTree:lang()*
Gets the language of this tree node. Gets the language of this tree node.

View File

@ -3,7 +3,7 @@
---@class TSNode ---@class TSNode
---@field id fun(self: TSNode): integer ---@field id fun(self: TSNode): integer
---@field tree fun(self: TSNode): TSTree ---@field tree fun(self: TSNode): TSTree
---@field range fun(self: TSNode): integer, integer, integer, integer ---@field range fun(self: TSNode, include_bytes: boolean?): integer, integer, integer, integer, integer, integer
---@field start fun(self: TSNode): integer, integer, integer ---@field start fun(self: TSNode): integer, integer, integer
---@field end_ fun(self: TSNode): integer, integer, integer ---@field end_ fun(self: TSNode): integer, integer, integer
---@field type fun(self: TSNode): string ---@field type fun(self: TSNode): string
@ -43,9 +43,9 @@ function TSNode:_rawquery(query, captures, start, end_) end
function TSNode:_rawquery(query, captures, start, end_) end function TSNode:_rawquery(query, captures, start, end_) end
---@class TSParser ---@class TSParser
---@field parse fun(self: TSParser, tree, source: integer|string): TSTree, Range4[] ---@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: boolean?): TSTree, integer[]
---@field reset fun(self: TSParser) ---@field reset fun(self: TSParser)
---@field included_ranges fun(self: TSParser): Range4[] ---@field included_ranges fun(self: TSParser, include_bytes: boolean?): integer[]
---@field set_included_ranges fun(self: TSParser, ranges: Range6[]) ---@field set_included_ranges fun(self: TSParser, ranges: Range6[])
---@field set_timeout fun(self: TSParser, timeout: integer) ---@field set_timeout fun(self: TSParser, timeout: integer)
---@field timeout fun(self: TSParser): integer ---@field timeout fun(self: TSParser): integer
@ -54,6 +54,7 @@ function TSNode:_rawquery(query, captures, start, end_) end
---@field root fun(self: TSTree): TSNode ---@field root fun(self: TSTree): TSNode
---@field edit fun(self: TSTree, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _:integer) ---@field edit fun(self: TSTree, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _:integer)
---@field copy fun(self: TSTree): TSTree ---@field copy fun(self: TSTree): TSTree
---@field included_ranges fun(self: TSTree, include_bytes: boolean?): integer[]
---@return integer ---@return integer
vim._ts_get_language_version = function() end vim._ts_get_language_version = function() end

View File

@ -78,11 +78,8 @@ end
---@param r2 Range4|Range6 ---@param r2 Range4|Range6
---@return boolean ---@return boolean
function M.intercepts(r1, r2) function M.intercepts(r1, r2)
local off_1 = #r1 == 6 and 1 or 0 local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
local off_2 = #r1 == 6 and 1 or 0 local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
local srow_1, scol_1, erow_1, ecol_1 = r1[1], r1[2], r1[3 + off_1], r1[4 + off_1]
local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
-- r1 is above r2 -- r1 is above r2
if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
@ -97,16 +94,21 @@ function M.intercepts(r1, r2)
return true return true
end end
---@private
---@param r Range4|Range6
---@return integer, integer, integer, integer
function M.unpack4(r)
local off_1 = #r == 6 and 1 or 0
return r[1], r[2], r[3 + off_1], r[4 + off_1]
end
---@private ---@private
---@param r1 Range4|Range6 ---@param r1 Range4|Range6
---@param r2 Range4|Range6 ---@param r2 Range4|Range6
---@return boolean whether r1 contains r2 ---@return boolean whether r1 contains r2
function M.contains(r1, r2) function M.contains(r1, r2)
local off_1 = #r1 == 6 and 1 or 0 local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
local off_2 = #r1 == 6 and 1 or 0 local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
local srow_1, scol_1, erow_1, ecol_1 = r1[1], r1[2], r1[3 + off_1], r1[4 + off_1]
local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
-- start doesn't fit -- start doesn't fit
if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
@ -123,9 +125,13 @@ end
---@private ---@private
---@param source integer|string ---@param source integer|string
---@param range Range4 ---@param range Range4|Range6
---@return Range6 ---@return Range6
function M.add_bytes(source, range) function M.add_bytes(source, range)
if type(range) == 'table' and #range == 6 then
return range --[[@as Range6]]
end
local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4] local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
local start_byte = 0 local start_byte = 0
local end_byte = 0 local end_byte = 0

View File

@ -57,13 +57,13 @@ local Range = require('vim.treesitter._range')
---@field private _injection_query Query Queries defining injected languages ---@field private _injection_query Query Queries defining injected languages
---@field private _opts table Options ---@field private _opts table Options
---@field private _parser TSParser Parser for language ---@field private _parser TSParser Parser for language
---@field private _regions Range6[][] List of regions this tree should manage and parse ---@field private _regions Range6[][]?
---List of regions this tree should manage and parse. If nil then regions are
---taken from _trees. This is mostly a short-lived cache for included_regions()
---@field private _lang string Language name ---@field private _lang string Language name
---@field private _source (integer|string) Buffer or string to parse ---@field private _source (integer|string) Buffer or string to parse
---@field private _trees TSTree[] Reference to parsed tree (one for each language) ---@field private _trees TSTree[] Reference to parsed tree (one for each language)
---@field private _valid boolean|table<integer,boolean> If the parsed tree is valid ---@field private _valid boolean|table<integer,boolean> If the parsed tree is valid
--- TODO(lewis6991): combine _regions, _valid and _trees
---@field private _is_child boolean
local LanguageTree = {} local LanguageTree = {}
---@class LanguageTreeOpts ---@class LanguageTreeOpts
@ -98,7 +98,6 @@ function LanguageTree.new(source, lang, opts)
_source = source, _source = source,
_lang = lang, _lang = lang,
_children = {}, _children = {},
_regions = {},
_trees = {}, _trees = {},
_opts = opts, _opts = opts,
_injection_query = injections[lang] and query.parse_query(lang, injections[lang]) _injection_query = injections[lang] and query.parse_query(lang, injections[lang])
@ -117,6 +116,48 @@ function LanguageTree.new(source, lang, opts)
return self return self
end end
---@private
---Measure execution time of a function
---@generic R1, R2, R3
---@param f fun(): R1, R2, R2
---@return integer, R1, R2, R3
local function tcall(f, ...)
local start = vim.loop.hrtime()
---@diagnostic disable-next-line
local r = { f(...) }
local duration = (vim.loop.hrtime() - start) / 1000000
return duration, unpack(r)
end
---@private
---@vararg any
function LanguageTree:_log(...)
if vim.g.__ts_debug == nil then
return
end
local args = { ... }
if type(args[1]) == 'function' then
args = { args[1]() }
end
local info = debug.getinfo(2, 'nl')
local nregions = #self:included_regions()
local prefix =
string.format('%s:%d: [%s:%d] ', info.name, info.currentline, self:lang(), nregions)
a.nvim_out_write(prefix)
for _, x in ipairs(args) do
if type(x) == 'string' then
a.nvim_out_write(x)
else
a.nvim_out_write(vim.inspect(x, { newline = ' ', indent = '' }))
end
a.nvim_out_write(' ')
end
a.nvim_out_write('\n')
end
--- Invalidates this parser and all its children --- Invalidates this parser and all its children
---@param reload boolean|nil ---@param reload boolean|nil
function LanguageTree:invalidate(reload) function LanguageTree:invalidate(reload)
@ -146,7 +187,9 @@ end
--- Determines whether this tree is valid. --- Determines whether this tree is valid.
--- If the tree is invalid, call `parse()`. --- If the tree is invalid, call `parse()`.
--- This will return the updated tree. --- This will return the updated tree.
function LanguageTree:is_valid() ---@param exclude_children boolean|nil
---@return boolean
function LanguageTree:is_valid(exclude_children)
local valid = self._valid local valid = self._valid
if type(valid) == 'table' then if type(valid) == 'table' then
@ -155,9 +198,18 @@ function LanguageTree:is_valid()
return false return false
end end
end end
return true
end end
if not exclude_children then
for _, child in pairs(self._children) do
if not child:is_valid(exclude_children) then
return false
end
end
end
assert(type(valid) == 'boolean')
return valid return valid
end end
@ -171,16 +223,6 @@ function LanguageTree:source()
return self._source return self._source
end end
---@private
---This is only exposed so it can be wrapped for profiling
---@param old_tree TSTree
---@return TSTree, integer[]
function LanguageTree:_parse_tree(old_tree)
local tree, tree_changes = self._parser:parse(old_tree, self._source)
self:_do_callback('changedtree', tree_changes, tree)
return tree, tree_changes
end
--- Parses all defined regions using a treesitter parser --- Parses all defined regions using a treesitter parser
--- for the language this tree represents. --- for the language this tree represents.
--- This will run the injection query for this language to --- This will run the injection query for this language to
@ -190,31 +232,39 @@ end
---@return table|nil Change list ---@return table|nil Change list
function LanguageTree:parse() function LanguageTree:parse()
if self:is_valid() then if self:is_valid() then
self:_log('valid')
return self._trees return self._trees
end end
local changes = {} local changes = {}
-- If there are no ranges, set to an empty list -- Collect some stats
-- so the included ranges in the parser are cleared. local regions_parsed = 0
if #self._regions > 0 then local total_parse_time = 0
for i, ranges in ipairs(self._regions) do
--- At least 1 region is invalid
if not self:is_valid(true) then
-- If there are no ranges, set to an empty list
-- so the included ranges in the parser are cleared.
for i, ranges in ipairs(self:included_regions()) do
if not self._valid or not self._valid[i] then if not self._valid or not self._valid[i] then
self._parser:set_included_ranges(ranges) self._parser:set_included_ranges(ranges)
local tree, tree_changes = self:_parse_tree(self._trees[i]) local parse_time, tree, tree_changes =
tcall(self._parser.parse, self._parser, self._trees[i], self._source)
self:_do_callback('changedtree', tree_changes, tree)
self._trees[i] = tree self._trees[i] = tree
vim.list_extend(changes, tree_changes) vim.list_extend(changes, tree_changes)
total_parse_time = total_parse_time + parse_time
regions_parsed = regions_parsed + 1
end end
end end
else
local tree, tree_changes = self:_parse_tree(self._trees[1])
self._trees = { tree }
changes = tree_changes
end end
local injections_by_lang = self:_get_injections()
local seen_langs = {} ---@type table<string,boolean> local seen_langs = {} ---@type table<string,boolean>
local query_time, injections_by_lang = tcall(self._get_injections, self)
for lang, injection_ranges in pairs(injections_by_lang) do for lang, injection_ranges in pairs(injections_by_lang) do
local has_lang = pcall(language.add, lang) local has_lang = pcall(language.add, lang)
@ -229,15 +279,6 @@ function LanguageTree:parse()
end end
child:set_included_regions(injection_ranges) child:set_included_regions(injection_ranges)
local _, child_changes = child:parse()
-- Propagate any child changes so they are included in the
-- the change list for the callback.
if child_changes then
vim.list_extend(changes, child_changes)
end
seen_langs[lang] = true seen_langs[lang] = true
end end
end end
@ -248,6 +289,23 @@ function LanguageTree:parse()
end end
end end
self:_log({
changes = changes,
regions_parsed = regions_parsed,
parse_time = total_parse_time,
query_time = query_time,
})
self:for_each_child(function(child)
local _, child_changes = child:parse()
-- Propagate any child changes so they are included in the
-- the change list for the callback.
if child_changes then
vim.list_extend(changes, child_changes)
end
end)
self._valid = true self._valid = true
return self._trees, changes return self._trees, changes
@ -295,8 +353,6 @@ function LanguageTree:add_child(lang)
end end
self._children[lang] = LanguageTree.new(self._source, lang, self._opts) self._children[lang] = LanguageTree.new(self._source, lang, self._opts)
self._children[lang]._is_child = true
self:invalidate() self:invalidate()
self:_do_callback('child_added', self._children[lang]) self:_do_callback('child_added', self._children[lang])
@ -331,6 +387,53 @@ function LanguageTree:destroy()
end end
end end
---@private
---@param region Range6[]
local function region_tostr(region)
local srow, scol = region[1][1], region[1][2]
local erow, ecol = region[#region][4], region[#region][5]
return string.format('[%d:%d-%d:%d]', srow, scol, erow, ecol)
end
---@private
---Sets self._valid properly and efficiently
---@param fn fun(index: integer, region: Range6[]): boolean
function LanguageTree:_validate_regions(fn)
if not self._valid then
return
end
if type(self._valid) ~= 'table' then
self._valid = {}
end
local all_valid = true
for i, region in ipairs(self:included_regions()) do
if self._valid[i] == nil then
self._valid[i] = true
end
if self._valid[i] then
self._valid[i] = fn(i, region)
if not self._valid[i] then
self:_log(function()
return 'invalidating region', i, region_tostr(region)
end)
end
end
if not self._valid[i] then
all_valid = false
end
end
-- Compress the valid value to 'true' if there are no invalid regions
if all_valid then
self._valid = all_valid
end
end
--- Sets the included regions that should be parsed by this |LanguageTree|. --- Sets the included regions that should be parsed by this |LanguageTree|.
--- A region is a set of nodes and/or ranges that will be parsed in the same context. --- A region is a set of nodes and/or ranges that will be parsed in the same context.
--- ---
@ -357,56 +460,57 @@ function LanguageTree:set_included_regions(regions)
end end
end end
if #self._regions ~= #regions then if #self:included_regions() ~= #regions then
self._trees = {} self._trees = {}
self:invalidate() self:invalidate()
elseif self._valid ~= false then else
if self._valid == true then self:_validate_regions(function(i, region)
self._valid = {} return vim.deep_equal(regions[i], region)
for i = 1, #regions do end)
self._valid[i] = true
end
end
for i = 1, #regions do
if not vim.deep_equal(self._regions[i], regions[i]) then
self._valid[i] = false
end
if not self._valid[i] then
self._trees[i] = nil
end
end
end end
self._regions = regions self._regions = regions
end end
--- Gets the set of included regions ---Gets the set of included regions
---@return integer[][]
function LanguageTree:included_regions() function LanguageTree:included_regions()
return self._regions if self._regions then
return self._regions
end
if #self._trees == 0 then
return { {} }
end
local regions = {} ---@type Range6[][]
for i, _ in ipairs(self._trees) do
regions[i] = self._trees[i]:included_ranges(true)
end
self._regions = regions
return regions
end end
---@private ---@private
---@param node TSNode ---@param node TSNode
---@param id integer ---@param source integer|string
---@param metadata TSMetadata ---@param metadata TSMetadata
---@return Range4 ---@return Range6
local function get_range_from_metadata(node, id, metadata) local function get_range_from_metadata(node, source, metadata)
if metadata[id] and metadata[id].range then if metadata and metadata.range then
return metadata[id].range --[[@as Range4]] return Range.add_bytes(source, metadata.range --[[@as Range4|Range6]])
end end
return { node:range() } return { node:range(true) }
end end
---@private ---@private
--- TODO(lewis6991): cleanup of the node_range interface --- TODO(lewis6991): cleanup of the node_range interface
---@param node TSNode ---@param node TSNode
---@param id integer ---@param source string|integer
---@param metadata TSMetadata ---@param metadata TSMetadata
---@return Range4[] ---@return Range4[]
local function get_node_ranges(node, id, metadata, include_children) local function get_node_ranges(node, source, metadata, include_children)
local range = get_range_from_metadata(node, id, metadata) local range = get_range_from_metadata(node, source, metadata)
if include_children then if include_children then
return { range } return { range }
@ -414,7 +518,7 @@ local function get_node_ranges(node, id, metadata, include_children)
local ranges = {} ---@type Range4[] local ranges = {} ---@type Range4[]
local srow, scol, erow, ecol = range[1], range[2], range[3], range[4] local srow, scol, erow, ecol = Range.unpack4(range)
for i = 0, node:named_child_count() - 1 do for i = 0, node:named_child_count() - 1 do
local child = node:named_child(i) local child = node:named_child(i)
@ -498,7 +602,7 @@ function LanguageTree:_get_injection(match, metadata)
if name == 'injection.language' then if name == 'injection.language' then
lang = get_node_text(node, self._source, metadata[id]) lang = get_node_text(node, self._source, metadata[id])
elseif name == 'injection.content' then elseif name == 'injection.content' then
ranges = get_node_ranges(node, id, metadata, include_children) ranges = get_node_ranges(node, self._source, metadata[id], include_children)
end end
end end
@ -545,7 +649,7 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
elseif name == 'combined' then elseif name == 'combined' then
combined = true combined = true
elseif name == 'content' and #ranges == 0 then elseif name == 'content' and #ranges == 0 then
table.insert(ranges, get_range_from_metadata(node, id, metadata)) table.insert(ranges, get_range_from_metadata(node, self._source, metadata[id]))
-- Ignore any tags that start with "_" -- Ignore any tags that start with "_"
-- Allows for other tags to be used in matches -- Allows for other tags to be used in matches
elseif string.sub(name, 1, 1) ~= '_' then elseif string.sub(name, 1, 1) ~= '_' then
@ -554,7 +658,7 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
end end
if #ranges == 0 then if #ranges == 0 then
table.insert(ranges, get_range_from_metadata(node, id, metadata)) table.insert(ranges, get_range_from_metadata(node, self._source, metadata[id]))
end end
end end
end end
@ -569,7 +673,7 @@ end
--- TODO: Allow for an offset predicate to tailor the injection range --- TODO: Allow for an offset predicate to tailor the injection range
--- instead of using the entire nodes range. --- instead of using the entire nodes range.
---@private ---@private
---@return table<string, Range4[][]> ---@return table<string, Range6[][]>
function LanguageTree:_get_injections() function LanguageTree:_get_injections()
if not self._injection_query then if not self._injection_query then
return {} return {}
@ -594,7 +698,7 @@ function LanguageTree:_get_injections()
end end
end end
---@type table<string,Range4[][]> ---@type table<string,Range6[][]>
local result = {} local result = {}
-- Generate a map by lang of node lists. -- Generate a map by lang of node lists.
@ -634,42 +738,51 @@ function LanguageTree:_do_callback(cb_name, ...)
end end
---@private ---@private
---@param regions Range6[][] function LanguageTree:_edit(
---@param old_range Range6 start_byte,
---@param new_range Range6 end_byte_old,
---@return table<integer,boolean> region indices to invalidate end_byte_new,
local function update_regions(regions, old_range, new_range) start_row,
---@type table<integer,boolean> start_col,
local valid = {} end_row_old,
end_col_old,
for i, ranges in ipairs(regions or {}) do end_row_new,
valid[i] = true end_col_new
for j, r in ipairs(ranges) do )
if Range.intercepts(r, old_range) then for _, tree in ipairs(self._trees) do
valid[i] = false tree:edit(
break start_byte,
end end_byte_old,
end_byte_new,
-- Range after change. Adjust start_row,
if Range.cmp_pos.gt(r[1], r[2], old_range[4], old_range[5]) then start_col,
local byte_offset = new_range[6] - old_range[6] end_row_old,
local row_offset = new_range[4] - old_range[4] end_col_old,
end_row_new,
-- Update the range to avoid invalidation in set_included_regions() end_col_new
-- which will compare the regions against the parsed injection regions )
ranges[j] = {
r[1] + row_offset,
r[2],
r[3] + byte_offset,
r[4] + row_offset,
r[5],
r[6] + byte_offset,
}
end
end
end end
return valid self._regions = nil
local changed_range = {
start_row,
start_col,
start_byte,
end_row_old,
end_col_old,
end_byte_old,
}
-- Validate regions after editing the tree
self:_validate_regions(function(_, region)
for _, r in ipairs(region) do
if Range.intercepts(r, changed_range) then
return false
end
end
return true
end)
end end
---@private ---@private
@ -700,49 +813,26 @@ function LanguageTree:_on_bytes(
local old_end_col = old_col + ((old_row == 0) and start_col or 0) local old_end_col = old_col + ((old_row == 0) and start_col or 0)
local new_end_col = new_col + ((new_row == 0) and start_col or 0) local new_end_col = new_col + ((new_row == 0) and start_col or 0)
local old_range = { self:_log(
'on_bytes',
bufnr,
changed_tick,
start_row, start_row,
start_col, start_col,
start_byte, start_byte,
start_row + old_row, old_row,
old_end_col, old_col,
start_byte + old_byte, old_byte,
} new_row,
new_col,
local new_range = { new_byte
start_row, )
start_col,
start_byte,
start_row + new_row,
new_end_col,
start_byte + new_byte,
}
if #self._regions == 0 then
self._valid = false
else
self._valid = update_regions(self._regions, old_range, new_range)
end
for _, child in pairs(self._children) do
child:_on_bytes(
bufnr,
changed_tick,
start_row,
start_col,
start_byte,
old_row,
old_col,
old_byte,
new_row,
new_col,
new_byte
)
end
-- Edit trees together BEFORE emitting a bytes callback. -- Edit trees together BEFORE emitting a bytes callback.
for _, tree in ipairs(self._trees) do ---@private
tree:edit( self:for_each_child(function(child)
---@diagnostic disable-next-line:invisible
child:_edit(
start_byte, start_byte,
start_byte + old_byte, start_byte + old_byte,
start_byte + new_byte, start_byte + new_byte,
@ -753,24 +843,22 @@ function LanguageTree:_on_bytes(
start_row + new_row, start_row + new_row,
new_end_col new_end_col
) )
end end, true)
if not self._is_child then self:_do_callback(
self:_do_callback( 'bytes',
'bytes', bufnr,
bufnr, changed_tick,
changed_tick, start_row,
start_row, start_col,
start_col, start_byte,
start_byte, old_row,
old_row, old_col,
old_col, old_byte,
old_byte, new_row,
new_row, new_col,
new_col, new_byte
new_byte )
)
end
end end
---@private ---@private

View File

@ -277,6 +277,7 @@ end
---@return (string[]|string|nil) ---@return (string[]|string|nil)
function M.get_node_text(node, source, opts) function M.get_node_text(node, source, opts)
opts = opts or {} opts = opts or {}
-- TODO(lewis6991): concat only works when source is number.
local concat = vim.F.if_nil(opts.concat, true) local concat = vim.F.if_nil(opts.concat, true)
local metadata = opts.metadata or {} local metadata = opts.metadata or {}

View File

@ -64,6 +64,7 @@ static struct luaL_Reg tree_meta[] = {
{ "__tostring", tree_tostring }, { "__tostring", tree_tostring },
{ "root", tree_root }, { "root", tree_root },
{ "edit", tree_edit }, { "edit", tree_edit },
{ "included_ranges", tree_get_ranges },
{ "copy", tree_copy }, { "copy", tree_copy },
{ NULL, NULL } { NULL, NULL }
}; };
@ -364,19 +365,29 @@ static const char *input_cb(void *payload, uint32_t byte_index, TSPoint position
#undef BUFSIZE #undef BUFSIZE
} }
static void push_ranges(lua_State *L, const TSRange *ranges, const size_t length) static void push_ranges(lua_State *L, const TSRange *ranges, const size_t length,
bool include_bytes)
{ {
lua_createtable(L, (int)length, 0); lua_createtable(L, (int)length, 0);
for (size_t i = 0; i < length; i++) { for (size_t i = 0; i < length; i++) {
lua_createtable(L, 4, 0); lua_createtable(L, include_bytes ? 6 : 4, 0);
int j = 1;
lua_pushinteger(L, ranges[i].start_point.row); lua_pushinteger(L, ranges[i].start_point.row);
lua_rawseti(L, -2, 1); lua_rawseti(L, -2, j++);
lua_pushinteger(L, ranges[i].start_point.column); lua_pushinteger(L, ranges[i].start_point.column);
lua_rawseti(L, -2, 2); lua_rawseti(L, -2, j++);
if (include_bytes) {
lua_pushinteger(L, ranges[i].start_byte);
lua_rawseti(L, -2, j++);
}
lua_pushinteger(L, ranges[i].end_point.row); lua_pushinteger(L, ranges[i].end_point.row);
lua_rawseti(L, -2, 3); lua_rawseti(L, -2, j++);
lua_pushinteger(L, ranges[i].end_point.column); lua_pushinteger(L, ranges[i].end_point.column);
lua_rawseti(L, -2, 4); lua_rawseti(L, -2, j++);
if (include_bytes) {
lua_pushinteger(L, ranges[i].end_byte);
lua_rawseti(L, -2, j++);
}
lua_rawseti(L, -2, (int)(i + 1)); lua_rawseti(L, -2, (int)(i + 1));
} }
@ -395,6 +406,8 @@ static int parser_parse(lua_State *L)
old_tree = tmp ? *tmp : NULL; old_tree = tmp ? *tmp : NULL;
} }
bool include_bytes = (lua_gettop(L) >= 3) && lua_toboolean(L, 3);
TSTree *new_tree = NULL; TSTree *new_tree = NULL;
size_t len; size_t len;
const char *str; const char *str;
@ -445,7 +458,7 @@ static int parser_parse(lua_State *L)
push_tree(L, new_tree, false); // [tree] push_tree(L, new_tree, false); // [tree]
push_ranges(L, changed, n_ranges); // [tree, ranges] push_ranges(L, changed, n_ranges, include_bytes); // [tree, ranges]
xfree(changed); xfree(changed);
return 2; return 2;
@ -500,6 +513,24 @@ static int tree_edit(lua_State *L)
return 0; return 0;
} }
static int tree_get_ranges(lua_State *L)
{
TSTree **tree = tree_check(L, 1);
if (!(*tree)) {
return 0;
}
bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
uint32_t len;
TSRange *ranges = ts_tree_included_ranges(*tree, &len);
push_ranges(L, ranges, len, include_bytes);
xfree(ranges);
return 1;
}
// Use the top of the stack (without popping it) to create a TSRange, it can be // Use the top of the stack (without popping it) to create a TSRange, it can be
// either a lua table or a TSNode // either a lua table or a TSNode
static void range_from_lua(lua_State *L, TSRange *range) static void range_from_lua(lua_State *L, TSRange *range)
@ -605,10 +636,12 @@ static int parser_get_ranges(lua_State *L)
return 0; return 0;
} }
bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
uint32_t len; uint32_t len;
const TSRange *ranges = ts_parser_included_ranges(*p, &len); const TSRange *ranges = ts_parser_included_ranges(*p, &len);
push_ranges(L, ranges, len); push_ranges(L, ranges, len, include_bytes);
return 1; return 1;
} }
@ -783,10 +816,7 @@ static int node_range(lua_State *L)
return 0; return 0;
} }
bool include_bytes = false; bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
if (lua_gettop(L) >= 2) {
include_bytes = lua_toboolean(L, 2);
}
TSPoint start = ts_node_start_point(node); TSPoint start = ts_node_start_point(node);
TSPoint end = ts_node_end_point(node); TSPoint end = ts_node_end_point(node);