diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index 984eab23382..9f3f6b0f491 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -13,8 +13,10 @@ import ( "strings" "sync" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/libraryelements" "github.com/grafana/grafana/pkg/services/librarypanels" + "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/api/routing" httpstatic "github.com/grafana/grafana/pkg/api/static" @@ -105,6 +107,8 @@ type HTTPServer struct { Alertmanager *notifier.Alertmanager `inject:""` LibraryPanelService librarypanels.Service `inject:""` LibraryElementService libraryelements.Service `inject:""` + SocialService social.Service `inject:""` + OAuthTokenService *oauthtoken.Service `inject:""` Listener net.Listener } diff --git a/pkg/api/login.go b/pkg/api/login.go index bf63b068c44..4fb6bcf5042 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -91,7 +91,8 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) { } enabledOAuths := make(map[string]interface{}) - for key, oauth := range setting.OAuthService.OAuthInfos { + providers := hs.SocialService.GetOAuthInfoProviders() + for key, oauth := range providers { enabledOAuths[key] = map[string]string{"name": oauth.Name} } @@ -147,12 +148,12 @@ func (hs *HTTPServer) tryOAuthAutoLogin(c *models.ReqContext) bool { if !setting.OAuthAutoLogin { return false } - oauthInfos := setting.OAuthService.OAuthInfos + oauthInfos := hs.SocialService.GetOAuthInfoProviders() if len(oauthInfos) != 1 { log.Warnf("Skipping OAuth auto login because multiple OAuth providers are configured") return false } - for key := range setting.OAuthService.OAuthInfos { + for key := range oauthInfos { redirectUrl := hs.Cfg.AppSubURL + "/login/" + key log.Infof("OAuth auto login enabled. Redirecting to " + redirectUrl) c.Redirect(redirectUrl, 307) diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index 1fce9b6f610..9bced0baa3e 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -41,7 +41,10 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { loginInfo := models.LoginInfo{ AuthModule: "oauth", } - if setting.OAuthService == nil { + name := ctx.Params(":name") + loginInfo.AuthModule = name + provider := hs.SocialService.GetOAuthInfoProvider(name) + if provider == nil { hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ HttpStatus: http.StatusNotFound, PublicMessage: "OAuth not enabled", @@ -49,10 +52,8 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { return } - name := ctx.Params(":name") - loginInfo.AuthModule = name - connect, ok := social.SocialMap[name] - if !ok { + connect, err := hs.SocialService.GetConnector(name) + if err != nil { hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ HttpStatus: http.StatusNotFound, PublicMessage: fmt.Sprintf("No OAuth with name %s configured", name), @@ -80,12 +81,12 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { return } - hashedState := hashStatecode(state, setting.OAuthService.OAuthInfos[name].ClientSecret) + hashedState := hashStatecode(state, provider.ClientSecret) cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) - if setting.OAuthService.OAuthInfos[name].HostedDomain == "" { + if provider.HostedDomain == "" { ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline)) } else { - ctx.Redirect(connect.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", setting.OAuthService.OAuthInfos[name].HostedDomain), oauth2.AccessTypeOnline)) + ctx.Redirect(connect.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", provider.HostedDomain), oauth2.AccessTypeOnline)) } return } @@ -103,7 +104,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { return } - queryState := hashStatecode(ctx.Query("state"), setting.OAuthService.OAuthInfos[name].ClientSecret) + queryState := hashStatecode(ctx.Query("state"), provider.ClientSecret) oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState) if cookieState != queryState { hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ @@ -113,7 +114,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { return } - oauthClient, err := social.GetOAuthHttpClient(name) + oauthClient, err := hs.SocialService.GetOAuthHttpClient(name) if err != nil { ctx.Logger.Error("Failed to create OAuth http client", "error", err) hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ diff --git a/pkg/api/login_test.go b/pkg/api/login_test.go index 2d59cfb6783..bdeceada8a9 100644 --- a/pkg/api/login_test.go +++ b/pkg/api/login_test.go @@ -17,6 +17,7 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/hooks" @@ -86,6 +87,16 @@ type redirectCase struct { redirectURL string } +var oAuthInfos = map[string]*social.OAuthInfo{ + "github": { + ClientId: "fake", + ClientSecret: "fakefake", + Enabled: true, + AllowSignup: true, + Name: "github", + }, +} + func TestLoginErrorCookieAPIEndpoint(t *testing.T) { fakeSetIndexViewData(t) @@ -97,6 +108,7 @@ func TestLoginErrorCookieAPIEndpoint(t *testing.T) { Cfg: cfg, SettingsProvider: &setting.OSSImpl{Cfg: cfg}, License: &licensing.OSSLicensingService{}, + SocialService: &mockSocialService{}, } sc.defaultHandler = routing.Wrap(func(w http.ResponseWriter, c *models.ReqContext) { @@ -106,23 +118,6 @@ func TestLoginErrorCookieAPIEndpoint(t *testing.T) { cfg.LoginCookieName = "grafana_session" setting.SecretKey = "login_testing" - origOAuthService := setting.OAuthService - origOAuthAutoLogin := setting.OAuthAutoLogin - t.Cleanup(func() { - setting.OAuthService = origOAuthService - setting.OAuthAutoLogin = origOAuthAutoLogin - }) - setting.OAuthService = &setting.OAuther{ - OAuthInfos: map[string]*setting.OAuthInfo{ - "github": { - ClientId: "fake", - ClientSecret: "fakefake", - Enabled: true, - AllowSignup: true, - Name: "github", - }, - }, - } setting.OAuthAutoLogin = true oauthError := errors.New("User not a member of one of the required organizations") @@ -160,6 +155,7 @@ func TestLoginViewRedirect(t *testing.T) { Cfg: cfg, SettingsProvider: &setting.OSSImpl{Cfg: cfg}, License: &licensing.OSSLicensingService{}, + SocialService: &mockSocialService{}, } hs.Cfg.CookieSecure = true @@ -171,9 +167,6 @@ func TestLoginViewRedirect(t *testing.T) { hs.LoginView(c) }) - setting.OAuthService = &setting.OAuther{} - setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) - redirectCases := []redirectCase{ { desc: "grafana relative url without subpath", @@ -489,25 +482,27 @@ func TestLoginOAuthRedirect(t *testing.T) { sc := setupScenarioContext(t, "/login") cfg := setting.NewCfg() + mock := &mockSocialService{ + oAuthInfo: &social.OAuthInfo{ + ClientId: "fake", + ClientSecret: "fakefake", + Enabled: true, + AllowSignup: true, + Name: "github", + }, + oAuthInfos: oAuthInfos, + } hs := &HTTPServer{ Cfg: cfg, SettingsProvider: &setting.OSSImpl{Cfg: cfg}, License: &licensing.OSSLicensingService{}, + SocialService: mock, } sc.defaultHandler = routing.Wrap(func(c *models.ReqContext) { hs.LoginView(c) }) - setting.OAuthService = &setting.OAuther{} - setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) - setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{ - ClientId: "fake", - ClientSecret: "fakefake", - Enabled: true, - AllowSignup: true, - Name: "github", - } setting.OAuthAutoLogin = true sc.m.Get(sc.url, sc.defaultHandler) sc.fakeReqNoAssertions("GET", sc.url).exec() @@ -534,15 +529,6 @@ func TestLoginInternal(t *testing.T) { hs.LoginView(c) }) - setting.OAuthService = &setting.OAuther{} - setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) - setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{ - ClientId: "fake", - ClientSecret: "fakefake", - Enabled: true, - AllowSignup: true, - Name: "github", - } setting.OAuthAutoLogin = true sc.m.Get(sc.url, sc.defaultHandler) sc.fakeReqNoAssertions("GET", sc.url).exec() @@ -586,6 +572,7 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte License: &licensing.OSSLicensingService{}, AuthTokenService: auth.NewFakeUserAuthTokenService(), log: log.New("hello"), + SocialService: &mockSocialService{}, } sc.defaultHandler = routing.Wrap(func(c *models.ReqContext) { @@ -596,8 +583,6 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte hs.LoginView(c) }) - setting.OAuthService = &setting.OAuther{} - setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) sc.cfg.AuthProxyEnabled = true sc.cfg.AuthProxyEnableLoginToken = enableLoginToken @@ -713,3 +698,32 @@ func TestLoginPostRunLokingHook(t *testing.T) { }) } } + +type mockSocialService struct { + oAuthInfo *social.OAuthInfo + oAuthInfos map[string]*social.OAuthInfo + oAuthProviders map[string]bool + httpClient *http.Client + socialConnector social.SocialConnector + err error +} + +func (m *mockSocialService) GetOAuthInfoProvider(name string) *social.OAuthInfo { + return m.oAuthInfo +} + +func (m *mockSocialService) GetOAuthInfoProviders() map[string]*social.OAuthInfo { + return m.oAuthInfos +} + +func (m *mockSocialService) GetOAuthProviders() map[string]bool { + return m.oAuthProviders +} + +func (m *mockSocialService) GetOAuthHttpClient(name string) (*http.Client, error) { + return m.httpClient, m.err +} + +func (m *mockSocialService) GetConnector(string) (social.SocialConnector, error) { + return m.socialConnector, m.err +} diff --git a/pkg/api/pluginproxy/ds_proxy.go b/pkg/api/pluginproxy/ds_proxy.go index 2e20e6a3c1b..8d89ace0798 100644 --- a/pkg/api/pluginproxy/ds_proxy.go +++ b/pkg/api/pluginproxy/ds_proxy.go @@ -31,14 +31,15 @@ var ( ) type DataSourceProxy struct { - ds *models.DataSource - ctx *models.ReqContext - targetUrl *url.URL - proxyPath string - route *plugins.AppPluginRoute - plugin *plugins.DataSourcePlugin - cfg *setting.Cfg - clientProvider httpclient.Provider + ds *models.DataSource + ctx *models.ReqContext + targetUrl *url.URL + proxyPath string + route *plugins.AppPluginRoute + plugin *plugins.DataSourcePlugin + cfg *setting.Cfg + clientProvider httpclient.Provider + oAuthTokenService oauthtoken.OAuthTokenService } type handleResponseTransport struct { @@ -71,20 +72,21 @@ func (lw *logWrapper) Write(p []byte) (n int, err error) { // NewDataSourceProxy creates a new Datasource proxy func NewDataSourceProxy(ds *models.DataSource, plugin *plugins.DataSourcePlugin, ctx *models.ReqContext, - proxyPath string, cfg *setting.Cfg, clientProvider httpclient.Provider) (*DataSourceProxy, error) { + proxyPath string, cfg *setting.Cfg, clientProvider httpclient.Provider, oAuthTokenService oauthtoken.OAuthTokenService) (*DataSourceProxy, error) { targetURL, err := datasource.ValidateURL(ds.Type, ds.Url) if err != nil { return nil, err } return &DataSourceProxy{ - ds: ds, - plugin: plugin, - ctx: ctx, - proxyPath: proxyPath, - targetUrl: targetURL, - cfg: cfg, - clientProvider: clientProvider, + ds: ds, + plugin: plugin, + ctx: ctx, + proxyPath: proxyPath, + targetUrl: targetURL, + cfg: cfg, + clientProvider: clientProvider, + oAuthTokenService: oAuthTokenService, }, nil } @@ -237,8 +239,8 @@ func (proxy *DataSourceProxy) director(req *http.Request) { ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, proxy.cfg) } - if oauthtoken.IsOAuthPassThruEnabled(proxy.ds) { - if token := oauthtoken.GetCurrentOAuthToken(proxy.ctx.Req.Context(), proxy.ctx.SignedInUser); token != nil { + if proxy.oAuthTokenService.IsOAuthPassThruEnabled(proxy.ds) { + if token := proxy.oAuthTokenService.GetCurrentOAuthToken(proxy.ctx.Req.Context(), proxy.ctx.SignedInUser); token != nil { req.Header.Set("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken)) } } diff --git a/pkg/api/pluginproxy/ds_proxy_test.go b/pkg/api/pluginproxy/ds_proxy_test.go index a6650963f01..d84adc76e31 100644 --- a/pkg/api/pluginproxy/ds_proxy_test.go +++ b/pkg/api/pluginproxy/ds_proxy_test.go @@ -2,6 +2,7 @@ package pluginproxy import ( "bytes" + "context" "fmt" "io/ioutil" "net/http" @@ -16,9 +17,9 @@ import ( "github.com/grafana/grafana/pkg/components/securejsondata" "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/infra/httpclient" - "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" + "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" "github.com/stretchr/testify/assert" @@ -114,7 +115,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { t.Run("When matching route path", func(t *testing.T) { ctx, req := setUp() - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) proxy.route = plugin.Routes[0] ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg) @@ -125,7 +126,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { t.Run("When matching route path and has dynamic url", func(t *testing.T) { ctx, req := setUp() - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/common/some/method", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/common/some/method", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) proxy.route = plugin.Routes[3] ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg) @@ -136,7 +137,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { t.Run("When matching route path with no url", func(t *testing.T) { ctx, req := setUp() - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) proxy.route = plugin.Routes[4] ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg) @@ -146,7 +147,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { t.Run("When matching route path and has dynamic body", func(t *testing.T) { ctx, req := setUp() - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/body", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/body", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) proxy.route = plugin.Routes[5] ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg) @@ -159,7 +160,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { t.Run("Validating request", func(t *testing.T) { t.Run("plugin route with valid role", func(t *testing.T) { ctx, _ := setUp() - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) err = proxy.validateRequest() require.NoError(t, err) @@ -167,7 +168,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { t.Run("plugin route with admin role and user is editor", func(t *testing.T) { ctx, _ := setUp() - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) err = proxy.validateRequest() require.Error(t, err) @@ -176,7 +177,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { t.Run("plugin route with admin role and user is admin", func(t *testing.T) { ctx, _ := setUp() ctx.SignedInUser.OrgRole = models.ROLE_ADMIN - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) err = proxy.validateRequest() require.NoError(t, err) @@ -258,7 +259,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { cfg := &setting.Cfg{} - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds, cfg) @@ -273,7 +274,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { req, err := http.NewRequest("GET", "http://localhost/asd", nil) require.NoError(t, err) client = newFakeHTTPClient(t, json2) - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken2", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken2", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[1], proxy.ds, cfg) @@ -289,7 +290,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { require.NoError(t, err) client = newFakeHTTPClient(t, []byte{}) - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds, cfg) @@ -309,7 +310,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { ds := &models.DataSource{Url: "htttp://graphite:8080", Type: models.DS_GRAPHITE} ctx := &models.ReqContext{} - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{BuildVersion: "5.3.0"}, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{BuildVersion: "5.3.0"}, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil) require.NoError(t, err) @@ -334,7 +335,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { } ctx := &models.ReqContext{} - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil) @@ -357,7 +358,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { } ctx := &models.ReqContext{} - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) requestURL, err := url.Parse("http://grafana.com/sub") @@ -384,7 +385,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { } ctx := &models.ReqContext{} - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) requestURL, err := url.Parse("http://grafana.com/sub") @@ -405,7 +406,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) { Url: "http://host/root/", } ctx := &models.ReqContext{} - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil) req.Header.Set("Origin", "grafana.com") @@ -423,19 +424,6 @@ func TestDataSourceProxy_routeRule(t *testing.T) { }) t.Run("When proxying a datasource that has OAuth token pass-through enabled", func(t *testing.T) { - social.SocialMap["generic_oauth"] = &social.SocialGenericOAuth{ - SocialBase: &social.SocialBase{ - Config: &oauth2.Config{}, - }, - } - origAuthSvc := setting.OAuthService - t.Cleanup(func() { - setting.OAuthService = origAuthSvc - }) - setting.OAuthService = &setting.OAuther{} - setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) - setting.OAuthService.OAuthInfos["generic_oauth"] = &setting.OAuthInfo{} - bus.AddHandler("test", func(query *models.GetAuthInfoQuery) error { query.Result = &models.UserAuth{ Id: 1, @@ -466,7 +454,16 @@ func TestDataSourceProxy_routeRule(t *testing.T) { Req: macaron.Request{Request: req}, }, } - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider) + mockAuthToken := mockOAuthTokenService{ + token: &oauth2.Token{ + AccessToken: "testtoken", + RefreshToken: "testrefreshtoken", + TokenType: "Bearer", + Expiry: time.Now().AddDate(0, 0, 1), + }, + oAuthEnabled: true, + } + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider, &mockAuthToken) require.NoError(t, err) req, err = http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil) require.NoError(t, err) @@ -600,7 +597,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) { t.Run("When response header Set-Cookie is not set should remove proxied Set-Cookie header", func(t *testing.T) { ctx, ds := setUp(t) - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) proxy.HandleRequest() @@ -615,7 +612,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) { "Set-Cookie": "important_cookie=important_value", }, }) - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) proxy.HandleRequest() @@ -634,7 +631,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) { t.Log("Wrote 401 response") }, }) - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) proxy.HandleRequest() @@ -656,7 +653,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) { }) ctx.Req.Request = httptest.NewRequest("GET", "/api/datasources/proxy/1/path/%2Ftest%2Ftest%2F?query=%2Ftest%2Ftest%2F", nil) - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/%2Ftest%2Ftest%2F", &setting.Cfg{}, httpClientProvider) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/%2Ftest%2Ftest%2F", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{}) require.NoError(t, err) proxy.HandleRequest() @@ -680,7 +677,7 @@ func TestNewDataSourceProxy_InvalidURL(t *testing.T) { } cfg := setting.Cfg{} plugin := plugins.DataSourcePlugin{} - _, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider()) + _, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider(), &oauthtoken.Service{}) require.Error(t, err) assert.True(t, strings.HasPrefix(err.Error(), `validation of data source URL "://host/root" failed`)) } @@ -699,7 +696,7 @@ func TestNewDataSourceProxy_ProtocolLessURL(t *testing.T) { cfg := setting.Cfg{} plugin := plugins.DataSourcePlugin{} - _, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider()) + _, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider(), &oauthtoken.Service{}) require.NoError(t, err) } @@ -739,7 +736,7 @@ func TestNewDataSourceProxy_MSSQL(t *testing.T) { Url: tc.url, } - p, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider()) + p, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider(), &oauthtoken.Service{}) if tc.err == nil { require.NoError(t, err) assert.Equal(t, &url.URL{ @@ -777,7 +774,7 @@ func getDatasourceProxiedRequest(t *testing.T, ctx *models.ReqContext, cfg *sett Url: "http://host/root/", } - proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpclient.NewProvider()) + proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpclient.NewProvider(), &oauthtoken.Service{}) require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil) require.NoError(t, err) @@ -888,7 +885,7 @@ func createAuthTest(t *testing.T, dsType string, authType string, authCheck stri func runDatasourceAuthTest(t *testing.T, test *testCase) { plugin := &plugins.DataSourcePlugin{} ctx := &models.ReqContext{} - proxy, err := NewDataSourceProxy(test.datasource, plugin, ctx, "", &setting.Cfg{}, httpclient.NewProvider()) + proxy, err := NewDataSourceProxy(test.datasource, plugin, ctx, "", &setting.Cfg{}, httpclient.NewProvider(), &oauthtoken.Service{}) require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil) @@ -929,9 +926,22 @@ func Test_PathCheck(t *testing.T) { return ctx, req } ctx, _ := setUp() - proxy, err := NewDataSourceProxy(&models.DataSource{}, plugin, ctx, "b", &setting.Cfg{}, httpclient.NewProvider()) + proxy, err := NewDataSourceProxy(&models.DataSource{}, plugin, ctx, "b", &setting.Cfg{}, httpclient.NewProvider(), &oauthtoken.Service{}) require.NoError(t, err) require.Nil(t, proxy.validateRequest()) require.Equal(t, plugin.Routes[1], proxy.route) } + +type mockOAuthTokenService struct { + token *oauth2.Token + oAuthEnabled bool +} + +func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth2.Token { + return m.token +} + +func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *models.DataSource) bool { + return m.oAuthEnabled +} diff --git a/pkg/infra/usagestats/service.go b/pkg/infra/usagestats/service.go index 5d452df4ca3..2888df3c379 100644 --- a/pkg/infra/usagestats/service.go +++ b/pkg/infra/usagestats/service.go @@ -39,6 +39,7 @@ type UsageStatsService struct { AlertingUsageStats alerting.UsageStatsQuerier `inject:""` License models.Licensing `inject:""` PluginManager plugins.Manager `inject:""` + SocialService social.Service `inject:""` log log.Logger @@ -48,7 +49,7 @@ type UsageStatsService struct { } func (uss *UsageStatsService) Init() error { - uss.oauthProviders = social.GetOAuthProviders(uss.Cfg) + uss.oauthProviders = uss.SocialService.GetOAuthProviders() return nil } diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index 5d87dd4eee3..26f7050472a 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -13,6 +13,7 @@ import ( "golang.org/x/oauth2" "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" ) @@ -21,6 +22,184 @@ var ( logger = log.New("social") ) +func init() { + registry.RegisterService(&SocialService{}) +} + +type SocialService struct { + Cfg *setting.Cfg `inject:""` + + socialMap map[string]SocialConnector + oAuthProvider map[string]*OAuthInfo +} + +type OAuthInfo struct { + ClientId, ClientSecret string + Scopes []string + AuthUrl, TokenUrl string + Enabled bool + EmailAttributeName string + EmailAttributePath string + RoleAttributePath string + RoleAttributeStrict bool + GroupsAttributePath string + AllowedDomains []string + HostedDomain string + ApiUrl string + AllowSignup bool + Name string + TlsClientCert string + TlsClientKey string + TlsClientCa string + TlsSkipVerify bool +} + +func (ss *SocialService) Init() error { + ss.oAuthProvider = make(map[string]*OAuthInfo) + ss.socialMap = make(map[string]SocialConnector) + + for _, name := range allOauthes { + sec := ss.Cfg.Raw.Section("auth." + name) + + info := &OAuthInfo{ + ClientId: sec.Key("client_id").String(), + ClientSecret: sec.Key("client_secret").String(), + Scopes: util.SplitString(sec.Key("scopes").String()), + AuthUrl: sec.Key("auth_url").String(), + TokenUrl: sec.Key("token_url").String(), + ApiUrl: sec.Key("api_url").String(), + Enabled: sec.Key("enabled").MustBool(), + EmailAttributeName: sec.Key("email_attribute_name").String(), + EmailAttributePath: sec.Key("email_attribute_path").String(), + RoleAttributePath: sec.Key("role_attribute_path").String(), + RoleAttributeStrict: sec.Key("role_attribute_strict").MustBool(), + GroupsAttributePath: sec.Key("groups_attribute_path").String(), + AllowedDomains: util.SplitString(sec.Key("allowed_domains").String()), + HostedDomain: sec.Key("hosted_domain").String(), + AllowSignup: sec.Key("allow_sign_up").MustBool(), + Name: sec.Key("name").MustString(name), + TlsClientCert: sec.Key("tls_client_cert").String(), + TlsClientKey: sec.Key("tls_client_key").String(), + TlsClientCa: sec.Key("tls_client_ca").String(), + TlsSkipVerify: sec.Key("tls_skip_verify_insecure").MustBool(), + } + + // when empty_scopes parameter exists and is true, overwrite scope with empty value + if sec.Key("empty_scopes").MustBool() { + info.Scopes = []string{} + } + + if !info.Enabled { + continue + } + + if name == "grafananet" { + name = grafanaCom + } + + ss.oAuthProvider[name] = info + + config := oauth2.Config{ + ClientID: info.ClientId, + ClientSecret: info.ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: info.AuthUrl, + TokenURL: info.TokenUrl, + AuthStyle: oauth2.AuthStyleAutoDetect, + }, + RedirectURL: strings.TrimSuffix(ss.Cfg.AppURL, "/") + SocialBaseUrl + name, + Scopes: info.Scopes, + } + + // GitHub. + if name == "github" { + ss.socialMap["github"] = &SocialGithub{ + SocialBase: newSocialBase(name, &config, info), + apiUrl: info.ApiUrl, + teamIds: sec.Key("team_ids").Ints(","), + allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), + } + } + + // GitLab. + if name == "gitlab" { + ss.socialMap["gitlab"] = &SocialGitlab{ + SocialBase: newSocialBase(name, &config, info), + apiUrl: info.ApiUrl, + allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), + } + } + + // Google. + if name == "google" { + ss.socialMap["google"] = &SocialGoogle{ + SocialBase: newSocialBase(name, &config, info), + hostedDomain: info.HostedDomain, + apiUrl: info.ApiUrl, + } + } + + // AzureAD. + if name == "azuread" { + ss.socialMap["azuread"] = &SocialAzureAD{ + SocialBase: newSocialBase(name, &config, info), + allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), + autoAssignOrgRole: ss.Cfg.AutoAssignOrgRole, + } + } + + // Okta + if name == "okta" { + ss.socialMap["okta"] = &SocialOkta{ + SocialBase: newSocialBase(name, &config, info), + apiUrl: info.ApiUrl, + allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), + roleAttributePath: info.RoleAttributePath, + roleAttributeStrict: info.RoleAttributeStrict, + } + } + + // Generic - Uses the same scheme as GitHub. + if name == "generic_oauth" { + ss.socialMap["generic_oauth"] = &SocialGenericOAuth{ + SocialBase: newSocialBase(name, &config, info), + apiUrl: info.ApiUrl, + emailAttributeName: info.EmailAttributeName, + emailAttributePath: info.EmailAttributePath, + nameAttributePath: sec.Key("name_attribute_path").String(), + roleAttributePath: info.RoleAttributePath, + roleAttributeStrict: info.RoleAttributeStrict, + groupsAttributePath: info.GroupsAttributePath, + loginAttributePath: sec.Key("login_attribute_path").String(), + idTokenAttributeName: sec.Key("id_token_attribute_name").String(), + teamIds: sec.Key("team_ids").Ints(","), + allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), + } + } + + if name == grafanaCom { + config = oauth2.Config{ + ClientID: info.ClientId, + ClientSecret: info.ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: ss.Cfg.GrafanaComURL + "/oauth2/authorize", + TokenURL: ss.Cfg.GrafanaComURL + "/api/oauth2/token", + AuthStyle: oauth2.AuthStyleInHeader, + }, + RedirectURL: strings.TrimSuffix(ss.Cfg.AppURL, "/") + SocialBaseUrl + name, + Scopes: info.Scopes, + } + + ss.socialMap[grafanaCom] = &SocialGrafanaCom{ + SocialBase: newSocialBase(name, &config, info), + url: ss.Cfg.GrafanaComURL, + allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), + } + } + } + return nil +} + type BasicUserInfo struct { Id string Name string @@ -68,7 +247,15 @@ var ( allOauthes = []string{"github", "gitlab", "google", "generic_oauth", "grafananet", grafanaCom, "azuread", "okta"} ) -func newSocialBase(name string, config *oauth2.Config, info *setting.OAuthInfo) *SocialBase { +type Service interface { + GetOAuthProviders() map[string]bool + GetOAuthHttpClient(string) (*http.Client, error) + GetConnector(string) (SocialConnector, error) + GetOAuthInfoProvider(string) *OAuthInfo + GetOAuthInfoProviders() map[string]*OAuthInfo +} + +func newSocialBase(name string, config *oauth2.Config, info *OAuthInfo) *SocialBase { logger := log.New("oauth." + name) return &SocialBase{ @@ -79,156 +266,11 @@ func newSocialBase(name string, config *oauth2.Config, info *setting.OAuthInfo) } } -func NewOAuthService(cfg *setting.Cfg) { - setting.OAuthService = &setting.OAuther{} - setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) - - for _, name := range allOauthes { - sec := cfg.Raw.Section("auth." + name) - - info := &setting.OAuthInfo{ - ClientId: sec.Key("client_id").String(), - ClientSecret: sec.Key("client_secret").String(), - Scopes: util.SplitString(sec.Key("scopes").String()), - AuthUrl: sec.Key("auth_url").String(), - TokenUrl: sec.Key("token_url").String(), - ApiUrl: sec.Key("api_url").String(), - Enabled: sec.Key("enabled").MustBool(), - EmailAttributeName: sec.Key("email_attribute_name").String(), - EmailAttributePath: sec.Key("email_attribute_path").String(), - RoleAttributePath: sec.Key("role_attribute_path").String(), - RoleAttributeStrict: sec.Key("role_attribute_strict").MustBool(), - GroupsAttributePath: sec.Key("groups_attribute_path").String(), - AllowedDomains: util.SplitString(sec.Key("allowed_domains").String()), - HostedDomain: sec.Key("hosted_domain").String(), - AllowSignup: sec.Key("allow_sign_up").MustBool(), - Name: sec.Key("name").MustString(name), - TlsClientCert: sec.Key("tls_client_cert").String(), - TlsClientKey: sec.Key("tls_client_key").String(), - TlsClientCa: sec.Key("tls_client_ca").String(), - TlsSkipVerify: sec.Key("tls_skip_verify_insecure").MustBool(), - } - - // when empty_scopes parameter exists and is true, overwrite scope with empty value - if sec.Key("empty_scopes").MustBool() { - info.Scopes = []string{} - } - - if !info.Enabled { - continue - } - - if name == "grafananet" { - name = grafanaCom - } - - setting.OAuthService.OAuthInfos[name] = info - - config := oauth2.Config{ - ClientID: info.ClientId, - ClientSecret: info.ClientSecret, - Endpoint: oauth2.Endpoint{ - AuthURL: info.AuthUrl, - TokenURL: info.TokenUrl, - AuthStyle: oauth2.AuthStyleAutoDetect, - }, - RedirectURL: strings.TrimSuffix(cfg.AppURL, "/") + SocialBaseUrl + name, - Scopes: info.Scopes, - } - - // GitHub. - if name == "github" { - SocialMap["github"] = &SocialGithub{ - SocialBase: newSocialBase(name, &config, info), - apiUrl: info.ApiUrl, - teamIds: sec.Key("team_ids").Ints(","), - allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), - } - } - - // GitLab. - if name == "gitlab" { - SocialMap["gitlab"] = &SocialGitlab{ - SocialBase: newSocialBase(name, &config, info), - apiUrl: info.ApiUrl, - allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), - } - } - - // Google. - if name == "google" { - SocialMap["google"] = &SocialGoogle{ - SocialBase: newSocialBase(name, &config, info), - hostedDomain: info.HostedDomain, - apiUrl: info.ApiUrl, - } - } - - // AzureAD. - if name == "azuread" { - SocialMap["azuread"] = &SocialAzureAD{ - SocialBase: newSocialBase(name, &config, info), - allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), - autoAssignOrgRole: cfg.AutoAssignOrgRole, - } - } - - // Okta - if name == "okta" { - SocialMap["okta"] = &SocialOkta{ - SocialBase: newSocialBase(name, &config, info), - apiUrl: info.ApiUrl, - allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), - roleAttributePath: info.RoleAttributePath, - roleAttributeStrict: info.RoleAttributeStrict, - } - } - - // Generic - Uses the same scheme as GitHub. - if name == "generic_oauth" { - SocialMap["generic_oauth"] = &SocialGenericOAuth{ - SocialBase: newSocialBase(name, &config, info), - apiUrl: info.ApiUrl, - emailAttributeName: info.EmailAttributeName, - emailAttributePath: info.EmailAttributePath, - nameAttributePath: sec.Key("name_attribute_path").String(), - roleAttributePath: info.RoleAttributePath, - roleAttributeStrict: info.RoleAttributeStrict, - groupsAttributePath: info.GroupsAttributePath, - loginAttributePath: sec.Key("login_attribute_path").String(), - idTokenAttributeName: sec.Key("id_token_attribute_name").String(), - teamIds: sec.Key("team_ids").Ints(","), - allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), - } - } - - if name == grafanaCom { - config = oauth2.Config{ - ClientID: info.ClientId, - ClientSecret: info.ClientSecret, - Endpoint: oauth2.Endpoint{ - AuthURL: cfg.GrafanaComURL + "/oauth2/authorize", - TokenURL: cfg.GrafanaComURL + "/api/oauth2/token", - AuthStyle: oauth2.AuthStyleInHeader, - }, - RedirectURL: strings.TrimSuffix(cfg.AppURL, "/") + SocialBaseUrl + name, - Scopes: info.Scopes, - } - - SocialMap[grafanaCom] = &SocialGrafanaCom{ - SocialBase: newSocialBase(name, &config, info), - url: cfg.GrafanaComURL, - allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), - } - } - } -} - // GetOAuthProviders returns available oauth providers and if they're enabled or not -var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool { +func (ss *SocialService) GetOAuthProviders() map[string]bool { result := map[string]bool{} - if cfg == nil || cfg.Raw == nil { + if ss.Cfg == nil || ss.Cfg.Raw == nil { return result } @@ -237,7 +279,7 @@ var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool { name = grafanaCom } - sec := cfg.Raw.Section("auth." + name) + sec := ss.Cfg.Raw.Section("auth." + name) if sec == nil { continue } @@ -247,13 +289,10 @@ var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool { return result } -func GetOAuthHttpClient(name string) (*http.Client, error) { - if setting.OAuthService == nil { - return nil, fmt.Errorf("OAuth not enabled") - } +func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) { // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does name = strings.TrimPrefix(name, "oauth_") - info, ok := setting.OAuthService.OAuthInfos[name] + info, ok := ss.oAuthProvider[name] if !ok { return nil, fmt.Errorf("could not find %q in OAuth Settings", name) } @@ -292,12 +331,20 @@ func GetOAuthHttpClient(name string) (*http.Client, error) { return oauthClient, nil } -func GetConnector(name string) (SocialConnector, error) { +func (ss *SocialService) GetConnector(name string) (SocialConnector, error) { // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does provider := strings.TrimPrefix(name, "oauth_") - connector, ok := SocialMap[provider] + connector, ok := ss.socialMap[provider] if !ok { return nil, fmt.Errorf("failed to find oauth provider for %q", name) } return connector, nil } + +func (ss *SocialService) GetOAuthInfoProvider(name string) *OAuthInfo { + return ss.oAuthProvider[name] +} + +func (ss *SocialService) GetOAuthInfoProviders() map[string]*OAuthInfo { + return ss.oAuthProvider +} diff --git a/pkg/server/server.go b/pkg/server/server.go index fbe2ac3a1c0..674ec0f8cef 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -28,7 +28,7 @@ import ( _ "github.com/grafana/grafana/pkg/infra/tracing" _ "github.com/grafana/grafana/pkg/infra/usagestats" "github.com/grafana/grafana/pkg/login" - "github.com/grafana/grafana/pkg/login/social" + _ "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/middleware" _ "github.com/grafana/grafana/pkg/plugins/manager" "github.com/grafana/grafana/pkg/registry" @@ -150,7 +150,6 @@ func (s *Server) init() error { } login.Init() - social.NewOAuthService(s.cfg) services := s.serviceRegistry.GetServices() if err := s.buildServiceGraph(services); err != nil { diff --git a/pkg/services/datasourceproxy/datasourceproxy.go b/pkg/services/datasourceproxy/datasourceproxy.go index 177e92b2005..ca88494eb86 100644 --- a/pkg/services/datasourceproxy/datasourceproxy.go +++ b/pkg/services/datasourceproxy/datasourceproxy.go @@ -14,6 +14,7 @@ import ( "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/services/datasources" + "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/setting" ) @@ -27,6 +28,7 @@ type DatasourceProxyService struct { PluginManager plugins.Manager `inject:""` Cfg *setting.Cfg `inject:""` HTTPClientProvider httpclient.Provider `inject:""` + OAuthTokenService *oauthtoken.Service `inject:""` } func (p *DatasourceProxyService) Init() error { @@ -68,7 +70,7 @@ func (p *DatasourceProxyService) ProxyDatasourceRequestWithID(c *models.ReqConte } proxyPath := getProxyPath(c) - proxy, err := pluginproxy.NewDataSourceProxy(ds, plugin, c, proxyPath, p.Cfg, p.HTTPClientProvider) + proxy, err := pluginproxy.NewDataSourceProxy(ds, plugin, c, proxyPath, p.Cfg, p.HTTPClientProvider, p.OAuthTokenService) if err != nil { if errors.Is(err, datasource.URLValidationError{}) { c.JsonApiErr(http.StatusBadRequest, fmt.Sprintf("Invalid data source URL: %q", ds.Url), err) diff --git a/pkg/services/oauthtoken/oauth_token_util.go b/pkg/services/oauthtoken/oauth_token.go similarity index 81% rename from pkg/services/oauthtoken/oauth_token_util.go rename to pkg/services/oauthtoken/oauth_token.go index 3fd246ed0ad..0e85e2844dc 100644 --- a/pkg/services/oauthtoken/oauth_token_util.go +++ b/pkg/services/oauthtoken/oauth_token.go @@ -8,6 +8,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/registry" "golang.org/x/oauth2" ) @@ -15,8 +16,25 @@ var ( logger = log.New("oauthtoken") ) +func init() { + registry.RegisterService(&Service{}) +} + +type Service struct { + SocialService social.Service `inject:""` +} + +type OAuthTokenService interface { + GetCurrentOAuthToken(context.Context, *models.SignedInUser) *oauth2.Token + IsOAuthPassThruEnabled(*models.DataSource) bool +} + +func (o *Service) Init() error { + return nil +} + // GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired. -func GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth2.Token { +func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth2.Token { if user == nil { // No user, therefore no token return nil @@ -34,13 +52,13 @@ func GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth } authProvider := authInfoQuery.Result.AuthModule - connect, err := social.GetConnector(authProvider) + connect, err := o.SocialService.GetConnector(authProvider) if err != nil { logger.Error("failed to get OAuth connector", "provider", authProvider, "error", err) return nil } - client, err := social.GetOAuthHttpClient(authProvider) + client, err := o.SocialService.GetOAuthHttpClient(authProvider) if err != nil { logger.Error("failed to get OAuth http client", "provider", authProvider, "error", err) return nil @@ -78,7 +96,7 @@ func GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth } // IsOAuthPassThruEnabled returns true if Forward OAuth Identity (oauthPassThru) is enabled for the provided data source. -func IsOAuthPassThruEnabled(ds *models.DataSource) bool { +func (o *Service) IsOAuthPassThruEnabled(ds *models.DataSource) bool { return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool() } diff --git a/pkg/setting/setting_oauth.go b/pkg/setting/setting_oauth.go deleted file mode 100644 index d388ff0ab52..00000000000 --- a/pkg/setting/setting_oauth.go +++ /dev/null @@ -1,28 +0,0 @@ -package setting - -type OAuthInfo struct { - ClientId, ClientSecret string - Scopes []string - AuthUrl, TokenUrl string - Enabled bool - EmailAttributeName string - EmailAttributePath string - RoleAttributePath string - RoleAttributeStrict bool - GroupsAttributePath string - AllowedDomains []string - HostedDomain string - ApiUrl string - AllowSignup bool - Name string - TlsClientCert string - TlsClientKey string - TlsClientCa string - TlsSkipVerify bool -} - -type OAuther struct { - OAuthInfos map[string]*OAuthInfo -} - -var OAuthService *OAuther diff --git a/pkg/tsdb/data_plugin_adapter.go b/pkg/tsdb/data_plugin_adapter.go index db1febc6774..9ae7517b053 100644 --- a/pkg/tsdb/data_plugin_adapter.go +++ b/pkg/tsdb/data_plugin_adapter.go @@ -13,7 +13,7 @@ import ( ) // nolint:staticcheck // plugins.DataQuery deprecated -func dataPluginQueryAdapter(pluginID string, handler backend.QueryDataHandler) plugins.DataPluginFunc { +func dataPluginQueryAdapter(pluginID string, handler backend.QueryDataHandler, oAuthService *oauthtoken.Service) plugins.DataPluginFunc { return plugins.DataPluginFunc(func(ctx context.Context, ds *models.DataSource, query plugins.DataQuery) (plugins.DataResponse, error) { instanceSettings, err := modelToInstanceSettings(ds) if err != nil { @@ -24,8 +24,8 @@ func dataPluginQueryAdapter(pluginID string, handler backend.QueryDataHandler) p query.Headers = make(map[string]string) } - if oauthtoken.IsOAuthPassThruEnabled(ds) { - if token := oauthtoken.GetCurrentOAuthToken(ctx, query.User); token != nil { + if oAuthService.IsOAuthPassThruEnabled(ds) { + if token := oAuthService.GetCurrentOAuthToken(ctx, query.User); token != nil { delete(query.Headers, "Authorization") query.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken) } diff --git a/pkg/tsdb/service.go b/pkg/tsdb/service.go index 59eda1a1c1d..d505d173857 100644 --- a/pkg/tsdb/service.go +++ b/pkg/tsdb/service.go @@ -9,6 +9,7 @@ import ( "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins/backendplugin" "github.com/grafana/grafana/pkg/registry" + "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb/azuremonitor" "github.com/grafana/grafana/pkg/tsdb/cloudmonitoring" @@ -48,6 +49,7 @@ type Service struct { PluginManager plugins.Manager `inject:""` BackendPluginManager backendplugin.Manager `inject:""` HTTPClientProvider httpclient.Provider `inject:""` + OAuthTokenService *oauthtoken.Service `inject:""` //nolint: staticcheck // plugins.DataPlugin deprecated registry map[string]func(*models.DataSource) (plugins.DataPlugin, error) @@ -81,7 +83,7 @@ func (s *Service) HandleRequest(ctx context.Context, ds *models.DataSource, quer return plugin.DataQuery(ctx, ds, query) } - return dataPluginQueryAdapter(ds.Type, s.BackendPluginManager).DataQuery(ctx, ds, query) + return dataPluginQueryAdapter(ds.Type, s.BackendPluginManager, s.OAuthTokenService).DataQuery(ctx, ds, query) } // RegisterQueryHandler registers a query handler factory.