package authnimpl import ( "context" "errors" "net" "net/http" "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/models/usertoken" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn/authntest" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login/authinfotest" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) func TestService_Authenticate(t *testing.T) { type TestCase struct { desc string clients []authn.Client expectedIdentity *authn.Identity expectedErrors []error } var ( firstErr = errors.New("first") lastErr = errors.New("last") ) tests := []TestCase{ { desc: "should succeed with authentication for configured client", clients: []authn.Client{ &authntest.FakeClient{ExpectedTest: true, ExpectedIdentity: &authn.Identity{ID: "user:1"}}, }, expectedIdentity: &authn.Identity{ID: "user:1"}, }, { desc: "should succeed with authentication for second client when first test fail", clients: []authn.Client{ &authntest.FakeClient{ExpectedName: "1", ExpectedPriority: 1, ExpectedTest: false}, &authntest.FakeClient{ExpectedName: "2", ExpectedPriority: 2, ExpectedTest: true, ExpectedIdentity: &authn.Identity{ID: "user:2"}}, }, expectedIdentity: &authn.Identity{ID: "user:2"}, }, { desc: "should succeed with authentication for third client when error happened in first", clients: []authn.Client{ &authntest.FakeClient{ExpectedName: "1", ExpectedPriority: 2, ExpectedTest: false}, &authntest.FakeClient{ExpectedName: "2", ExpectedPriority: 1, ExpectedTest: true, ExpectedErr: errors.New("some error")}, &authntest.FakeClient{ExpectedName: "3", ExpectedPriority: 3, ExpectedTest: true, ExpectedIdentity: &authn.Identity{ID: "user:3"}}, }, expectedIdentity: &authn.Identity{ID: "user:3"}, }, { desc: "should return error when no client could authenticate the request", clients: []authn.Client{ &authntest.FakeClient{ExpectedName: "1", ExpectedPriority: 2, ExpectedTest: false}, &authntest.FakeClient{ExpectedName: "2", ExpectedPriority: 1, ExpectedTest: false}, &authntest.FakeClient{ExpectedName: "3", ExpectedPriority: 3, ExpectedTest: false}, }, expectedErrors: []error{errCantAuthenticateReq}, }, { desc: "should return all errors in chain", clients: []authn.Client{ &authntest.FakeClient{ExpectedName: "1", ExpectedPriority: 2, ExpectedTest: false}, &authntest.FakeClient{ExpectedName: "2", ExpectedPriority: 1, ExpectedTest: true, ExpectedErr: firstErr}, &authntest.FakeClient{ExpectedName: "3", ExpectedPriority: 3, ExpectedTest: true, ExpectedErr: lastErr}, }, expectedErrors: []error{firstErr, lastErr}, }, { desc: "should return error on disabled identity", clients: []authn.Client{ &authntest.FakeClient{ExpectedName: "1", ExpectedTest: true, ExpectedIdentity: &authn.Identity{IsDisabled: true}}, }, expectedErrors: []error{errDisabledIdentity}, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { svc := setupTests(t, func(svc *Service) { for _, c := range tt.clients { svc.RegisterClient(c) } }) identity, err := svc.Authenticate(context.Background(), &authn.Request{}) if len(tt.expectedErrors) == 0 { assert.NoError(t, err) assert.EqualValues(t, tt.expectedIdentity, identity) } else { for _, e := range tt.expectedErrors { assert.ErrorIs(t, err, e) } assert.Nil(t, identity) } }) } } func TestService_OrgID(t *testing.T) { type TestCase struct { desc string req *authn.Request expectedOrgID int64 } tests := []TestCase{ { desc: "should set org id when present in header", req: &authn.Request{HTTPRequest: &http.Request{ Header: map[string][]string{orgIDHeaderName: {"1"}}, URL: &url.URL{}, }}, expectedOrgID: 1, }, { desc: "should set org id when present in url", req: &authn.Request{HTTPRequest: &http.Request{ Header: map[string][]string{}, URL: mustParseURL("http://localhost/?targetOrgId=2"), }}, expectedOrgID: 2, }, { desc: "should prioritise org id from url when present in both header and url", req: &authn.Request{HTTPRequest: &http.Request{ Header: map[string][]string{orgIDHeaderName: {"1"}}, URL: mustParseURL("http://localhost/?targetOrgId=2"), }}, expectedOrgID: 2, }, { desc: "should set org id to 0 when missing in both header and url", req: &authn.Request{HTTPRequest: &http.Request{ Header: map[string][]string{}, URL: &url.URL{}, }}, expectedOrgID: 0, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { var calledWith int64 s := setupTests(t, func(svc *Service) { svc.RegisterClient(authntest.MockClient{ AuthenticateFunc: func(ctx context.Context, r *authn.Request) (*authn.Identity, error) { calledWith = r.OrgID return &authn.Identity{}, nil }, TestFunc: func(ctx context.Context, r *authn.Request) bool { return true }, }) }) _, _ = s.Authenticate(context.Background(), tt.req) assert.Equal(t, tt.expectedOrgID, calledWith) }) } } func TestService_HookClient(t *testing.T) { hookCalled := false s := setupTests(t, func(svc *Service) { svc.RegisterClient(&authntest.MockClient{ AuthenticateFunc: func(ctx context.Context, r *authn.Request) (*authn.Identity, error) { return &authn.Identity{}, nil }, TestFunc: func(ctx context.Context, r *authn.Request) bool { return true }, HookFunc: func(ctx context.Context, identity *authn.Identity, r *authn.Request) error { hookCalled = true return nil }, }) }) _, _ = s.Authenticate(context.Background(), &authn.Request{}) require.True(t, hookCalled) } func TestService_Login(t *testing.T) { type TestCase struct { desc string client string expectedClientOK bool expectedClientErr error expectedClientIdentity *authn.Identity expectedSessionErr error expectedErr error expectedIdentity *authn.Identity } tests := []TestCase{ { desc: "should login for valid request", client: "fake", expectedClientOK: true, expectedClientIdentity: &authn.Identity{ ID: "user:1", }, expectedIdentity: &authn.Identity{ ID: "user:1", SessionToken: &auth.UserToken{UserId: 1}, }, }, { desc: "should not login with invalid client", client: "invalid", expectedErr: authn.ErrClientNotConfigured, }, { desc: "should not login non user identity", client: "fake", expectedClientOK: true, expectedClientIdentity: &authn.Identity{ID: "apikey:1"}, expectedErr: authn.ErrUnsupportedIdentity, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { s := setupTests(t, func(svc *Service) { svc.RegisterClient(&authntest.FakeClient{ ExpectedName: "fake", ExpectedErr: tt.expectedClientErr, ExpectedTest: tt.expectedClientOK, ExpectedIdentity: tt.expectedClientIdentity, }) svc.sessionService = &authtest.FakeUserAuthTokenService{ CreateTokenProvider: func(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) { if tt.expectedSessionErr != nil { return nil, tt.expectedSessionErr } return &auth.UserToken{UserId: user.ID}, nil }, } }) identity, err := s.Login(context.Background(), tt.client, &authn.Request{HTTPRequest: &http.Request{ Header: map[string][]string{}, URL: &url.URL{}, }}) assert.ErrorIs(t, err, tt.expectedErr) assert.EqualValues(t, tt.expectedIdentity, identity) }) } } func TestService_RedirectURL(t *testing.T) { type testCase struct { desc string client string expectedErr error } tests := []testCase{ { desc: "should generate url for valid redirect client", client: "redirect", }, { desc: "should return error on non existing client", client: "non-existing", expectedErr: authn.ErrClientNotConfigured, }, { desc: "should return error when client don't support the redirect interface", client: "non-redirect", expectedErr: authn.ErrUnsupportedClient, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { service := setupTests(t, func(svc *Service) { svc.RegisterClient(authntest.FakeRedirectClient{ExpectedName: "redirect"}) svc.RegisterClient(&authntest.FakeClient{ExpectedName: "non-redirect"}) }) _, err := service.RedirectURL(context.Background(), tt.client, nil) assert.ErrorIs(t, err, tt.expectedErr) }) } } func TestService_Logout(t *testing.T) { type TestCase struct { desc string identity *authn.Identity sessionToken *usertoken.UserToken info *login.UserAuth client authn.Client expectedErr error expectedTokenRevoked bool expectedRedirect *authn.Redirect } tests := []TestCase{ { desc: "should redirect to default redirect url when identity is not a user", identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceServiceAccount, 1)}, expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, }, { desc: "should redirect to default redirect url when no external provider was used to authenticate", identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, expectedTokenRevoked: true, }, { desc: "should redirect to default redirect url when client is not found", identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, info: &login.UserAuth{AuthModule: "notFound"}, expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, expectedTokenRevoked: true, }, { desc: "should redirect to default redirect url when client do not implement logout extension", identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, info: &login.UserAuth{AuthModule: "azuread"}, expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, client: &authntest.FakeClient{ExpectedName: "auth.client.azuread"}, expectedTokenRevoked: true, }, { desc: "should redirect to client specific url", identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, info: &login.UserAuth{AuthModule: "azuread"}, expectedRedirect: &authn.Redirect{URL: "http://idp.com/logout"}, client: &authntest.MockClient{ NameFunc: func() string { return "auth.client.azuread" }, LogoutFunc: func(ctx context.Context, _ identity.Requester, _ *login.UserAuth) (*authn.Redirect, bool) { return &authn.Redirect{URL: "http://idp.com/logout"}, true }, }, expectedTokenRevoked: true, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { var tokenRevoked bool s := setupTests(t, func(svc *Service) { if tt.client != nil { svc.RegisterClient(tt.client) } svc.cfg.AppSubURL = "http://localhost:3000" svc.authInfoService = &authinfotest.FakeService{ ExpectedUserAuth: tt.info, } svc.sessionService = &authtest.FakeUserAuthTokenService{ RevokeTokenProvider: func(_ context.Context, sessionToken *auth.UserToken, soft bool) error { tokenRevoked = true assert.EqualValues(t, tt.sessionToken, sessionToken) assert.False(t, soft) return nil }, } }) redirect, err := s.Logout(context.Background(), tt.identity, tt.sessionToken) assert.ErrorIs(t, err, tt.expectedErr) assert.EqualValues(t, tt.expectedRedirect, redirect) assert.Equal(t, tt.expectedTokenRevoked, tokenRevoked) }) } } func mustParseURL(s string) *url.URL { u, err := url.Parse(s) if err != nil { panic(err) } return u } func setupTests(t *testing.T, opts ...func(svc *Service)) *Service { t.Helper() s := &Service{ log: log.NewNopLogger(), cfg: setting.NewCfg(), clients: map[string]authn.Client{}, clientQueue: newQueue[authn.ContextAwareClient](), tracer: tracing.InitializeTracerForTest(), metrics: newMetrics(nil), postAuthHooks: newQueue[authn.PostAuthHookFn](), postLoginHooks: newQueue[authn.PostLoginHookFn](), } for _, o := range opts { o(s) } return s }