diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index 2bb0f8a49d5..37f04011af7 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -163,24 +163,11 @@ func initContextWithBasicAuth(ctx *models.ReqContext, orgId int64) bool { return true } - loginQuery := models.GetUserByLoginQuery{LoginOrEmail: username} - if err := bus.Dispatch(&loginQuery); err != nil { - ctx.Logger.Debug( - "Failed to look up the username", - "username", username, - ) - ctx.JsonApiErr(401, errStringInvalidUsernamePassword, err) - - return true - } - - user := loginQuery.Result - loginUserQuery := models.LoginUserQuery{ + authQuery := models.LoginUserQuery{ Username: username, Password: password, - User: user, } - if err := bus.Dispatch(&loginUserQuery); err != nil { + if err := bus.Dispatch(&authQuery); err != nil { ctx.Logger.Debug( "Failed to authorize the user", "username", username, @@ -190,6 +177,8 @@ func initContextWithBasicAuth(ctx *models.ReqContext, orgId int64) bool { return true } + user := authQuery.User + query := models.GetSignedInUserQuery{UserId: user.Id, OrgId: orgId} if err := bus.Dispatch(&query); err != nil { ctx.Logger.Error( diff --git a/pkg/middleware/middleware_basic_auth_test.go b/pkg/middleware/middleware_basic_auth_test.go new file mode 100644 index 00000000000..ef2631c1dbb --- /dev/null +++ b/pkg/middleware/middleware_basic_auth_test.go @@ -0,0 +1,143 @@ +package middleware + +import ( + "encoding/json" + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "github.com/grafana/grafana/pkg/bus" + authLogin "github.com/grafana/grafana/pkg/login" + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util" +) + +func TestMiddlewareBasicAuth(t *testing.T) { + Convey("Given the basic auth", t, func() { + var oldBasicAuthEnabled = setting.BasicAuthEnabled + var oldDisableBruteForceLoginProtection = setting.DisableBruteForceLoginProtection + var id int64 = 12 + + Convey("Setup", func() { + setting.BasicAuthEnabled = true + setting.DisableBruteForceLoginProtection = true + bus.ClearBusHandlers() + }) + + middlewareScenario(t, "Valid API key", func(sc *scenarioContext) { + var orgID int64 = 2 + keyhash := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") + + bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { + query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash} + return nil + }) + + authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9") + sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() + + Convey("Should return 200", func() { + So(sc.resp.Code, ShouldEqual, 200) + }) + + Convey("Should init middleware context", func() { + So(sc.context.IsSignedIn, ShouldEqual, true) + So(sc.context.OrgId, ShouldEqual, orgID) + So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR) + }) + }) + + middlewareScenario(t, "Handle auth", func(sc *scenarioContext) { + var password = "MyPass" + var salt = "Salt" + var orgID int64 = 2 + + bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error { + query.User = &models.User{ + Password: util.EncodePassword(password, salt), + Salt: salt, + } + return nil + }) + + bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: orgID, UserId: id} + return nil + }) + + authHeader := util.GetBasicAuthHeader("myUser", password) + sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() + + Convey("Should init middleware context with users", func() { + So(sc.context.IsSignedIn, ShouldEqual, true) + So(sc.context.OrgId, ShouldEqual, orgID) + So(sc.context.UserId, ShouldEqual, id) + }) + + bus.ClearBusHandlers() + }) + + middlewareScenario(t, "Auth sequence", func(sc *scenarioContext) { + var password = "MyPass" + var salt = "Salt" + + authLogin.Init() + + bus.AddHandler("user-query", func(query *models.GetUserByLoginQuery) error { + query.Result = &models.User{ + Password: util.EncodePassword(password, salt), + Id: id, + Salt: salt, + } + return nil + }) + + bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{UserId: query.UserId} + return nil + }) + + authHeader := util.GetBasicAuthHeader("myUser", password) + sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() + + Convey("Should init middleware context with user", func() { + So(sc.context.IsSignedIn, ShouldEqual, true) + So(sc.context.UserId, ShouldEqual, id) + }) + }) + + middlewareScenario(t, "Should return error if user is not found", func(sc *scenarioContext) { + sc.fakeReq("GET", "/") + sc.req.SetBasicAuth("user", "password") + sc.exec() + + err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) + So(err, ShouldNotBeNil) + + So(sc.resp.Code, ShouldEqual, 401) + So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword) + }) + + middlewareScenario(t, "Should return error if user & password do not match", func(sc *scenarioContext) { + bus.AddHandler("user-query", func(loginUserQuery *models.GetUserByLoginQuery) error { + return nil + }) + + sc.fakeReq("GET", "/") + sc.req.SetBasicAuth("killa", "gorilla") + sc.exec() + + err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) + So(err, ShouldNotBeNil) + + So(sc.resp.Code, ShouldEqual, 401) + So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword) + }) + + Convey("Destroy", func() { + setting.BasicAuthEnabled = oldBasicAuthEnabled + setting.DisableBruteForceLoginProtection = oldDisableBruteForceLoginProtection + }) + }) +} diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index 502b5fff617..000aba1024d 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -3,10 +3,8 @@ package middleware import ( "context" "encoding/base32" - "encoding/json" "fmt" "net/http" - "net/http/httptest" "path/filepath" "testing" "time" @@ -476,95 +474,6 @@ func TestMiddlewareContext(t *testing.T) { }) } -func TestMiddlewareBasicAuth(t *testing.T) { - Convey("Given the basic auth", t, func() { - old := setting.BasicAuthEnabled - - Convey("Setup", func() { - setting.BasicAuthEnabled = true - }) - - middlewareScenario(t, "Valid API key", func(sc *scenarioContext) { - keyhash := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") - - bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { - query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash} - return nil - }) - - authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9") - sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() - - Convey("Should return 200", func() { - So(sc.resp.Code, ShouldEqual, 200) - }) - - Convey("Should init middleware context", func() { - So(sc.context.IsSignedIn, ShouldEqual, true) - So(sc.context.OrgId, ShouldEqual, 12) - So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR) - }) - }) - - middlewareScenario(t, "Handle auth", func(sc *scenarioContext) { - - bus.AddHandler("test", func(query *models.GetUserByLoginQuery) error { - query.Result = &models.User{ - Password: util.EncodePassword("myPass", "Salt"), - Salt: "Salt", - } - return nil - }) - - bus.AddHandler("test", func(loginUserQuery *models.LoginUserQuery) error { - return nil - }) - - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 2, UserId: 12} - return nil - }) - - authHeader := util.GetBasicAuthHeader("myUser", "myPass") - sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() - - Convey("Should init middleware context with user", func() { - So(sc.context.IsSignedIn, ShouldEqual, true) - So(sc.context.OrgId, ShouldEqual, 2) - So(sc.context.UserId, ShouldEqual, 12) - }) - }) - - middlewareScenario(t, "Should return error if user is not found", func(sc *scenarioContext) { - sc.fakeReqWithBasicAuth("GET", "/", "test", "test").exec() - - err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) - So(err, ShouldNotBeNil) - - So(sc.resp.Code, ShouldEqual, 401) - So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword) - }) - - middlewareScenario(t, "Should return error if user & password do not match", func(sc *scenarioContext) { - bus.AddHandler("test", func(loginUserQuery *models.GetUserByLoginQuery) error { - return nil - }) - - sc.fakeReqWithBasicAuth("GET", "/", "test", "test").exec() - - err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) - So(err, ShouldNotBeNil) - - So(sc.resp.Code, ShouldEqual, 401) - So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword) - }) - - Convey("Destroy", func() { - setting.BasicAuthEnabled = old - }) - }) -} - func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) { Convey(desc, func() { defer bus.ClearBusHandlers() @@ -602,100 +511,3 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) { fn(sc) }) } - -type scenarioContext struct { - m *macaron.Macaron - context *models.ReqContext - resp *httptest.ResponseRecorder - apiKey string - authHeader string - tokenSessionCookie string - respJson map[string]interface{} - handlerFunc handlerFunc - defaultHandler macaron.Handler - url string - userAuthTokenService *auth.FakeUserAuthTokenService - remoteCacheService *remotecache.RemoteCache - - req *http.Request -} - -func (sc *scenarioContext) withValidApiKey() *scenarioContext { - sc.apiKey = "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9" - return sc -} - -func (sc *scenarioContext) withTokenSessionCookie(unhashedToken string) *scenarioContext { - sc.tokenSessionCookie = unhashedToken - return sc -} - -func (sc *scenarioContext) withAuthorizationHeader(authHeader string) *scenarioContext { - sc.authHeader = authHeader - return sc -} - -func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext { - sc.resp = httptest.NewRecorder() - req, err := http.NewRequest(method, url, nil) - So(err, ShouldBeNil) - sc.req = req - - return sc -} - -func (sc *scenarioContext) fakeReqWithBasicAuth(method, url, user, password string) *scenarioContext { - sc.resp = httptest.NewRecorder() - req, err := http.NewRequest(method, url, nil) - req.SetBasicAuth(user, password) - So(err, ShouldBeNil) - sc.req = req - - return sc -} - -func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map[string]string) *scenarioContext { - sc.resp = httptest.NewRecorder() - req, err := http.NewRequest(method, url, nil) - q := req.URL.Query() - for k, v := range queryParams { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - So(err, ShouldBeNil) - sc.req = req - - return sc -} - -func (sc *scenarioContext) handler(fn handlerFunc) *scenarioContext { - sc.handlerFunc = fn - return sc -} - -func (sc *scenarioContext) exec() { - if sc.apiKey != "" { - sc.req.Header.Add("Authorization", "Bearer "+sc.apiKey) - } - - if sc.authHeader != "" { - sc.req.Header.Add("Authorization", sc.authHeader) - } - - if sc.tokenSessionCookie != "" { - sc.req.AddCookie(&http.Cookie{ - Name: setting.LoginCookieName, - Value: sc.tokenSessionCookie, - }) - } - - sc.m.ServeHTTP(sc.resp, sc.req) - - if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" { - err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) - So(err, ShouldBeNil) - } -} - -type scenarioFunc func(c *scenarioContext) -type handlerFunc func(c *models.ReqContext) diff --git a/pkg/middleware/testing.go b/pkg/middleware/testing.go new file mode 100644 index 00000000000..8a55933a606 --- /dev/null +++ b/pkg/middleware/testing.go @@ -0,0 +1,102 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/macaron.v1" + + "github.com/grafana/grafana/pkg/infra/remotecache" + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/setting" +) + +type scenarioContext struct { + m *macaron.Macaron + context *models.ReqContext + resp *httptest.ResponseRecorder + apiKey string + authHeader string + tokenSessionCookie string + respJson map[string]interface{} + handlerFunc handlerFunc + defaultHandler macaron.Handler + url string + userAuthTokenService *auth.FakeUserAuthTokenService + remoteCacheService *remotecache.RemoteCache + + req *http.Request +} + +func (sc *scenarioContext) withValidApiKey() *scenarioContext { + sc.apiKey = "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9" + return sc +} + +func (sc *scenarioContext) withTokenSessionCookie(unhashedToken string) *scenarioContext { + sc.tokenSessionCookie = unhashedToken + return sc +} + +func (sc *scenarioContext) withAuthorizationHeader(authHeader string) *scenarioContext { + sc.authHeader = authHeader + return sc +} + +func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext { + sc.resp = httptest.NewRecorder() + req, err := http.NewRequest(method, url, nil) + So(err, ShouldBeNil) + sc.req = req + + return sc +} + +func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map[string]string) *scenarioContext { + sc.resp = httptest.NewRecorder() + req, err := http.NewRequest(method, url, nil) + q := req.URL.Query() + for k, v := range queryParams { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + So(err, ShouldBeNil) + sc.req = req + + return sc +} + +func (sc *scenarioContext) handler(fn handlerFunc) *scenarioContext { + sc.handlerFunc = fn + return sc +} + +func (sc *scenarioContext) exec() { + if sc.apiKey != "" { + sc.req.Header.Add("Authorization", "Bearer "+sc.apiKey) + } + + if sc.authHeader != "" { + sc.req.Header.Add("Authorization", sc.authHeader) + } + + if sc.tokenSessionCookie != "" { + sc.req.AddCookie(&http.Cookie{ + Name: setting.LoginCookieName, + Value: sc.tokenSessionCookie, + }) + } + + sc.m.ServeHTTP(sc.resp, sc.req) + + if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" { + err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) + So(err, ShouldBeNil) + } +} + +type scenarioFunc func(c *scenarioContext) +type handlerFunc func(c *models.ReqContext)