treesitter: add node:field() to get field children

This commit is contained in:
Thomas Vigouroux 2020-08-14 16:01:10 +02:00
parent e123fd0a5d
commit 18217b987f
3 changed files with 60 additions and 0 deletions

View File

@ -628,6 +628,9 @@ tsnode:iter_children() *tsnode:iter_children()*
Returns the child node plus the eventual field name corresponding to
this child node.
tsnode:field({name}) *tsnode:field()*
Returns a table of the nodes corresponding to the {name} field.
tsnode:child_count() *tsnode:child_count()*
Get the node's number of children.

View File

@ -62,6 +62,7 @@ static struct luaL_Reg node_meta[] = {
{ "end_", node_end },
{ "type", node_type },
{ "symbol", node_symbol },
{ "field", node_field },
{ "named", node_named },
{ "missing", node_missing },
{ "has_error", node_has_error },
@ -653,6 +654,34 @@ static int node_symbol(lua_State *L)
return 1;
}
static int node_field(lua_State *L)
{
TSNode node;
if (!node_check(L, 1, &node)) {
return 0;
}
size_t name_len;
const char *field_name = luaL_checklstring(L, 2, &name_len);
TSTreeCursor cursor = ts_tree_cursor_new(node);
lua_newtable(L); // [table]
unsigned int curr_index = 0;
if (ts_tree_cursor_goto_first_child(&cursor)) {
do {
if (!STRCMP(field_name, ts_tree_cursor_current_field_name(&cursor))) {
push_node(L, ts_tree_cursor_current_node(&cursor), 1); // [table, node]
lua_rawseti(L, -2, ++curr_index);
}
} while (ts_tree_cursor_goto_next_sibling(&cursor));
}
ts_tree_cursor_delete(&cursor);
return 1;
}
static int node_named(lua_State *L)
{
TSNode node;

View File

@ -151,6 +151,34 @@ void ui_refresh(void)
}, res)
end)
it('allows to get a child by field', function()
if not check_parser() then return end
insert(test_text);
local res = exec_lua([[
parser = vim.treesitter.get_parser(0, "c")
func_node = parser:parse():root():child(0)
local res = {}
for _, node in ipairs(func_node:field("type")) do
table.insert(res, {node:type(), node:range()})
end
return res
]])
eq({{ "primitive_type", 0, 0, 0, 4 }}, res)
local res_fail = exec_lua([[
parser = vim.treesitter.get_parser(0, "c")
return #func_node:field("foo") == 0
]])
assert(res_fail)
end)
local query = [[
((call_expression function: (identifier) @minfunc (argument_list (identifier) @min_id)) (eq? @minfunc "MIN"))
"for" @keyword