mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
RBAC: Redirect to /login when forceLogin is set (#56469)
This commit is contained in:
parent
b622a87aee
commit
bb479e030a
@ -3,12 +3,18 @@ package accesscontrol
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/middleware/cookies"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
@ -23,6 +29,24 @@ func Middleware(ac AccessControl) func(web.Handler, Evaluator) web.Handler {
|
||||
}
|
||||
|
||||
return func(c *models.ReqContext) {
|
||||
if c.AllowAnonymous {
|
||||
forceLogin, _ := strconv.ParseBool(c.Req.URL.Query().Get("forceLogin")) // ignoring error, assuming false for non-true values is ok.
|
||||
orgID, err := strconv.ParseInt(c.Req.URL.Query().Get("orgId"), 10, 64)
|
||||
if err == nil && orgID > 0 && orgID != c.OrgID {
|
||||
forceLogin = true
|
||||
}
|
||||
|
||||
if !c.IsSignedIn && forceLogin {
|
||||
unauthorized(c, nil)
|
||||
}
|
||||
}
|
||||
|
||||
var revokedErr *models.TokenRevokedError
|
||||
if errors.As(c.LookupTokenErr, &revokedErr) {
|
||||
unauthorized(c, revokedErr)
|
||||
return
|
||||
}
|
||||
|
||||
authorize(c, ac, c.SignedInUser, evaluator)
|
||||
}
|
||||
}
|
||||
@ -80,6 +104,47 @@ func deny(c *models.ReqContext, evaluator Evaluator, err error) {
|
||||
})
|
||||
}
|
||||
|
||||
func unauthorized(c *models.ReqContext, err error) {
|
||||
if c.IsApiRequest() {
|
||||
response := map[string]interface{}{
|
||||
"message": "Unauthorized",
|
||||
}
|
||||
|
||||
var revokedErr *models.TokenRevokedError
|
||||
if errors.As(err, &revokedErr) {
|
||||
response["message"] = "Token revoked"
|
||||
response["error"] = map[string]interface{}{
|
||||
"id": "ERR_TOKEN_REVOKED",
|
||||
"maxConcurrentSessions": revokedErr.MaxConcurrentSessions,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, response)
|
||||
return
|
||||
}
|
||||
|
||||
writeRedirectCookie(c)
|
||||
c.Redirect(setting.AppSubUrl + "/login")
|
||||
}
|
||||
|
||||
func writeRedirectCookie(c *models.ReqContext) {
|
||||
redirectTo := c.Req.RequestURI
|
||||
if setting.AppSubUrl != "" && !strings.HasPrefix(redirectTo, setting.AppSubUrl) {
|
||||
redirectTo = setting.AppSubUrl + c.Req.RequestURI
|
||||
}
|
||||
|
||||
// remove any forceLogin=true params
|
||||
redirectTo = removeForceLoginParams(redirectTo)
|
||||
|
||||
cookies.WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, nil)
|
||||
}
|
||||
|
||||
var forceLoginParamsRegexp = regexp.MustCompile(`&?forceLogin=true`)
|
||||
|
||||
func removeForceLoginParams(str string) string {
|
||||
return forceLoginParamsRegexp.ReplaceAllString(str, "")
|
||||
}
|
||||
|
||||
func newID() string {
|
||||
// Less ambiguity than alphanumerical.
|
||||
numerical := []byte("0123456789")
|
||||
|
@ -83,7 +83,53 @@ func TestMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func contextProvider() web.Handler {
|
||||
func TestMiddleware_forceLogin(t *testing.T) {
|
||||
tests := []struct {
|
||||
url string
|
||||
redirectToLogin bool
|
||||
}{
|
||||
{url: "/endpoint?forceLogin=true", redirectToLogin: true},
|
||||
{url: "/endpoint?forceLogin=false"},
|
||||
{url: "/endpoint"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
var endpointCalled bool
|
||||
|
||||
server := web.New()
|
||||
server.UseMiddleware(web.Renderer("../../public/views", "[[", "]]"))
|
||||
|
||||
server.Get("/endpoint", func(c *models.ReqContext) {
|
||||
endpointCalled = true
|
||||
c.Resp.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
ac := mock.New().WithPermissions([]accesscontrol.Permission{{Action: "endpoint:read", Scope: "endpoint:1"}})
|
||||
server.Use(contextProvider(func(c *models.ReqContext) {
|
||||
c.AllowAnonymous = true
|
||||
c.SignedInUser.IsAnonymous = true
|
||||
c.IsSignedIn = false
|
||||
}))
|
||||
server.Use(
|
||||
accesscontrol.Middleware(ac)(nil, accesscontrol.EvalPermission("endpoint:read", "endpoint:1")),
|
||||
)
|
||||
|
||||
request, err := http.NewRequest(http.MethodGet, tc.url, nil)
|
||||
assert.NoError(t, err)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
server.ServeHTTP(recorder, request)
|
||||
|
||||
expectedCode := http.StatusOK
|
||||
if tc.redirectToLogin {
|
||||
expectedCode = http.StatusFound
|
||||
}
|
||||
assert.Equal(t, expectedCode, recorder.Code)
|
||||
assert.Equal(t, !tc.redirectToLogin, endpointCalled, "/endpoint should be called?")
|
||||
}
|
||||
}
|
||||
|
||||
func contextProvider(modifiers ...func(c *models.ReqContext)) web.Handler {
|
||||
return func(c *web.Context) {
|
||||
reqCtx := &models.ReqContext{
|
||||
Context: c,
|
||||
@ -92,6 +138,9 @@ func contextProvider() web.Handler {
|
||||
IsSignedIn: true,
|
||||
SkipCache: true,
|
||||
}
|
||||
for _, modifier := range modifiers {
|
||||
modifier(reqCtx)
|
||||
}
|
||||
c.Req = c.Req.WithContext(ctxkey.Set(c.Req.Context(), reqCtx))
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user