diff --git a/server/cmd/mmctl/commands/completion.go b/server/cmd/mmctl/commands/completion.go index 666cb48311..f58685de4f 100644 --- a/server/cmd/mmctl/commands/completion.go +++ b/server/cmd/mmctl/commands/completion.go @@ -4,8 +4,12 @@ package commands import ( + "context" "os" + "strings" + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/cmd/mmctl/client" "github.com/spf13/cobra" ) @@ -202,3 +206,69 @@ __mmctl_bash_source <(__mmctl_convert_bash_to_zsh) return nil } + +func noCompletion(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { + return nil, cobra.ShellCompDirectiveNoFileComp +} + +type validateArgsFn func(ctx context.Context, c client.Client, cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) + +func validateArgsWithClient(fn validateArgsFn) func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { //nolint:unused // Remove with https://github.com/mattermost/mattermost/pull/25633 + return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + ctx, cancel := context.WithTimeout(context.Background(), shellCompleteTimeout) + defer cancel() + + c, _, _, err := getClient(ctx) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + return fn(ctx, c, cmd, args, toComplete) + } +} + +type fetcher[T any] func(ctx context.Context, c client.Client, page int, perPage int) ([]T, *model.Response, error) // fetcher calls the Mattermost API to fetch a list of entities T. +type matcher[T any] func(t T) []string // matcher returns list of field that are T uses for shell completion. + +func fetchAndComplete[T any](f fetcher[T], m matcher[T]) validateArgsFn { + return func(ctx context.Context, c client.Client, cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + res := []string{} + + if toComplete == "" { + return res, cobra.ShellCompDirectiveNoFileComp + } + + var page int + for { + entities, _, err := f(ctx, c, page, perPage) + if err != nil { + // Return what we got so far + return res, cobra.ShellCompDirectiveNoFileComp + } + + for _, e := range entities { + for _, field := range m(e) { + if strings.HasPrefix(field, toComplete) { + res = append(res, field) + + // Only complete one field per entity. + break + } + } + } + + if len(res) > shellCompletionMaxItems { + res = res[:shellCompletionMaxItems] + break + } + + if len(entities) < perPage { + break + } + + page++ + } + + return res, cobra.ShellCompDirectiveNoFileComp + } +} diff --git a/server/cmd/mmctl/commands/completion_test.go b/server/cmd/mmctl/commands/completion_test.go new file mode 100644 index 0000000000..9fce818b0a --- /dev/null +++ b/server/cmd/mmctl/commands/completion_test.go @@ -0,0 +1,159 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package commands + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/cmd/mmctl/client" +) + +func TestFetchAndComplete(t *testing.T) { + type user struct { + name string + position string + } + + createUsers := func(page, perPage int) []user { + ret := []user{} + for i := perPage * page; i < perPage*(page+1); i++ { + ret = append(ret, user{ + name: fmt.Sprintf("name_%d", i), + position: fmt.Sprintf("position_%d", i), + }) + } + return ret + } + + listNames := func(n int) []string { + ret := []string{} + for i := 0; i < n; i++ { + ret = append(ret, fmt.Sprintf("name_%d", i)) + } + return ret + } + + for name, tc := range map[string]struct { + fetcher func(ctx context.Context, c client.Client, page int, perPage int) ([]user, *model.Response, error) + matcher func(t user) []string + toComplete string + ExpectedCompletion []string + ExpectedDirective cobra.ShellCompDirective // Defaults to cobra.ShellCompDirectiveNoFileComp + }{ + "empty query leads to empty result": { + fetcher: func(ctx context.Context, c client.Client, page int, perPage int) ([]user, *model.Response, error) { + return []user{{name: "bob"}, {name: "alice"}}, nil, nil + }, + matcher: func(t user) []string { + return []string{t.name} + }, + toComplete: "", + ExpectedCompletion: []string{}, + }, + "no matches": { + fetcher: func(ctx context.Context, c client.Client, page int, perPage int) ([]user, *model.Response, error) { + return []user{{name: "bob"}, {name: "alice"}}, nil, nil + }, + matcher: func(t user) []string { + return []string{t.name} + }, + toComplete: "x", + ExpectedCompletion: []string{}, + }, + "one element matches": { + fetcher: func(ctx context.Context, c client.Client, page int, perPage int) ([]user, *model.Response, error) { + return []user{{name: "bob"}, {name: "alice"}}, nil, nil + }, + matcher: func(t user) []string { + return []string{t.name} + }, + toComplete: "b", + ExpectedCompletion: []string{"bob"}, + }, + "two element matches": { + fetcher: func(ctx context.Context, c client.Client, page int, perPage int) ([]user, *model.Response, error) { + return []user{{name: "anne"}, {name: "alice"}}, nil, nil + }, + matcher: func(t user) []string { + return []string{t.name} + }, + toComplete: "a", + ExpectedCompletion: []string{"anne", "alice"}, + }, + "only match one fields per element": { + fetcher: func(ctx context.Context, c client.Client, page int, perPage int) ([]user, *model.Response, error) { + return []user{{name: "bob", position: "backend"}, {name: "alice"}}, nil, nil + }, + matcher: func(t user) []string { + return []string{t.name, t.position} + }, + toComplete: "b", + ExpectedCompletion: []string{"bob"}, + }, + "error ignored returns": { + fetcher: func(ctx context.Context, c client.Client, page int, perPage int) ([]user, *model.Response, error) { + return []user{{name: "bob", position: "backend"}, {name: "alice"}}, nil, errors.New("some error") + }, + matcher: func(t user) []string { + return []string{t.name, t.position} + }, + toComplete: "b", + ExpectedCompletion: []string{}, + }, + "limit to 50": { + fetcher: func(ctx context.Context, c client.Client, page int, perPage int) ([]user, *model.Response, error) { + return createUsers(0, 200), nil, nil + }, + matcher: func(t user) []string { + return []string{t.name} + }, + toComplete: "name", + ExpectedCompletion: listNames(50), + }, + "request multipile pages": { + fetcher: func(ctx context.Context, c client.Client, page int, perPage int) ([]user, *model.Response, error) { + return createUsers(page, perPage), nil, nil + }, + matcher: func(t user) []string { + return []string{t.name} + }, + toComplete: "name_4", + ExpectedCompletion: []string{ + "name_4", "name_40", "name_41", "name_42", "name_43", "name_44", "name_45", "name_46", "name_47", "name_48", "name_49", + "name_400", "name_401", "name_402", "name_403", "name_404", "name_405", "name_406", "name_407", "name_408", "name_409", "name_410", "name_411", "name_412", + "name_413", "name_414", "name_415", "name_416", "name_417", "name_418", "name_419", "name_420", "name_421", "name_422", "name_423", "name_424", "name_425", + "name_426", "name_427", "name_428", "name_429", "name_430", "name_431", "name_432", "name_433", "name_434", "name_435", "name_436", "name_437", "name_438", + }, + }, + } { + t.Run(name, func(t *testing.T) { + name := name // TODO: Remove once go1.22 is used + tc := tc // TODO: Remove once go1.22 is used + t.Parallel() + + comp, directive := fetchAndComplete[user](tc.fetcher, tc.matcher)(context.Background(), nil, nil, nil, tc.toComplete) + assert.Equal(t, tc.ExpectedCompletion, comp, name) + + expectedDirective := cobra.ShellCompDirectiveNoFileComp + if tc.ExpectedDirective != 0 { + expectedDirective = tc.ExpectedDirective + } + + assert.Equal(t, expectedDirective, directive, name) + }) + } +} + +func TestNoCompletion(t *testing.T) { + comp, directive := noCompletion(nil, nil, "any") + assert.Nil(t, comp) + assert.Equal(t, cobra.ShellCompDirectiveNoFileComp, directive) +} diff --git a/server/cmd/mmctl/commands/init.go b/server/cmd/mmctl/commands/init.go index 1f6c995e00..a990e92846 100644 --- a/server/cmd/mmctl/commands/init.go +++ b/server/cmd/mmctl/commands/init.go @@ -12,6 +12,7 @@ import ( "os" "runtime" "strings" + "time" "github.com/Masterminds/semver/v3" "github.com/pkg/errors" @@ -24,6 +25,13 @@ import ( "github.com/mattermost/mattermost/server/v8/cmd/mmctl/printer" ) +const ( + perPage = 200 + + shellCompletionMaxItems = 50 // Maximum number of items that will be loaded and shown in shell completion. + shellCompleteTimeout = 5 * time.Second +) + var ( insecureSignatureAlgorithms = map[x509.SignatureAlgorithm]bool{ x509.SHA1WithRSA: true, @@ -62,22 +70,34 @@ func CheckVersionMatch(version, serverVersion string) (bool, error) { return true, nil } +func getClient(ctx context.Context) (*model.Client4, string, bool, error) { + if viper.GetBool("local") { + c, err := InitUnixClient(viper.GetString("local-socket-path")) + if err != nil { + return nil, "", true, err + } + printer.SetServerAddres("local instance") + + return c, "", true, nil + } + + c, serverVersion, err := InitClient(ctx, viper.GetBool("insecure-sha1-intermediate"), viper.GetBool("insecure-tls-version")) + if err != nil { + return nil, "", false, err + } + + return c, serverVersion, false, nil +} + func withClient(fn func(c client.Client, cmd *cobra.Command, args []string) error) func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error { - if viper.GetBool("local") { - c, err := InitUnixClient(viper.GetString("local-socket-path")) - if err != nil { - return err - } - printer.SetServerAddres("local instance") - return fn(c, cmd, args) - } - ctx := context.TODO() - - c, serverVersion, err := InitClient(ctx, viper.GetBool("insecure-sha1-intermediate"), viper.GetBool("insecure-tls-version")) + c, serverVersion, local, err := getClient(ctx) if err != nil { - return err + return fmt.Errorf("failed to create client: %w", err) + } + if local { + return fn(c, cmd, args) } if Version != "unspecified" { // unspecified version indicates that we are on dev mode.