package middleware import ( "errors" "fmt" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/infra/log/logtest" "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/services/accesscontrol/actest" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn/authntest" "github.com/grafana/grafana/pkg/services/contexthandler" contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/pluginsintegration/pluginstore" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" ) func setupAuthMiddlewareTest(t *testing.T, identity *authn.Identity, authErr error) *contexthandler.ContextHandler { return contexthandler.ProvideService(setting.NewCfg(), tracing.InitializeTracerForTest(), featuremgmt.WithFeatures(), &authntest.FakeService{ ExpectedErr: authErr, ExpectedIdentity: identity, }) } func TestAuth_Middleware(t *testing.T) { type testCase struct { desc string identity *authn.Identity path string authErr error authMiddleware web.Handler expecedReached bool expectedCode int } tests := []testCase{ { desc: "ReqSignedIn should redirect unauthenticated request to secure endpoint", path: "/secure", authMiddleware: ReqSignedIn, authErr: errors.New("no auth"), expectedCode: http.StatusFound, }, { desc: "ReqSignedIn should return 401 for api endpint", path: "/api/secure", authMiddleware: ReqSignedIn, authErr: errors.New("no auth"), expectedCode: http.StatusUnauthorized, }, { desc: "ReqSignedIn should return 200 for anonymous user", path: "/api/secure", authMiddleware: ReqSignedIn, identity: &authn.Identity{ID: authn.AnonymousNamespaceID}, expecedReached: true, expectedCode: http.StatusOK, }, { desc: "ReqSignedIn should return redirect anonymous user with forceLogin query string", path: "/secure?forceLogin=true", authMiddleware: ReqSignedIn, identity: &authn.Identity{ID: authn.AnonymousNamespaceID}, expecedReached: false, expectedCode: http.StatusFound, }, { desc: "ReqSignedIn should return redirect anonymous user when orgId in query string is different from currently used", path: "/secure?orgId=2", authMiddleware: ReqSignedIn, identity: &authn.Identity{ID: authn.AnonymousNamespaceID, OrgID: 1}, expecedReached: false, expectedCode: http.StatusFound, }, { desc: "ReqSignedInNoAnonymous should return 401 for anonymous user", path: "/api/secure", authMiddleware: ReqSignedInNoAnonymous, identity: &authn.Identity{ID: authn.AnonymousNamespaceID}, expecedReached: false, expectedCode: http.StatusUnauthorized, }, { desc: "ReqSignedInNoAnonymous should return 200 for authenticated user", path: "/api/secure", authMiddleware: ReqSignedInNoAnonymous, identity: &authn.Identity{ID: authn.MustParseNamespaceID("user:1")}, expecedReached: true, expectedCode: http.StatusOK, }, { desc: "snapshot public mode disabled should return 200 for authenticated user", path: "/api/secure", authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: false}), identity: &authn.Identity{ID: authn.MustParseNamespaceID("user:1")}, expecedReached: true, expectedCode: http.StatusOK, }, { desc: "snapshot public mode disabled should return 401 for unauthenticated request", path: "/api/secure", authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: false}), authErr: errors.New("no auth"), expecedReached: false, expectedCode: http.StatusUnauthorized, }, { desc: "snapshot public mode enabled should return 200 for unauthenticated request", path: "/api/secure", authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: true}), authErr: errors.New("no auth"), expecedReached: true, expectedCode: http.StatusOK, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { ctxHandler := setupAuthMiddlewareTest(t, tt.identity, tt.authErr) server := web.New() server.Use(ctxHandler.Middleware) server.Use(tt.authMiddleware) var reached bool server.Get("/secure", func(c *contextmodel.ReqContext) { reached = true c.Resp.WriteHeader(http.StatusOK) }) server.Get("/api/secure", func(c *contextmodel.ReqContext) { reached = true c.Resp.WriteHeader(http.StatusOK) }) req, err := http.NewRequest(http.MethodGet, tt.path, nil) require.NoError(t, err) recorder := httptest.NewRecorder() server.ServeHTTP(recorder, req) res := recorder.Result() assert.Equal(t, tt.expecedReached, reached) assert.Equal(t, tt.expectedCode, res.StatusCode) require.NoError(t, res.Body.Close()) }) } } func TestRoleAppPluginAuth(t *testing.T) { t.Run("Verify user's role when requesting app route which requires role", func(t *testing.T) { appSubURL := setting.AppSubUrl setting.AppSubUrl = "/grafana/" t.Cleanup(func() { setting.AppSubUrl = appSubURL }) tcs := []struct { roleRequired org.RoleType role org.RoleType expStatus int expBody string expLocation string }{ {roleRequired: org.RoleViewer, role: org.RoleAdmin, expStatus: http.StatusOK, expBody: ""}, {roleRequired: org.RoleAdmin, role: org.RoleAdmin, expStatus: http.StatusOK, expBody: ""}, {roleRequired: org.RoleAdmin, role: org.RoleViewer, expStatus: http.StatusFound, expBody: "Found.\n\n", expLocation: "/grafana/"}, {roleRequired: "", role: org.RoleViewer, expStatus: http.StatusOK, expBody: ""}, {roleRequired: org.RoleEditor, role: "", expStatus: http.StatusFound, expBody: "Found.\n\n", expLocation: "/grafana/"}, } const path = "/a/test-app/test" for i, tc := range tcs { t.Run(fmt.Sprintf("testcase %d", i), func(t *testing.T) { ps := pluginstore.NewFakePluginStore(pluginstore.Plugin{ JSONData: plugins.JSONData{ ID: "test-app", Includes: []*plugins.Includes{ { Type: "page", Role: tc.roleRequired, Path: path, }, }, }, }) middlewareScenario(t, t.Name(), func(t *testing.T, sc *scenarioContext) { sc.withIdentity(&authn.Identity{ OrgRoles: map[int64]org.RoleType{ 0: tc.role, }, }) features := featuremgmt.WithFeatures() logger := &logtest.Fake{} ac := &actest.FakeAccessControl{} sc.m.Get("/a/:id/*", RoleAppPluginAuth(ac, ps, features, logger), func(c *contextmodel.ReqContext) { c.JSON(http.StatusOK, map[string]interface{}{}) }) sc.fakeReq("GET", path).exec() assert.Equal(t, tc.expStatus, sc.resp.Code) assert.Equal(t, tc.expBody, sc.resp.Body.String()) assert.Equal(t, tc.expLocation, sc.resp.Header().Get("Location")) }) }) } }) // We return success in this case because the frontend takes care of rendering the 404 page middlewareScenario(t, "Plugin is not found returns success", func(t *testing.T, sc *scenarioContext) { sc.withIdentity(&authn.Identity{ OrgRoles: map[int64]org.RoleType{ 0: org.RoleViewer, }, }) features := featuremgmt.WithFeatures() logger := &logtest.Fake{} ac := &actest.FakeAccessControl{} sc.m.Get("/a/:id/*", RoleAppPluginAuth(ac, &pluginstore.FakePluginStore{}, features, logger), func(c *contextmodel.ReqContext) { c.JSON(http.StatusOK, map[string]interface{}{}) }) sc.fakeReq("GET", "/a/test-app/test").exec() assert.Equal(t, 200, sc.resp.Code) assert.Equal(t, "", sc.resp.Body.String()) }) // We return success in this case because the frontend takes care of rendering the right page based on its router middlewareScenario(t, "Plugin page not found returns success", func(t *testing.T, sc *scenarioContext) { sc.withIdentity(&authn.Identity{ OrgRoles: map[int64]org.RoleType{ 0: org.RoleViewer, }, }) features := featuremgmt.WithFeatures() logger := &logtest.Fake{} ac := &actest.FakeAccessControl{} sc.m.Get("/a/:id/*", RoleAppPluginAuth(ac, pluginstore.NewFakePluginStore(pluginstore.Plugin{ JSONData: plugins.JSONData{ ID: "test-app", Includes: []*plugins.Includes{ { Type: "page", Role: org.RoleViewer, Path: "/a/test-app/test", }, }, }, }), features, logger), func(c *contextmodel.ReqContext) { c.JSON(http.StatusOK, map[string]interface{}{}) }) sc.fakeReq("GET", "/a/test-app/notExistingPath").exec() assert.Equal(t, 200, sc.resp.Code) assert.Equal(t, "", sc.resp.Body.String()) }) t.Run("Plugin include with RBAC", func(t *testing.T) { tcs := []struct { name string evalResult bool evalErr error expStatus int expBody string expLocation string }{ { name: "Unsuccessful RBAC eval will result in a redirect", evalResult: false, expStatus: 302, expBody: "Found.\n\n", expLocation: "/", }, { name: "An RBAC eval error will result in a redirect", evalErr: errors.New("eval error"), expStatus: 302, expBody: "Found.\n\n", expLocation: "/", }, { name: "Successful RBAC eval will result in a successful request", evalResult: true, expStatus: 200, expBody: "", expLocation: "", }, } for _, tc := range tcs { middlewareScenario(t, "Plugin include with RBAC", func(t *testing.T, sc *scenarioContext) { sc.withIdentity(&authn.Identity{ OrgRoles: map[int64]org.RoleType{ 0: org.RoleViewer, }, }) logger := &logtest.Fake{} features := featuremgmt.WithFeatures(featuremgmt.FlagAccessControlOnCall) ac := &actest.FakeAccessControl{ ExpectedEvaluate: tc.evalResult, ExpectedErr: tc.evalErr, } path := "/a/test-app/test" ps := pluginstore.NewFakePluginStore(pluginstore.Plugin{ JSONData: plugins.JSONData{ ID: "test-app", Includes: []*plugins.Includes{ { Type: "page", Role: org.RoleViewer, Path: path, Action: "test-app.test:read", }, }, }, }) sc.m.Get("/a/:id/*", RoleAppPluginAuth(ac, ps, features, logger), func(c *contextmodel.ReqContext) { c.JSON(http.StatusOK, map[string]interface{}{}) }) sc.fakeReq("GET", path).exec() assert.Equal(t, tc.expStatus, sc.resp.Code) assert.Equal(t, tc.expBody, sc.resp.Body.String()) assert.Equal(t, tc.expLocation, sc.resp.Header().Get("Location")) }) } }) } func TestRemoveForceLoginparams(t *testing.T) { tcs := []struct { inp string exp string }{ {inp: "/?forceLogin=true", exp: "/?"}, {inp: "/d/dash/dash-title?ordId=1&forceLogin=true", exp: "/d/dash/dash-title?ordId=1"}, {inp: "/?kiosk&forceLogin=true", exp: "/?kiosk"}, {inp: "/d/dash/dash-title?ordId=1&kiosk&forceLogin=true", exp: "/d/dash/dash-title?ordId=1&kiosk"}, {inp: "/d/dash/dash-title?ordId=1&forceLogin=true&kiosk", exp: "/d/dash/dash-title?ordId=1&kiosk"}, {inp: "/d/dash/dash-title?forceLogin=true&kiosk", exp: "/d/dash/dash-title?&kiosk"}, } for i, tc := range tcs { t.Run(fmt.Sprintf("testcase %d", i), func(t *testing.T) { require.Equal(t, tc.exp, removeForceLoginParams(tc.inp)) }) } }