From 95a4c4a4d65c2a7e8fbbe6eb5346108f6a242c48 Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Mon, 13 Jun 2022 16:59:15 +0200 Subject: [PATCH] OAuth: Redirect to login if no oauth module is found or if module is not configured (#50661) * OAuth: Redirect to login if no oauth module is found or if module is not configured * OAuth: Update test to check for location header --- pkg/api/login_oauth.go | 10 ++-------- pkg/api/login_oauth_test.go | 19 +++++++++++-------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index 1803c168e95..62c0d6e966d 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -75,19 +75,13 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { loginInfo.AuthModule = name provider := hs.SocialService.GetOAuthInfoProvider(name) if provider == nil { - hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ - HttpStatus: http.StatusNotFound, - PublicMessage: "OAuth not enabled", - }) + hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, errors.New("OAuth not enabled")) return } connect, err := hs.SocialService.GetConnector(name) if err != nil { - hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ - HttpStatus: http.StatusNotFound, - PublicMessage: fmt.Sprintf("No OAuth with name %s configured", name), - }) + hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name)) return } diff --git a/pkg/api/login_oauth_test.go b/pkg/api/login_oauth_test.go index 2f67bfdfbcc..cd5b2e6f2f2 100644 --- a/pkg/api/login_oauth_test.go +++ b/pkg/api/login_oauth_test.go @@ -9,6 +9,8 @@ import ( "path/filepath" "testing" + "github.com/grafana/grafana/pkg/services/secrets/fakes" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -31,11 +33,12 @@ func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux { sqlStore := sqlstore.InitTestDB(t) hs := &HTTPServer{ - Cfg: cfg, - License: &licensing.OSSLicensingService{Cfg: cfg}, - SQLStore: sqlStore, - SocialService: social.ProvideService(cfg), - HooksService: hooks.ProvideService(), + Cfg: cfg, + License: &licensing.OSSLicensingService{Cfg: cfg}, + SQLStore: sqlStore, + SocialService: social.ProvideService(cfg), + HooksService: hooks.ProvideService(), + SecretsService: fakes.NewFakeSecretsService(), } m := web.New() @@ -55,9 +58,9 @@ func TestOAuthLogin_UnknownProvider(t *testing.T) { recorder := httptest.NewRecorder() m.ServeHTTP(recorder, req) - - assert.Equal(t, http.StatusNotFound, recorder.Code) - assert.Contains(t, recorder.Body.String(), "OAuth not enabled") + // expect to be redirected to /login + assert.Equal(t, http.StatusFound, recorder.Code) + assert.Equal(t, "/login", recorder.Header().Get("Location")) } func TestOAuthLogin_Base(t *testing.T) {