Merge pull request #22594 from lewis6991/perf/treefold

This commit is contained in:
Lewis Russell 2023-03-10 13:35:07 +00:00 committed by GitHub
commit 845efb8e12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 287 additions and 148 deletions

View File

@ -99,13 +99,28 @@ function M.get_parser(bufnr, lang, opts)
if bufnr == nil or bufnr == 0 then if bufnr == nil or bufnr == 0 then
bufnr = a.nvim_get_current_buf() bufnr = a.nvim_get_current_buf()
end end
if lang == nil then if lang == nil then
local ft = vim.bo[bufnr].filetype local ft = vim.bo[bufnr].filetype
lang = language.get_lang(ft) or ft if ft ~= '' then
-- TODO(lewis6991): we should error here and not default to ft lang = language.get_lang(ft) or ft
-- if not lang then -- TODO(lewis6991): we should error here and not default to ft
-- error(string.format('filetype %s of buffer %d is not associated with any lang', ft, bufnr)) -- if not lang then
-- end -- error(string.format('filetype %s of buffer %d is not associated with any lang', ft, bufnr))
-- end
else
if parsers[bufnr] then
return parsers[bufnr]
end
error(
string.format(
'There is no parser available for buffer %d and one could not be'
.. ' created because lang could not be determined. Either pass lang'
.. ' or set the buffer filetype',
bufnr
)
)
end
end end
if parsers[bufnr] == nil or parsers[bufnr]:lang() ~= lang then if parsers[bufnr] == nil or parsers[bufnr]:lang() ~= lang then

View File

