diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index bc222361b25..3f1a0cadaa9 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -3,6 +3,8 @@ package api import ( "errors" "fmt" + "crypto/rand" + "encoding/base64" "golang.org/x/oauth2" @@ -14,6 +16,12 @@ import ( "github.com/grafana/grafana/pkg/social" ) +func GenStateString() string { + rnd := make([]byte, 32) + rand.Read(rnd) + return base64.StdEncoding.EncodeToString(rnd) +} + func OAuthLogin(ctx *middleware.Context) { if setting.OAuthService == nil { ctx.Handle(404, "login.OAuthLogin(oauth service not enabled)", nil) @@ -29,7 +37,17 @@ func OAuthLogin(ctx *middleware.Context) { code := ctx.Query("code") if code == "" { - ctx.Redirect(connect.AuthCodeURL("", oauth2.AccessTypeOnline)) + state := GenStateString() + ctx.Session.Set(middleware.SESS_KEY_OAUTH_STATE, state) + ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline)) + return + } + + // verify state string + savedState := ctx.Session.Get(middleware.SESS_KEY_OAUTH_STATE).(string) + queryState := ctx.Query("state") + if savedState != queryState { + ctx.Handle(500, "login.OAuthLogin(state mismatch)", nil) return } diff --git a/pkg/middleware/session.go b/pkg/middleware/session.go index ee6462be37a..d575189f4de 100644 --- a/pkg/middleware/session.go +++ b/pkg/middleware/session.go @@ -13,6 +13,7 @@ import ( const ( SESS_KEY_USERID = "uid" + SESS_KEY_OAUTH_STATE = "state" ) var sessionManager *session.Manager