diff --git a/pkg/api/common_test.go b/pkg/api/common_test.go index 0764fb0bfd6..590074f7193 100644 --- a/pkg/api/common_test.go +++ b/pkg/api/common_test.go @@ -93,6 +93,26 @@ func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map return sc } +func (sc *scenarioContext) fakeReqNoAssertions(method, url string) *scenarioContext { + sc.resp = httptest.NewRecorder() + req, _ := http.NewRequest(method, url, nil) + sc.req = req + + return sc +} + +func (sc *scenarioContext) fakeReqNoAssertionsWithCookie(method, url string, cookie http.Cookie) *scenarioContext { + sc.resp = httptest.NewRecorder() + http.SetCookie(sc.resp, &cookie) + + req, _ := http.NewRequest(method, url, nil) + req.Header = http.Header{"Cookie": sc.resp.Header()["Set-Cookie"]} + + sc.req = req + + return sc +} + type scenarioContext struct { m *macaron.Macaron context *m.ReqContext diff --git a/pkg/api/login.go b/pkg/api/login.go index 37df4613212..61a6299b935 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -21,8 +21,14 @@ const ( LoginErrorCookieName = "login_error" ) +var setIndexViewData = (*HTTPServer).setIndexViewData + +var getViewIndex = func() string { + return ViewIndex +} + func (hs *HTTPServer) LoginView(c *models.ReqContext) { - viewData, err := hs.setIndexViewData(c) + viewData, err := setIndexViewData(hs, c) if err != nil { c.Handle(500, "Failed to get settings", err) return @@ -41,8 +47,14 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) { viewData.Settings["samlEnabled"] = hs.Cfg.SAMLEnabled if loginError, ok := tryGetEncryptedCookie(c, LoginErrorCookieName); ok { + //this cookie is only set whenever an OAuth login fails + //therefore the loginError should be passed to the view data + //and the view should return immediately before attempting + //to login again via OAuth and enter to a redirect loop deleteCookie(c, LoginErrorCookieName) viewData.Settings["loginError"] = loginError + c.HTML(200, getViewIndex(), viewData) + return } if tryOAuthAutoLogin(c) { diff --git a/pkg/api/login_test.go b/pkg/api/login_test.go new file mode 100644 index 00000000000..ab28848a43d --- /dev/null +++ b/pkg/api/login_test.go @@ -0,0 +1,135 @@ +package api + +import ( + "encoding/hex" + "errors" + "github.com/grafana/grafana/pkg/api/dtos" + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util" + "github.com/stretchr/testify/assert" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func mockSetIndexViewData() { + setIndexViewData = func(*HTTPServer, *models.ReqContext) (*dtos.IndexViewData, error) { + data := &dtos.IndexViewData{ + User: &dtos.CurrentUser{}, + Settings: map[string]interface{}{}, + NavTree: []*dtos.NavLink{}, + } + return data, nil + } +} + +func resetSetIndexViewData() { + setIndexViewData = (*HTTPServer).setIndexViewData +} + +func mockViewIndex() { + getViewIndex = func() string { + return "index-template" + } +} + +func resetViewIndex() { + getViewIndex = func() string { + return ViewIndex + } +} + +func getBody(resp *httptest.ResponseRecorder) (string, error) { + responseData, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + return string(responseData), nil +} + +func TestLoginErrorCookieApiEndpoint(t *testing.T) { + mockSetIndexViewData() + defer resetSetIndexViewData() + + mockViewIndex() + defer resetViewIndex() + + sc := setupScenarioContext("/login") + hs := &HTTPServer{ + Cfg: setting.NewCfg(), + } + + sc.defaultHandler = Wrap(func(w http.ResponseWriter, c *models.ReqContext) { + hs.LoginView(c) + }) + + setting.OAuthService = &setting.OAuther{} + setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) + setting.LoginCookieName = "grafana_session" + setting.SecretKey = "login_testing" + + setting.OAuthService = &setting.OAuther{} + setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) + setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{ + ClientId: "fake", + ClientSecret: "fakefake", + Enabled: true, + AllowSignup: true, + Name: "github", + } + setting.OAuthAutoLogin = true + + oauthError := errors.New("User not a member of one of the required organizations") + encryptedError, _ := util.Encrypt([]byte(oauthError.Error()), setting.SecretKey) + cookie := http.Cookie{ + Name: LoginErrorCookieName, + MaxAge: 60, + Value: hex.EncodeToString(encryptedError), + HttpOnly: true, + Path: setting.AppSubUrl + "/", + Secure: hs.Cfg.CookieSecure, + SameSite: hs.Cfg.CookieSameSite, + } + sc.m.Get(sc.url, sc.defaultHandler) + sc.fakeReqNoAssertionsWithCookie("GET", sc.url, cookie).exec() + assert.Equal(t, sc.resp.Code, 200) + + responseString, err := getBody(sc.resp) + assert.Nil(t, err) + assert.True(t, strings.Contains(responseString, oauthError.Error())) +} + +func TestLoginOAuthRedirect(t *testing.T) { + mockSetIndexViewData() + defer resetSetIndexViewData() + + sc := setupScenarioContext("/login") + hs := &HTTPServer{ + Cfg: setting.NewCfg(), + } + + sc.defaultHandler = Wrap(func(c *models.ReqContext) { + hs.LoginView(c) + }) + + setting.OAuthService = &setting.OAuther{} + setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) + setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{ + ClientId: "fake", + ClientSecret: "fakefake", + Enabled: true, + AllowSignup: true, + Name: "github", + } + setting.OAuthAutoLogin = true + sc.m.Get(sc.url, sc.defaultHandler) + sc.fakeReqNoAssertions("GET", sc.url).exec() + + assert.Equal(t, sc.resp.Code, 307) + location, ok := sc.resp.Header()["Location"] + assert.True(t, ok) + assert.Equal(t, location[0], "/login/github") +}