@ -1,139 +1,157 @@
local Range = require('vim.treesitter._range')
local api = vim.api local api = vim.api
local M = {} ---@class FoldInfo
---@field levels table<integer,string>
---@field levels0 table<integer,integer>
---@field private start_counts table<integer,integer>
---@field private stop_counts table<integer,integer>
local FoldInfo = {}
FoldInfo.__index = FoldInfo
--- Memoizes a function based on the buffer tick of the provided bufnr. function FoldInfo.new()
--- The cache entry is cleared when the buffer is detached to avoid memory leaks. return setmetatable({
---@generic F: function start_counts = {},
---@param fn F fn to memoize, taking the bufnr as first argument stop_counts = {},
---@return F levels0 = {},
local function memoize_by_changedtick(fn) levels = {},
---@type table<integer,{result:any,last_tick:integer}> }, FoldInfo)
local cache = {} end
---@param bufnr integer ---@param srow integer
return function(bufnr, ...) ---@param erow integer
local tick = api.nvim_buf_get_changedtick(bufnr) function FoldInfo:invalidate_range(srow, erow)
for i = srow, erow do
if cache[bufnr] then self.start_counts[i + 1] = nil
if cache[bufnr].last_tick == tick then self.stop_counts[i + 1] = nil
return cache[bufnr].result self.levels0[i + 1] = nil
end self.levels[i + 1] = nil
else
local function detach_handler()
cache[bufnr] = nil
end
-- Clean up logic only!
api.nvim_buf_attach(bufnr, false, {
on_detach = detach_handler,
on_reload = detach_handler,
})
end
cache[bufnr] = {
result = fn(bufnr, ...),
last_tick = tick,
}
return cache[bufnr].result
end end
end end
---@param bufnr integer ---@param srow integer
---@param capture string ---@param erow integer
---@param query_name string function FoldInfo:remove_range(srow, erow)
---@param callback fun(id: integer, node:TSNode, metadata: TSMetadata) for i = erow - 1, srow, -1 do
local function iter_matches_with_capture(bufnr, capture, query_name, callback) table.remove(self.levels, i + 1)
local parser = vim.treesitter.get_parser(bufnr) table.remove(self.levels0, i + 1)
table.remove(self.start_counts, i + 1)
if not parser then table.remove(self.stop_counts, i + 1)
return
end end
end
parser:for_each_tree(function(tree, lang_tree) ---@param srow integer
local lang = lang_tree:lang() ---@param erow integer
local query = vim.treesitter.query.get_query(lang, query_name) function FoldInfo:add_range(srow, erow)
if query then for i = srow, erow - 1 do
local root = tree:root() table.insert(self.levels, i + 1, '-1')
local start, _, stop = root:range() table.insert(self.levels0, i + 1, -1)
for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do table.insert(self.start_counts, i + 1, nil)
for id, node in pairs(match) do table.insert(self.stop_counts, i + 1, nil)
if query.captures[id] == capture then end
callback(id, node, metadata) end
end
end ---@param lnum integer
end function FoldInfo:add_start(lnum)
end self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1
end) end
---@param lnum integer
function FoldInfo:add_stop(lnum)
self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1
end
---@param lnum integer
---@return integer
function FoldInfo:get_start(lnum)
return self.start_counts[lnum] or 0
end
---@param lnum integer
---@return integer
function FoldInfo:get_stop(lnum)
return self.stop_counts[lnum] or 0
end end
---@private ---@private
--- TODO(lewis6991): copied from languagetree.lua. Consolidate --- TODO(lewis6991): copied from languagetree.lua. Consolidate
---@param node TSNode ---@param node TSNode
---@param id integer
---@param metadata TSMetadata ---@param metadata TSMetadata
---@return Range ---@return Range4
local function get_range_from_metadata(node, id, metadata) local function get_range_from_metadata(node, metadata)
if metadata[id] and metadata[id].range then if metadata and metadata.range then
return metadata[id].range --[[@as Range]] return metadata.range --[[@as Range4]]
end end
return { node:range() } return { node:range() }
end end
-- This is cached on buf tick to avoid computing that multiple times local function trim_level(level)
-- Especially not for every line in the file when `zx` is hit
---@param bufnr integer
---@return table<integer,string>
local folds_levels = memoize_by_changedtick(function(bufnr)
local max_fold_level = vim.wo.foldnestmax local max_fold_level = vim.wo.foldnestmax
local function trim_level(level) if level > max_fold_level then
if level > max_fold_level then return max_fold_level
return max_fold_level
end
return level
end end
return level
end
-- start..stop is an inclusive range ---@param bufnr integer
local start_counts = {} ---@type table<integer,integer> ---@param info FoldInfo
local stop_counts = {} ---@type table<integer,integer> ---@param srow integer?
---@param erow integer?
local function get_folds_levels(bufnr, info, srow, erow)
srow = srow or 0
erow = erow or api.nvim_buf_line_count(bufnr)
info:invalidate_range(srow, erow)
local prev_start = -1 local prev_start = -1
local prev_stop = -1 local prev_stop = -1
local min_fold_lines = vim.wo.foldminlines vim.treesitter.get_parser(bufnr):for_each_tree(function(tree, ltree)
local query = vim.treesitter.query.get_query(ltree:lang(), 'folds')
iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata) if not query then
local range = get_range_from_metadata(node, id, metadata) return
local start, stop, stop_col = range[1], range[3], range[4]
if stop_col == 0 then
stop = stop - 1
end end
local fold_length = stop - start + 1 -- erow in query is end-exclusive
local q_erow = erow and erow + 1 or -1
-- Fold only multiline nodes that are not exactly the same as previously met folds for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do
-- Checking against just the previously found fold is sufficient if nodes if query.captures[id] == 'fold' then
-- are returned in preorder or postorder when traversing tree local range = get_range_from_metadata(node, metadata[id])
if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then local start, _, stop, stop_col = Range.unpack4(range)
start_counts[start] = (start_counts[start] or 0) + 1
stop_counts[stop] = (stop_counts[stop] or 0) + 1 if stop_col == 0 then
prev_start = start stop = stop - 1
prev_stop = stop end
local fold_length = stop - start + 1
-- Fold only multiline nodes that are not exactly the same as previously met folds
-- Checking against just the previously found fold is sufficient if nodes
-- are returned in preorder or postorder when traversing tree
if
fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
then
info:add_start(start + 1)
info:add_stop(stop + 1)
prev_start = start
prev_stop = stop
end
end
end end
end) end)
---@type table<integer,string> local current_level = info.levels0[srow] or 0
local levels = {}
local current_level = 0
-- We now have the list of fold opening and closing, fill the gaps and mark where fold start -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
for lnum = 0, api.nvim_buf_line_count(bufnr) do for lnum = srow + 1, erow + 1 do
local last_trimmed_level = trim_level(current_level) local last_trimmed_level = trim_level(current_level)
current_level = current_level + (start_counts[lnum] or 0) current_level = current_level + info:get_start(lnum)
info.levels0[lnum] = current_level
local trimmed_level = trim_level(current_level) local trimmed_level = trim_level(current_level)
current_level = current_level - (stop_counts[lnum] or 0) current_level = current_level - info:get_stop(lnum)
-- Determine if it's the start/end of a fold -- Determine if it's the start/end of a fold
-- NB: vim's fold-expr interface does not have a mechanism to indicate that -- NB: vim's fold-expr interface does not have a mechanism to indicate that
@ -148,11 +166,61 @@ local folds_levels = memoize_by_changedtick(function(bufnr)
prefix = '>' prefix = '>'
end end
levels[lnum + 1] = prefix .. tostring(trimmed_level) info.levels[lnum] = prefix .. tostring(trimmed_level)
end
end
local M = {}
---@type table<integer,FoldInfo>
local foldinfos = {}
local function recompute_folds()
if api.nvim_get_mode().mode == 'i' then
-- foldUpdate() is guarded in insert mode. So update folds on InsertLeave
api.nvim_create_autocmd('InsertLeave', {
once = true,
callback = vim._foldupdate,
})
return
end end
return levels vim._foldupdate()
end) end
---@param bufnr integer
---@param foldinfo FoldInfo
---@param tree_changes Range4[]
local function on_changedtree(bufnr, foldinfo, tree_changes)
-- For some reason, queries seem to use the old buffer state in on_bytes.
-- Get around this by scheduling and manually updating folds.
vim.schedule(function()
for _, change in ipairs(tree_changes) do
local srow, _, erow = Range.unpack4(change)
get_folds_levels(bufnr, foldinfo, srow, erow)
end
recompute_folds()
end)
end
---@param bufnr integer
---@param foldinfo FoldInfo
---@param start_row integer
---@param old_row integer
---@param new_row integer
local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row)
local end_row_old = start_row + old_row
local end_row_new = start_row + new_row
if new_row < old_row then
foldinfo:remove_range(end_row_old, end_row_new)
elseif new_row > old_row then
foldinfo:add_range(start_row, end_row_new)
vim.schedule(function()
get_folds_levels(bufnr, foldinfo, start_row, end_row_new)
recompute_folds()
end)
end
end
---@param lnum integer|nil ---@param lnum integer|nil
---@return string ---@return string
@ -165,9 +233,27 @@ function M.foldexpr(lnum)
return '0' return '0'
end end
local levels = folds_levels(bufnr) or {} if not foldinfos[bufnr] then
foldinfos[bufnr] = FoldInfo.new()
get_folds_levels(bufnr, foldinfos[bufnr])
return levels[lnum] or '0' local parser = vim.treesitter.get_parser(bufnr)
parser:register_cbs({
on_changedtree = function(tree_changes)
on_changedtree(bufnr, foldinfos[bufnr], tree_changes)
end,
on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _)
on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row)
end,
on_detach = function()
foldinfos[bufnr] = nil
end,
})
end
return foldinfos[bufnr].levels[lnum] or '0'
end end
return M return M

