Chore: Refactor OAuth/social package to service (#35403)

* Creating SocialService

* Add GetOAuthProviders as socialService method

* Add OAuthTokenService

* Add GetOAuthHttpClient method to SocialService

* Rename services, access socialMap from GetConnector

* Fix tests by mocking oauthtoken methods

* Move NewAuthService into Init

* Move OAuthService to social pkg

* Refactor OAuthService to OAuthProvider

* Fix nil map error, rename file, simplify tests

* Fix bug for Forward OAuth Identify

* Remove file after rebase
This commit is contained in:
idafurjes 2021-07-07 08:54:17 +02:00 committed by GitHub
parent 55e763b4cd
commit 60ac54d969
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 380 additions and 307 deletions

View File

@ -13,8 +13,10 @@ import (
"strings"
"sync"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/libraryelements"
"github.com/grafana/grafana/pkg/services/librarypanels"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/api/routing"
httpstatic "github.com/grafana/grafana/pkg/api/static"
@ -105,6 +107,8 @@ type HTTPServer struct {
Alertmanager *notifier.Alertmanager `inject:""`
LibraryPanelService librarypanels.Service `inject:""`
LibraryElementService libraryelements.Service `inject:""`
SocialService social.Service `inject:""`
OAuthTokenService *oauthtoken.Service `inject:""`
Listener net.Listener
}

View File

@ -91,7 +91,8 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) {
}
enabledOAuths := make(map[string]interface{})
for key, oauth := range setting.OAuthService.OAuthInfos {
providers := hs.SocialService.GetOAuthInfoProviders()
for key, oauth := range providers {
enabledOAuths[key] = map[string]string{"name": oauth.Name}
}
@ -147,12 +148,12 @@ func (hs *HTTPServer) tryOAuthAutoLogin(c *models.ReqContext) bool {
if !setting.OAuthAutoLogin {
return false
}
oauthInfos := setting.OAuthService.OAuthInfos
oauthInfos := hs.SocialService.GetOAuthInfoProviders()
if len(oauthInfos) != 1 {
log.Warnf("Skipping OAuth auto login because multiple OAuth providers are configured")
return false
}
for key := range setting.OAuthService.OAuthInfos {
for key := range oauthInfos {
redirectUrl := hs.Cfg.AppSubURL + "/login/" + key
log.Infof("OAuth auto login enabled. Redirecting to " + redirectUrl)
c.Redirect(redirectUrl, 307)

View File

@ -41,7 +41,10 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
loginInfo := models.LoginInfo{
AuthModule: "oauth",
}
if setting.OAuthService == nil {
name := ctx.Params(":name")
loginInfo.AuthModule = name
provider := hs.SocialService.GetOAuthInfoProvider(name)
if provider == nil {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
HttpStatus: http.StatusNotFound,
PublicMessage: "OAuth not enabled",
@ -49,10 +52,8 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
return
}
name := ctx.Params(":name")
loginInfo.AuthModule = name
connect, ok := social.SocialMap[name]
if !ok {
connect, err := hs.SocialService.GetConnector(name)
if err != nil {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
HttpStatus: http.StatusNotFound,
PublicMessage: fmt.Sprintf("No OAuth with name %s configured", name),
@ -80,12 +81,12 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
return
}
hashedState := hashStatecode(state, setting.OAuthService.OAuthInfos[name].ClientSecret)
hashedState := hashStatecode(state, provider.ClientSecret)
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
if setting.OAuthService.OAuthInfos[name].HostedDomain == "" {
if provider.HostedDomain == "" {
ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline))
} else {
ctx.Redirect(connect.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", setting.OAuthService.OAuthInfos[name].HostedDomain), oauth2.AccessTypeOnline))
ctx.Redirect(connect.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", provider.HostedDomain), oauth2.AccessTypeOnline))
}
return
}
@ -103,7 +104,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
return
}
queryState := hashStatecode(ctx.Query("state"), setting.OAuthService.OAuthInfos[name].ClientSecret)
queryState := hashStatecode(ctx.Query("state"), provider.ClientSecret)
oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState)
if cookieState != queryState {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
@ -113,7 +114,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
return
}
oauthClient, err := social.GetOAuthHttpClient(name)
oauthClient, err := hs.SocialService.GetOAuthHttpClient(name)
if err != nil {
ctx.Logger.Error("Failed to create OAuth http client", "error", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{

View File

@ -17,6 +17,7 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/hooks"
@ -86,6 +87,16 @@ type redirectCase struct {
redirectURL string
}
var oAuthInfos = map[string]*social.OAuthInfo{
"github": {
ClientId: "fake",
ClientSecret: "fakefake",
Enabled: true,
AllowSignup: true,
Name: "github",
},
}
func TestLoginErrorCookieAPIEndpoint(t *testing.T) {
fakeSetIndexViewData(t)
@ -97,6 +108,7 @@ func TestLoginErrorCookieAPIEndpoint(t *testing.T) {
Cfg: cfg,
SettingsProvider: &setting.OSSImpl{Cfg: cfg},
License: &licensing.OSSLicensingService{},
SocialService: &mockSocialService{},
}
sc.defaultHandler = routing.Wrap(func(w http.ResponseWriter, c *models.ReqContext) {
@ -106,23 +118,6 @@ func TestLoginErrorCookieAPIEndpoint(t *testing.T) {
cfg.LoginCookieName = "grafana_session"
setting.SecretKey = "login_testing"
origOAuthService := setting.OAuthService
origOAuthAutoLogin := setting.OAuthAutoLogin
t.Cleanup(func() {
setting.OAuthService = origOAuthService
setting.OAuthAutoLogin = origOAuthAutoLogin
})
setting.OAuthService = &setting.OAuther{
OAuthInfos: map[string]*setting.OAuthInfo{
"github": {
ClientId: "fake",
ClientSecret: "fakefake",
Enabled: true,
AllowSignup: true,
Name: "github",
},
},
}
setting.OAuthAutoLogin = true
oauthError := errors.New("User not a member of one of the required organizations")
@ -160,6 +155,7 @@ func TestLoginViewRedirect(t *testing.T) {
Cfg: cfg,
SettingsProvider: &setting.OSSImpl{Cfg: cfg},
License: &licensing.OSSLicensingService{},
SocialService: &mockSocialService{},
}
hs.Cfg.CookieSecure = true
@ -171,9 +167,6 @@ func TestLoginViewRedirect(t *testing.T) {
hs.LoginView(c)
})
setting.OAuthService = &setting.OAuther{}
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
redirectCases := []redirectCase{
{
desc: "grafana relative url without subpath",
@ -489,25 +482,27 @@ func TestLoginOAuthRedirect(t *testing.T) {
sc := setupScenarioContext(t, "/login")
cfg := setting.NewCfg()
mock := &mockSocialService{
oAuthInfo: &social.OAuthInfo{
ClientId: "fake",
ClientSecret: "fakefake",
Enabled: true,
AllowSignup: true,
Name: "github",
},
oAuthInfos: oAuthInfos,
}
hs := &HTTPServer{
Cfg: cfg,
SettingsProvider: &setting.OSSImpl{Cfg: cfg},
License: &licensing.OSSLicensingService{},
SocialService: mock,
}
sc.defaultHandler = routing.Wrap(func(c *models.ReqContext) {
hs.LoginView(c)
})
setting.OAuthService = &setting.OAuther{}
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{
ClientId: "fake",
ClientSecret: "fakefake",
Enabled: true,
AllowSignup: true,
Name: "github",
}
setting.OAuthAutoLogin = true
sc.m.Get(sc.url, sc.defaultHandler)
sc.fakeReqNoAssertions("GET", sc.url).exec()
@ -534,15 +529,6 @@ func TestLoginInternal(t *testing.T) {
hs.LoginView(c)
})
setting.OAuthService = &setting.OAuther{}
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{
ClientId: "fake",
ClientSecret: "fakefake",
Enabled: true,
AllowSignup: true,
Name: "github",
}
setting.OAuthAutoLogin = true
sc.m.Get(sc.url, sc.defaultHandler)
sc.fakeReqNoAssertions("GET", sc.url).exec()
@ -586,6 +572,7 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte
License: &licensing.OSSLicensingService{},
AuthTokenService: auth.NewFakeUserAuthTokenService(),
log: log.New("hello"),
SocialService: &mockSocialService{},
}
sc.defaultHandler = routing.Wrap(func(c *models.ReqContext) {
@ -596,8 +583,6 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte
hs.LoginView(c)
})
setting.OAuthService = &setting.OAuther{}
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
sc.cfg.AuthProxyEnabled = true
sc.cfg.AuthProxyEnableLoginToken = enableLoginToken
@ -713,3 +698,32 @@ func TestLoginPostRunLokingHook(t *testing.T) {
})
}
}
type mockSocialService struct {
oAuthInfo *social.OAuthInfo
oAuthInfos map[string]*social.OAuthInfo
oAuthProviders map[string]bool
httpClient *http.Client
socialConnector social.SocialConnector
err error
}
func (m *mockSocialService) GetOAuthInfoProvider(name string) *social.OAuthInfo {
return m.oAuthInfo
}
func (m *mockSocialService) GetOAuthInfoProviders() map[string]*social.OAuthInfo {
return m.oAuthInfos
}
func (m *mockSocialService) GetOAuthProviders() map[string]bool {
return m.oAuthProviders
}
func (m *mockSocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
return m.httpClient, m.err
}
func (m *mockSocialService) GetConnector(string) (social.SocialConnector, error) {
return m.socialConnector, m.err
}

View File

@ -39,6 +39,7 @@ type DataSourceProxy struct {
plugin *plugins.DataSourcePlugin
cfg *setting.Cfg
clientProvider httpclient.Provider
oAuthTokenService oauthtoken.OAuthTokenService
}
type handleResponseTransport struct {
@ -71,7 +72,7 @@ func (lw *logWrapper) Write(p []byte) (n int, err error) {
// NewDataSourceProxy creates a new Datasource proxy
func NewDataSourceProxy(ds *models.DataSource, plugin *plugins.DataSourcePlugin, ctx *models.ReqContext,
proxyPath string, cfg *setting.Cfg, clientProvider httpclient.Provider) (*DataSourceProxy, error) {
proxyPath string, cfg *setting.Cfg, clientProvider httpclient.Provider, oAuthTokenService oauthtoken.OAuthTokenService) (*DataSourceProxy, error) {
targetURL, err := datasource.ValidateURL(ds.Type, ds.Url)
if err != nil {
return nil, err
@ -85,6 +86,7 @@ func NewDataSourceProxy(ds *models.DataSource, plugin *plugins.DataSourcePlugin,
targetUrl: targetURL,
cfg: cfg,
clientProvider: clientProvider,
oAuthTokenService: oAuthTokenService,
}, nil
}
@ -237,8 +239,8 @@ func (proxy *DataSourceProxy) director(req *http.Request) {
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, proxy.cfg)
}
if oauthtoken.IsOAuthPassThruEnabled(proxy.ds) {
if token := oauthtoken.GetCurrentOAuthToken(proxy.ctx.Req.Context(), proxy.ctx.SignedInUser); token != nil {
if proxy.oAuthTokenService.IsOAuthPassThruEnabled(proxy.ds) {
if token := proxy.oAuthTokenService.GetCurrentOAuthToken(proxy.ctx.Req.Context(), proxy.ctx.SignedInUser); token != nil {
req.Header.Set("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))
}
}

View File

@ -2,6 +2,7 @@ package pluginproxy
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
@ -16,9 +17,9 @@ import (
"github.com/grafana/grafana/pkg/components/securejsondata"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/stretchr/testify/assert"
@ -114,7 +115,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
t.Run("When matching route path", func(t *testing.T) {
ctx, req := setUp()
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
proxy.route = plugin.Routes[0]
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
@ -125,7 +126,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
t.Run("When matching route path and has dynamic url", func(t *testing.T) {
ctx, req := setUp()
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/common/some/method", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/common/some/method", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
proxy.route = plugin.Routes[3]
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
@ -136,7 +137,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
t.Run("When matching route path with no url", func(t *testing.T) {
ctx, req := setUp()
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
proxy.route = plugin.Routes[4]
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
@ -146,7 +147,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
t.Run("When matching route path and has dynamic body", func(t *testing.T) {
ctx, req := setUp()
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/body", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/body", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
proxy.route = plugin.Routes[5]
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
@ -159,7 +160,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
t.Run("Validating request", func(t *testing.T) {
t.Run("plugin route with valid role", func(t *testing.T) {
ctx, _ := setUp()
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
err = proxy.validateRequest()
require.NoError(t, err)
@ -167,7 +168,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
t.Run("plugin route with admin role and user is editor", func(t *testing.T) {
ctx, _ := setUp()
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
err = proxy.validateRequest()
require.Error(t, err)
@ -176,7 +177,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
t.Run("plugin route with admin role and user is admin", func(t *testing.T) {
ctx, _ := setUp()
ctx.SignedInUser.OrgRole = models.ROLE_ADMIN
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
err = proxy.validateRequest()
require.NoError(t, err)
@ -258,7 +259,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
cfg := &setting.Cfg{}
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds, cfg)
@ -273,7 +274,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
req, err := http.NewRequest("GET", "http://localhost/asd", nil)
require.NoError(t, err)
client = newFakeHTTPClient(t, json2)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken2", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken2", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[1], proxy.ds, cfg)
@ -289,7 +290,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
require.NoError(t, err)
client = newFakeHTTPClient(t, []byte{})
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds, cfg)
@ -309,7 +310,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
ds := &models.DataSource{Url: "htttp://graphite:8080", Type: models.DS_GRAPHITE}
ctx := &models.ReqContext{}
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{BuildVersion: "5.3.0"}, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{BuildVersion: "5.3.0"}, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
require.NoError(t, err)
@ -334,7 +335,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
}
ctx := &models.ReqContext{}
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
@ -357,7 +358,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
}
ctx := &models.ReqContext{}
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
requestURL, err := url.Parse("http://grafana.com/sub")
@ -384,7 +385,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
}
ctx := &models.ReqContext{}
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
requestURL, err := url.Parse("http://grafana.com/sub")
@ -405,7 +406,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
Url: "http://host/root/",
}
ctx := &models.ReqContext{}
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
req.Header.Set("Origin", "grafana.com")
@ -423,19 +424,6 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
})
t.Run("When proxying a datasource that has OAuth token pass-through enabled", func(t *testing.T) {
social.SocialMap["generic_oauth"] = &social.SocialGenericOAuth{
SocialBase: &social.SocialBase{
Config: &oauth2.Config{},
},
}
origAuthSvc := setting.OAuthService
t.Cleanup(func() {
setting.OAuthService = origAuthSvc
})
setting.OAuthService = &setting.OAuther{}
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
setting.OAuthService.OAuthInfos["generic_oauth"] = &setting.OAuthInfo{}
bus.AddHandler("test", func(query *models.GetAuthInfoQuery) error {
query.Result = &models.UserAuth{
Id: 1,
@ -466,7 +454,16 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
Req: macaron.Request{Request: req},
},
}
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider)
mockAuthToken := mockOAuthTokenService{
token: &oauth2.Token{
AccessToken: "testtoken",
RefreshToken: "testrefreshtoken",
TokenType: "Bearer",
Expiry: time.Now().AddDate(0, 0, 1),
},
oAuthEnabled: true,
}
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider, &mockAuthToken)
require.NoError(t, err)
req, err = http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
require.NoError(t, err)
@ -600,7 +597,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) {
t.Run("When response header Set-Cookie is not set should remove proxied Set-Cookie header", func(t *testing.T) {
ctx, ds := setUp(t)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
proxy.HandleRequest()
@ -615,7 +612,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) {
"Set-Cookie": "important_cookie=important_value",
},
})
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
proxy.HandleRequest()
@ -634,7 +631,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) {
t.Log("Wrote 401 response")
},
})
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
proxy.HandleRequest()
@ -656,7 +653,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) {
})
ctx.Req.Request = httptest.NewRequest("GET", "/api/datasources/proxy/1/path/%2Ftest%2Ftest%2F?query=%2Ftest%2Ftest%2F", nil)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/%2Ftest%2Ftest%2F", &setting.Cfg{}, httpClientProvider)
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/%2Ftest%2Ftest%2F", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
require.NoError(t, err)
proxy.HandleRequest()
@ -680,7 +677,7 @@ func TestNewDataSourceProxy_InvalidURL(t *testing.T) {
}
cfg := setting.Cfg{}
plugin := plugins.DataSourcePlugin{}
_, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider())
_, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider(), &oauthtoken.Service{})
require.Error(t, err)
assert.True(t, strings.HasPrefix(err.Error(), `validation of data source URL "://host/root" failed`))
}
@ -699,7 +696,7 @@ func TestNewDataSourceProxy_ProtocolLessURL(t *testing.T) {
cfg := setting.Cfg{}
plugin := plugins.DataSourcePlugin{}
_, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider())
_, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider(), &oauthtoken.Service{})
require.NoError(t, err)
}
@ -739,7 +736,7 @@ func TestNewDataSourceProxy_MSSQL(t *testing.T) {
Url: tc.url,
}
p, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider())
p, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider(), &oauthtoken.Service{})
if tc.err == nil {
require.NoError(t, err)
assert.Equal(t, &url.URL{
@ -777,7 +774,7 @@ func getDatasourceProxiedRequest(t *testing.T, ctx *models.ReqContext, cfg *sett
Url: "http://host/root/",
}
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpclient.NewProvider())
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpclient.NewProvider(), &oauthtoken.Service{})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
require.NoError(t, err)
@ -888,7 +885,7 @@ func createAuthTest(t *testing.T, dsType string, authType string, authCheck stri
func runDatasourceAuthTest(t *testing.T, test *testCase) {
plugin := &plugins.DataSourcePlugin{}
ctx := &models.ReqContext{}
proxy, err := NewDataSourceProxy(test.datasource, plugin, ctx, "", &setting.Cfg{}, httpclient.NewProvider())
proxy, err := NewDataSourceProxy(test.datasource, plugin, ctx, "", &setting.Cfg{}, httpclient.NewProvider(), &oauthtoken.Service{})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
@ -929,9 +926,22 @@ func Test_PathCheck(t *testing.T) {
return ctx, req
}
ctx, _ := setUp()
proxy, err := NewDataSourceProxy(&models.DataSource{}, plugin, ctx, "b", &setting.Cfg{}, httpclient.NewProvider())
proxy, err := NewDataSourceProxy(&models.DataSource{}, plugin, ctx, "b", &setting.Cfg{}, httpclient.NewProvider(), &oauthtoken.Service{})
require.NoError(t, err)
require.Nil(t, proxy.validateRequest())
require.Equal(t, plugin.Routes[1], proxy.route)
}
type mockOAuthTokenService struct {
token *oauth2.Token
oAuthEnabled bool
}
func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth2.Token {
return m.token
}
func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *models.DataSource) bool {
return m.oAuthEnabled
}

View File

@ -39,6 +39,7 @@ type UsageStatsService struct {
AlertingUsageStats alerting.UsageStatsQuerier `inject:""`
License models.Licensing `inject:""`
PluginManager plugins.Manager `inject:""`
SocialService social.Service `inject:""`
log log.Logger
@ -48,7 +49,7 @@ type UsageStatsService struct {
}
func (uss *UsageStatsService) Init() error {
uss.oauthProviders = social.GetOAuthProviders(uss.Cfg)
uss.oauthProviders = uss.SocialService.GetOAuthProviders()
return nil
}

View File

@ -13,6 +13,7 @@ import (
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)
@ -21,6 +22,184 @@ var (
logger = log.New("social")
)
func init() {
registry.RegisterService(&SocialService{})
}
type SocialService struct {
Cfg *setting.Cfg `inject:""`
socialMap map[string]SocialConnector
oAuthProvider map[string]*OAuthInfo
}
type OAuthInfo struct {
ClientId, ClientSecret string
Scopes []string
AuthUrl, TokenUrl string
Enabled bool
EmailAttributeName string
EmailAttributePath string
RoleAttributePath string
RoleAttributeStrict bool
GroupsAttributePath string
AllowedDomains []string
HostedDomain string
ApiUrl string
AllowSignup bool
Name string
TlsClientCert string
TlsClientKey string
TlsClientCa string
TlsSkipVerify bool
}
func (ss *SocialService) Init() error {
ss.oAuthProvider = make(map[string]*OAuthInfo)
ss.socialMap = make(map[string]SocialConnector)
for _, name := range allOauthes {
sec := ss.Cfg.Raw.Section("auth." + name)
info := &OAuthInfo{
ClientId: sec.Key("client_id").String(),
ClientSecret: sec.Key("client_secret").String(),
Scopes: util.SplitString(sec.Key("scopes").String()),
AuthUrl: sec.Key("auth_url").String(),
TokenUrl: sec.Key("token_url").String(),
ApiUrl: sec.Key("api_url").String(),
Enabled: sec.Key("enabled").MustBool(),
EmailAttributeName: sec.Key("email_attribute_name").String(),
EmailAttributePath: sec.Key("email_attribute_path").String(),
RoleAttributePath: sec.Key("role_attribute_path").String(),
RoleAttributeStrict: sec.Key("role_attribute_strict").MustBool(),
GroupsAttributePath: sec.Key("groups_attribute_path").String(),
AllowedDomains: util.SplitString(sec.Key("allowed_domains").String()),
HostedDomain: sec.Key("hosted_domain").String(),
AllowSignup: sec.Key("allow_sign_up").MustBool(),
Name: sec.Key("name").MustString(name),
TlsClientCert: sec.Key("tls_client_cert").String(),
TlsClientKey: sec.Key("tls_client_key").String(),
TlsClientCa: sec.Key("tls_client_ca").String(),
TlsSkipVerify: sec.Key("tls_skip_verify_insecure").MustBool(),
}
// when empty_scopes parameter exists and is true, overwrite scope with empty value
if sec.Key("empty_scopes").MustBool() {
info.Scopes = []string{}
}
if !info.Enabled {
continue
}
if name == "grafananet" {
name = grafanaCom
}
ss.oAuthProvider[name] = info
config := oauth2.Config{
ClientID: info.ClientId,
ClientSecret: info.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: info.AuthUrl,
TokenURL: info.TokenUrl,
AuthStyle: oauth2.AuthStyleAutoDetect,
},
RedirectURL: strings.TrimSuffix(ss.Cfg.AppURL, "/") + SocialBaseUrl + name,
Scopes: info.Scopes,
}
// GitHub.
if name == "github" {
ss.socialMap["github"] = &SocialGithub{
SocialBase: newSocialBase(name, &config, info),
apiUrl: info.ApiUrl,
teamIds: sec.Key("team_ids").Ints(","),
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
}
}
// GitLab.
if name == "gitlab" {
ss.socialMap["gitlab"] = &SocialGitlab{
SocialBase: newSocialBase(name, &config, info),
apiUrl: info.ApiUrl,
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
}
}
// Google.
if name == "google" {
ss.socialMap["google"] = &SocialGoogle{
SocialBase: newSocialBase(name, &config, info),
hostedDomain: info.HostedDomain,
apiUrl: info.ApiUrl,
}
}
// AzureAD.
if name == "azuread" {
ss.socialMap["azuread"] = &SocialAzureAD{
SocialBase: newSocialBase(name, &config, info),
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
autoAssignOrgRole: ss.Cfg.AutoAssignOrgRole,
}
}
// Okta
if name == "okta" {
ss.socialMap["okta"] = &SocialOkta{
SocialBase: newSocialBase(name, &config, info),
apiUrl: info.ApiUrl,
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
roleAttributePath: info.RoleAttributePath,
roleAttributeStrict: info.RoleAttributeStrict,
}
}
// Generic - Uses the same scheme as GitHub.
if name == "generic_oauth" {
ss.socialMap["generic_oauth"] = &SocialGenericOAuth{
SocialBase: newSocialBase(name, &config, info),
apiUrl: info.ApiUrl,
emailAttributeName: info.EmailAttributeName,
emailAttributePath: info.EmailAttributePath,
nameAttributePath: sec.Key("name_attribute_path").String(),
roleAttributePath: info.RoleAttributePath,
roleAttributeStrict: info.RoleAttributeStrict,
groupsAttributePath: info.GroupsAttributePath,
loginAttributePath: sec.Key("login_attribute_path").String(),
idTokenAttributeName: sec.Key("id_token_attribute_name").String(),
teamIds: sec.Key("team_ids").Ints(","),
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
}
}
if name == grafanaCom {
config = oauth2.Config{
ClientID: info.ClientId,
ClientSecret: info.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: ss.Cfg.GrafanaComURL + "/oauth2/authorize",
TokenURL: ss.Cfg.GrafanaComURL + "/api/oauth2/token",
AuthStyle: oauth2.AuthStyleInHeader,
},
RedirectURL: strings.TrimSuffix(ss.Cfg.AppURL, "/") + SocialBaseUrl + name,
Scopes: info.Scopes,
}
ss.socialMap[grafanaCom] = &SocialGrafanaCom{
SocialBase: newSocialBase(name, &config, info),
url: ss.Cfg.GrafanaComURL,
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
}
}
}
return nil
}
type BasicUserInfo struct {
Id string
Name string
@ -68,7 +247,15 @@ var (
allOauthes = []string{"github", "gitlab", "google", "generic_oauth", "grafananet", grafanaCom, "azuread", "okta"}
)
func newSocialBase(name string, config *oauth2.Config, info *setting.OAuthInfo) *SocialBase {
type Service interface {
GetOAuthProviders() map[string]bool
GetOAuthHttpClient(string) (*http.Client, error)
GetConnector(string) (SocialConnector, error)
GetOAuthInfoProvider(string) *OAuthInfo
GetOAuthInfoProviders() map[string]*OAuthInfo
}
func newSocialBase(name string, config *oauth2.Config, info *OAuthInfo) *SocialBase {
logger := log.New("oauth." + name)
return &SocialBase{
@ -79,156 +266,11 @@ func newSocialBase(name string, config *oauth2.Config, info *setting.OAuthInfo)
}
}
func NewOAuthService(cfg *setting.Cfg) {
setting.OAuthService = &setting.OAuther{}
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
for _, name := range allOauthes {
sec := cfg.Raw.Section("auth." + name)
info := &setting.OAuthInfo{
ClientId: sec.Key("client_id").String(),
ClientSecret: sec.Key("client_secret").String(),
Scopes: util.SplitString(sec.Key("scopes").String()),
AuthUrl: sec.Key("auth_url").String(),
TokenUrl: sec.Key("token_url").String(),
ApiUrl: sec.Key("api_url").String(),
Enabled: sec.Key("enabled").MustBool(),
EmailAttributeName: sec.Key("email_attribute_name").String(),
EmailAttributePath: sec.Key("email_attribute_path").String(),
RoleAttributePath: sec.Key("role_attribute_path").String(),
RoleAttributeStrict: sec.Key("role_attribute_strict").MustBool(),
GroupsAttributePath: sec.Key("groups_attribute_path").String(),
AllowedDomains: util.SplitString(sec.Key("allowed_domains").String()),
HostedDomain: sec.Key("hosted_domain").String(),
AllowSignup: sec.Key("allow_sign_up").MustBool(),
Name: sec.Key("name").MustString(name),
TlsClientCert: sec.Key("tls_client_cert").String(),
TlsClientKey: sec.Key("tls_client_key").String(),
TlsClientCa: sec.Key("tls_client_ca").String(),
TlsSkipVerify: sec.Key("tls_skip_verify_insecure").MustBool(),
}
// when empty_scopes parameter exists and is true, overwrite scope with empty value
if sec.Key("empty_scopes").MustBool() {
info.Scopes = []string{}
}
if !info.Enabled {
continue
}
if name == "grafananet" {
name = grafanaCom
}
setting.OAuthService.OAuthInfos[name] = info
config := oauth2.Config{
ClientID: info.ClientId,
ClientSecret: info.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: info.AuthUrl,
TokenURL: info.TokenUrl,
AuthStyle: oauth2.AuthStyleAutoDetect,
},
RedirectURL: strings.TrimSuffix(cfg.AppURL, "/") + SocialBaseUrl + name,
Scopes: info.Scopes,
}
// GitHub.
if name == "github" {
SocialMap["github"] = &SocialGithub{
SocialBase: newSocialBase(name, &config, info),
apiUrl: info.ApiUrl,
teamIds: sec.Key("team_ids").Ints(","),
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
}
}
// GitLab.
if name == "gitlab" {
SocialMap["gitlab"] = &SocialGitlab{
SocialBase: newSocialBase(name, &config, info),
apiUrl: info.ApiUrl,
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
}
}
// Google.
if name == "google" {
SocialMap["google"] = &SocialGoogle{
SocialBase: newSocialBase(name, &config, info),
hostedDomain: info.HostedDomain,
apiUrl: info.ApiUrl,
}
}
// AzureAD.
if name == "azuread" {
SocialMap["azuread"] = &SocialAzureAD{
SocialBase: newSocialBase(name, &config, info),
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
autoAssignOrgRole: cfg.AutoAssignOrgRole,
}
}
// Okta
if name == "okta" {
SocialMap["okta"] = &SocialOkta{
SocialBase: newSocialBase(name, &config, info),
apiUrl: info.ApiUrl,
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
roleAttributePath: info.RoleAttributePath,
roleAttributeStrict: info.RoleAttributeStrict,
}
}
// Generic - Uses the same scheme as GitHub.
if name == "generic_oauth" {
SocialMap["generic_oauth"] = &SocialGenericOAuth{
SocialBase: newSocialBase(name, &config, info),
apiUrl: info.ApiUrl,
emailAttributeName: info.EmailAttributeName,
emailAttributePath: info.EmailAttributePath,
nameAttributePath: sec.Key("name_attribute_path").String(),
roleAttributePath: info.RoleAttributePath,
roleAttributeStrict: info.RoleAttributeStrict,
groupsAttributePath: info.GroupsAttributePath,
loginAttributePath: sec.Key("login_attribute_path").String(),
idTokenAttributeName: sec.Key("id_token_attribute_name").String(),
teamIds: sec.Key("team_ids").Ints(","),
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
}
}
if name == grafanaCom {
config = oauth2.Config{
ClientID: info.ClientId,
ClientSecret: info.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: cfg.GrafanaComURL + "/oauth2/authorize",
TokenURL: cfg.GrafanaComURL + "/api/oauth2/token",
AuthStyle: oauth2.AuthStyleInHeader,
},
RedirectURL: strings.TrimSuffix(cfg.AppURL, "/") + SocialBaseUrl + name,
Scopes: info.Scopes,
}
SocialMap[grafanaCom] = &SocialGrafanaCom{
SocialBase: newSocialBase(name, &config, info),
url: cfg.GrafanaComURL,
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
}
}
}
}
// GetOAuthProviders returns available oauth providers and if they're enabled or not
var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool {
func (ss *SocialService) GetOAuthProviders() map[string]bool {
result := map[string]bool{}
if cfg == nil || cfg.Raw == nil {
if ss.Cfg == nil || ss.Cfg.Raw == nil {
return result
}
@ -237,7 +279,7 @@ var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool {
name = grafanaCom
}
sec := cfg.Raw.Section("auth." + name)
sec := ss.Cfg.Raw.Section("auth." + name)
if sec == nil {
continue
}
@ -247,13 +289,10 @@ var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool {
return result
}
func GetOAuthHttpClient(name string) (*http.Client, error) {
if setting.OAuthService == nil {
return nil, fmt.Errorf("OAuth not enabled")
}
func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
// The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
name = strings.TrimPrefix(name, "oauth_")
info, ok := setting.OAuthService.OAuthInfos[name]
info, ok := ss.oAuthProvider[name]
if !ok {
return nil, fmt.Errorf("could not find %q in OAuth Settings", name)
}
@ -292,12 +331,20 @@ func GetOAuthHttpClient(name string) (*http.Client, error) {
return oauthClient, nil
}
func GetConnector(name string) (SocialConnector, error) {
func (ss *SocialService) GetConnector(name string) (SocialConnector, error) {
// The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
provider := strings.TrimPrefix(name, "oauth_")
connector, ok := SocialMap[provider]
connector, ok := ss.socialMap[provider]
if !ok {
return nil, fmt.Errorf("failed to find oauth provider for %q", name)
}
return connector, nil
}
func (ss *SocialService) GetOAuthInfoProvider(name string) *OAuthInfo {
return ss.oAuthProvider[name]
}
func (ss *SocialService) GetOAuthInfoProviders() map[string]*OAuthInfo {
return ss.oAuthProvider
}

View File

@ -28,7 +28,7 @@ import (
_ "github.com/grafana/grafana/pkg/infra/tracing"
_ "github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/login/social"
_ "github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/middleware"
_ "github.com/grafana/grafana/pkg/plugins/manager"
"github.com/grafana/grafana/pkg/registry"
@ -150,7 +150,6 @@ func (s *Server) init() error {
}
login.Init()
social.NewOAuthService(s.cfg)
services := s.serviceRegistry.GetServices()
if err := s.buildServiceGraph(services); err != nil {

View File

@ -14,6 +14,7 @@ import (
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/setting"
)
@ -27,6 +28,7 @@ type DatasourceProxyService struct {
PluginManager plugins.Manager `inject:""`
Cfg *setting.Cfg `inject:""`
HTTPClientProvider httpclient.Provider `inject:""`
OAuthTokenService *oauthtoken.Service `inject:""`
}
func (p *DatasourceProxyService) Init() error {
@ -68,7 +70,7 @@ func (p *DatasourceProxyService) ProxyDatasourceRequestWithID(c *models.ReqConte
}
proxyPath := getProxyPath(c)
proxy, err := pluginproxy.NewDataSourceProxy(ds, plugin, c, proxyPath, p.Cfg, p.HTTPClientProvider)
proxy, err := pluginproxy.NewDataSourceProxy(ds, plugin, c, proxyPath, p.Cfg, p.HTTPClientProvider, p.OAuthTokenService)
if err != nil {
if errors.Is(err, datasource.URLValidationError{}) {
c.JsonApiErr(http.StatusBadRequest, fmt.Sprintf("Invalid data source URL: %q", ds.Url), err)

View File

@ -8,6 +8,7 @@ import (
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/registry"
"golang.org/x/oauth2"
)
@ -15,8 +16,25 @@ var (
logger = log.New("oauthtoken")
)
func init() {
registry.RegisterService(&Service{})
}
type Service struct {
SocialService social.Service `inject:""`
}
type OAuthTokenService interface {
GetCurrentOAuthToken(context.Context, *models.SignedInUser) *oauth2.Token
IsOAuthPassThruEnabled(*models.DataSource) bool
}
func (o *Service) Init() error {
return nil
}
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
func GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth2.Token {
func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth2.Token {
if user == nil {
// No user, therefore no token
return nil
@ -34,13 +52,13 @@ func GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth
}
authProvider := authInfoQuery.Result.AuthModule
connect, err := social.GetConnector(authProvider)
connect, err := o.SocialService.GetConnector(authProvider)
if err != nil {
logger.Error("failed to get OAuth connector", "provider", authProvider, "error", err)
return nil
}
client, err := social.GetOAuthHttpClient(authProvider)
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
if err != nil {
logger.Error("failed to get OAuth http client", "provider", authProvider, "error", err)
return nil
@ -78,7 +96,7 @@ func GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth
}
// IsOAuthPassThruEnabled returns true if Forward OAuth Identity (oauthPassThru) is enabled for the provided data source.
func IsOAuthPassThruEnabled(ds *models.DataSource) bool {
func (o *Service) IsOAuthPassThruEnabled(ds *models.DataSource) bool {
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
}

View File

@ -1,28 +0,0 @@
package setting
type OAuthInfo struct {
ClientId, ClientSecret string
Scopes []string
AuthUrl, TokenUrl string
Enabled bool
EmailAttributeName string
EmailAttributePath string
RoleAttributePath string
RoleAttributeStrict bool
GroupsAttributePath string
AllowedDomains []string
HostedDomain string
ApiUrl string
AllowSignup bool
Name string
TlsClientCert string
TlsClientKey string
TlsClientCa string
TlsSkipVerify bool
}
type OAuther struct {
OAuthInfos map[string]*OAuthInfo
}
var OAuthService *OAuther

View File

@ -13,7 +13,7 @@ import (
)
// nolint:staticcheck // plugins.DataQuery deprecated
func dataPluginQueryAdapter(pluginID string, handler backend.QueryDataHandler) plugins.DataPluginFunc {
func dataPluginQueryAdapter(pluginID string, handler backend.QueryDataHandler, oAuthService *oauthtoken.Service) plugins.DataPluginFunc {
return plugins.DataPluginFunc(func(ctx context.Context, ds *models.DataSource, query plugins.DataQuery) (plugins.DataResponse, error) {
instanceSettings, err := modelToInstanceSettings(ds)
if err != nil {
@ -24,8 +24,8 @@ func dataPluginQueryAdapter(pluginID string, handler backend.QueryDataHandler) p
query.Headers = make(map[string]string)
}
if oauthtoken.IsOAuthPassThruEnabled(ds) {
if token := oauthtoken.GetCurrentOAuthToken(ctx, query.User); token != nil {
if oAuthService.IsOAuthPassThruEnabled(ds) {
if token := oAuthService.GetCurrentOAuthToken(ctx, query.User); token != nil {
delete(query.Headers, "Authorization")
query.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken)
}

View File

@ -9,6 +9,7 @@ import (
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor"
"github.com/grafana/grafana/pkg/tsdb/cloudmonitoring"
@ -48,6 +49,7 @@ type Service struct {
PluginManager plugins.Manager `inject:""`
BackendPluginManager backendplugin.Manager `inject:""`
HTTPClientProvider httpclient.Provider `inject:""`
OAuthTokenService *oauthtoken.Service `inject:""`
//nolint: staticcheck // plugins.DataPlugin deprecated
registry map[string]func(*models.DataSource) (plugins.DataPlugin, error)
@ -81,7 +83,7 @@ func (s *Service) HandleRequest(ctx context.Context, ds *models.DataSource, quer
return plugin.DataQuery(ctx, ds, query)
}
return dataPluginQueryAdapter(ds.Type, s.BackendPluginManager).DataQuery(ctx, ds, query)
return dataPluginQueryAdapter(ds.Type, s.BackendPluginManager, s.OAuthTokenService).DataQuery(ctx, ds, query)
}
// RegisterQueryHandler registers a query handler factory.