View File

@ -60,16 +60,6 @@ function M.add(lang, opts)
filetype = { filetype, { 'string', 'table' }, true }, filetype = { filetype, { 'string', 'table' }, true },
}) })
if filetype == '' then
error(string.format("'%s' is not a valid filetype", filetype))
elseif type(filetype) == 'table' then
for _, f in ipairs(filetype) do
if f == '' then
error(string.format("'%s' is not a valid filetype", filetype))
end
end
end
M.register(lang, filetype or lang) M.register(lang, filetype or lang)
if vim._ts_has_language(lang) then if vim._ts_has_language(lang) then
@ -109,7 +99,9 @@ function M.register(lang, filetype)
end end
for _, f in ipairs(filetypes) do for _, f in ipairs(filetypes) do
ft_to_lang[f] = lang if f ~= '' then
ft_to_lang[f] = lang
end
end end
end end

View File

@ -26,6 +26,7 @@
#include "nvim/eval/typval.h" #include "nvim/eval/typval.h"
#include "nvim/eval/typval_defs.h" #include "nvim/eval/typval_defs.h"
#include "nvim/ex_eval.h" #include "nvim/ex_eval.h"
#include "nvim/fold.h"
#include "nvim/globals.h" #include "nvim/globals.h"
#include "nvim/lua/converter.h" #include "nvim/lua/converter.h"
#include "nvim/lua/spell.h" #include "nvim/lua/spell.h"
@ -528,6 +529,31 @@ static int nlua_iconv(lua_State *lstate)
return 1; return 1;
} }
// Like 'zx' but don't call newFoldLevel()
static int nlua_foldupdate(lua_State *lstate)
{
curwin->w_foldinvalid = true; // recompute folds
foldOpenCursor();
return 0;
}
// Access to internal functions. For use in runtime/
static void nlua_state_add_internal(lua_State *const lstate)
{
// _getvar
lua_pushcfunction(lstate, &nlua_getvar);
lua_setfield(lstate, -2, "_getvar");
// _setvar
lua_pushcfunction(lstate, &nlua_setvar);
lua_setfield(lstate, -2, "_setvar");
// _updatefolds
lua_pushcfunction(lstate, &nlua_foldupdate);
lua_setfield(lstate, -2, "_foldupdate");
}
void nlua_state_add_stdlib(lua_State *const lstate, bool is_thread) void nlua_state_add_stdlib(lua_State *const lstate, bool is_thread)
{ {
if (!is_thread) { if (!is_thread) {
@ -562,14 +588,6 @@ void nlua_state_add_stdlib(lua_State *const lstate, bool is_thread)
lua_setfield(lstate, -2, "__index"); // [meta] lua_setfield(lstate, -2, "__index"); // [meta]
lua_pop(lstate, 1); // don't use metatable now lua_pop(lstate, 1); // don't use metatable now
// _getvar
lua_pushcfunction(lstate, &nlua_getvar);
lua_setfield(lstate, -2, "_getvar");
// _setvar
lua_pushcfunction(lstate, &nlua_setvar);
lua_setfield(lstate, -2, "_setvar");
// vim.spell // vim.spell
luaopen_spell(lstate); luaopen_spell(lstate);
lua_setfield(lstate, -2, "spell"); lua_setfield(lstate, -2, "spell");
@ -578,6 +596,8 @@ void nlua_state_add_stdlib(lua_State *const lstate, bool is_thread)
// depends on p_ambw, p_emoji // depends on p_ambw, p_emoji
lua_pushcfunction(lstate, &nlua_iconv); lua_pushcfunction(lstate, &nlua_iconv);
lua_setfield(lstate, -2, "iconv"); lua_setfield(lstate, -2, "iconv");
nlua_state_add_internal(lstate);
} }
// vim.mpack // vim.mpack

View File

@ -405,8 +405,6 @@ static int parser_parse(lua_State *L)
old_tree = tmp ? *tmp : NULL; old_tree = tmp ? *tmp : NULL;
} }
bool include_bytes = (lua_gettop(L) >= 3) && lua_toboolean(L, 3);
TSTree *new_tree = NULL; TSTree *new_tree = NULL;
size_t len; size_t len;
const char *str; const char *str;
@ -443,6 +441,8 @@ static int parser_parse(lua_State *L)
return luaL_argerror(L, 3, "expected either string or buffer handle"); return luaL_argerror(L, 3, "expected either string or buffer handle");
} }
bool include_bytes = (lua_gettop(L) >= 4) && lua_toboolean(L, 4);
// Sometimes parsing fails (timeout, or wrong parser ABI) // Sometimes parsing fails (timeout, or wrong parser ABI)
// In those case, just return an error. // In those case, just return an error.
if (!new_tree) { if (!new_tree) {

View File

@ -36,11 +36,6 @@ describe('treesitter language API', function()
pcall_err(exec_lua, 'vim.treesitter.add("/foo/")')) pcall_err(exec_lua, 'vim.treesitter.add("/foo/")'))
end) end)
it('shows error for invalid filetype', function()
eq('.../language.lua:0: \'\' is not a valid filetype',
pcall_err(exec_lua, [[vim.treesitter.add('foo', { filetype = '' })]]))
end)
it('inspects language', function() it('inspects language', function()
local keys, fields, symbols = unpack(exec_lua([[ local keys, fields, symbols = unpack(exec_lua([[
local lang = vim.treesitter.inspect_language('c') local lang = vim.treesitter.inspect_language('c')

View File

@ -128,7 +128,9 @@ void ui_refresh(void)
it('does not get parser for empty filetype', function() it('does not get parser for empty filetype', function()
insert(test_text); insert(test_text);
eq(".../language.lua:0: '' is not a valid filetype", eq('.../treesitter.lua:0: There is no parser available for buffer 1 and one'
.. ' could not be created because lang could not be determined. Either'
.. ' pass lang or set the buffer filetype',
pcall_err(exec_lua, 'vim.treesitter.get_parser(0)')) pcall_err(exec_lua, 'vim.treesitter.get_parser(0)'))
-- Must provide language for buffers with an empty filetype -- Must provide language for buffers with an empty filetype
@ -886,18 +888,20 @@ int x = INT_MAX;
it("can fold via foldexpr", function() it("can fold via foldexpr", function()
insert(test_text) insert(test_text)
local levels = exec_lua([[ local function get_fold_levels()
vim.opt.filetype = 'c' return exec_lua([[
vim.treesitter.get_parser(0, "c") local res = {}
local res = {} for i = 1, vim.api.nvim_buf_line_count(0) do
for i = 1, vim.api.nvim_buf_line_count(0) do res[i] = vim.treesitter.foldexpr(i)
res[i] = vim.treesitter.foldexpr(i) end
end return res
return res ]])
]]) end
exec_lua([[vim.treesitter.get_parser(0, "c")]])
eq({ eq({
[1] = '>1', [1] = '>1',
[2] = '1', [2] = '1',
[3] = '1', [3] = '1',
[4] = '1', [4] = '1',
@ -915,6 +919,33 @@ int x = INT_MAX;
[16] = '3', [16] = '3',
[17] = '3', [17] = '3',
[18] = '2', [18] = '2',
[19] = '1' }, levels) [19] = '1' }, get_fold_levels())
helpers.command('1,2d')
helpers.poke_eventloop()
exec_lua([[vim.treesitter.get_parser():parse()]])
helpers.poke_eventloop()
helpers.sleep(100)
eq({
[1] = '0',
[2] = '0',
[3] = '>1',
[4] = '1',
[5] = '1',
[6] = '0',
[7] = '0',
[8] = '>1',
[9] = '1',
[10] = '1',
[11] = '1',
[12] = '1',
[13] = '>2',
[14] = '2',
[15] = '2',
[16] = '1',
[17] = '0' }, get_fold_levels())
end) end)
end) end)