mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Chore: Refactor OAuth/social package to service (#35403)
* Creating SocialService * Add GetOAuthProviders as socialService method * Add OAuthTokenService * Add GetOAuthHttpClient method to SocialService * Rename services, access socialMap from GetConnector * Fix tests by mocking oauthtoken methods * Move NewAuthService into Init * Move OAuthService to social pkg * Refactor OAuthService to OAuthProvider * Fix nil map error, rename file, simplify tests * Fix bug for Forward OAuth Identify * Remove file after rebase
This commit is contained in:
parent
55e763b4cd
commit
60ac54d969
@ -13,8 +13,10 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/libraryelements"
|
||||
"github.com/grafana/grafana/pkg/services/librarypanels"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
|
||||
"github.com/grafana/grafana/pkg/api/routing"
|
||||
httpstatic "github.com/grafana/grafana/pkg/api/static"
|
||||
@ -105,6 +107,8 @@ type HTTPServer struct {
|
||||
Alertmanager *notifier.Alertmanager `inject:""`
|
||||
LibraryPanelService librarypanels.Service `inject:""`
|
||||
LibraryElementService libraryelements.Service `inject:""`
|
||||
SocialService social.Service `inject:""`
|
||||
OAuthTokenService *oauthtoken.Service `inject:""`
|
||||
Listener net.Listener
|
||||
}
|
||||
|
||||
|
@ -91,7 +91,8 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) {
|
||||
}
|
||||
|
||||
enabledOAuths := make(map[string]interface{})
|
||||
for key, oauth := range setting.OAuthService.OAuthInfos {
|
||||
providers := hs.SocialService.GetOAuthInfoProviders()
|
||||
for key, oauth := range providers {
|
||||
enabledOAuths[key] = map[string]string{"name": oauth.Name}
|
||||
}
|
||||
|
||||
@ -147,12 +148,12 @@ func (hs *HTTPServer) tryOAuthAutoLogin(c *models.ReqContext) bool {
|
||||
if !setting.OAuthAutoLogin {
|
||||
return false
|
||||
}
|
||||
oauthInfos := setting.OAuthService.OAuthInfos
|
||||
oauthInfos := hs.SocialService.GetOAuthInfoProviders()
|
||||
if len(oauthInfos) != 1 {
|
||||
log.Warnf("Skipping OAuth auto login because multiple OAuth providers are configured")
|
||||
return false
|
||||
}
|
||||
for key := range setting.OAuthService.OAuthInfos {
|
||||
for key := range oauthInfos {
|
||||
redirectUrl := hs.Cfg.AppSubURL + "/login/" + key
|
||||
log.Infof("OAuth auto login enabled. Redirecting to " + redirectUrl)
|
||||
c.Redirect(redirectUrl, 307)
|
||||
|
@ -41,7 +41,10 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
loginInfo := models.LoginInfo{
|
||||
AuthModule: "oauth",
|
||||
}
|
||||
if setting.OAuthService == nil {
|
||||
name := ctx.Params(":name")
|
||||
loginInfo.AuthModule = name
|
||||
provider := hs.SocialService.GetOAuthInfoProvider(name)
|
||||
if provider == nil {
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusNotFound,
|
||||
PublicMessage: "OAuth not enabled",
|
||||
@ -49,10 +52,8 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
return
|
||||
}
|
||||
|
||||
name := ctx.Params(":name")
|
||||
loginInfo.AuthModule = name
|
||||
connect, ok := social.SocialMap[name]
|
||||
if !ok {
|
||||
connect, err := hs.SocialService.GetConnector(name)
|
||||
if err != nil {
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusNotFound,
|
||||
PublicMessage: fmt.Sprintf("No OAuth with name %s configured", name),
|
||||
@ -80,12 +81,12 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
return
|
||||
}
|
||||
|
||||
hashedState := hashStatecode(state, setting.OAuthService.OAuthInfos[name].ClientSecret)
|
||||
hashedState := hashStatecode(state, provider.ClientSecret)
|
||||
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
if setting.OAuthService.OAuthInfos[name].HostedDomain == "" {
|
||||
if provider.HostedDomain == "" {
|
||||
ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline))
|
||||
} else {
|
||||
ctx.Redirect(connect.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", setting.OAuthService.OAuthInfos[name].HostedDomain), oauth2.AccessTypeOnline))
|
||||
ctx.Redirect(connect.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", provider.HostedDomain), oauth2.AccessTypeOnline))
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -103,7 +104,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
return
|
||||
}
|
||||
|
||||
queryState := hashStatecode(ctx.Query("state"), setting.OAuthService.OAuthInfos[name].ClientSecret)
|
||||
queryState := hashStatecode(ctx.Query("state"), provider.ClientSecret)
|
||||
oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState)
|
||||
if cookieState != queryState {
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
@ -113,7 +114,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
return
|
||||
}
|
||||
|
||||
oauthClient, err := social.GetOAuthHttpClient(name)
|
||||
oauthClient, err := hs.SocialService.GetOAuthHttpClient(name)
|
||||
if err != nil {
|
||||
ctx.Logger.Error("Failed to create OAuth http client", "error", err)
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/hooks"
|
||||
@ -86,6 +87,16 @@ type redirectCase struct {
|
||||
redirectURL string
|
||||
}
|
||||
|
||||
var oAuthInfos = map[string]*social.OAuthInfo{
|
||||
"github": {
|
||||
ClientId: "fake",
|
||||
ClientSecret: "fakefake",
|
||||
Enabled: true,
|
||||
AllowSignup: true,
|
||||
Name: "github",
|
||||
},
|
||||
}
|
||||
|
||||
func TestLoginErrorCookieAPIEndpoint(t *testing.T) {
|
||||
fakeSetIndexViewData(t)
|
||||
|
||||
@ -97,6 +108,7 @@ func TestLoginErrorCookieAPIEndpoint(t *testing.T) {
|
||||
Cfg: cfg,
|
||||
SettingsProvider: &setting.OSSImpl{Cfg: cfg},
|
||||
License: &licensing.OSSLicensingService{},
|
||||
SocialService: &mockSocialService{},
|
||||
}
|
||||
|
||||
sc.defaultHandler = routing.Wrap(func(w http.ResponseWriter, c *models.ReqContext) {
|
||||
@ -106,23 +118,6 @@ func TestLoginErrorCookieAPIEndpoint(t *testing.T) {
|
||||
cfg.LoginCookieName = "grafana_session"
|
||||
setting.SecretKey = "login_testing"
|
||||
|
||||
origOAuthService := setting.OAuthService
|
||||
origOAuthAutoLogin := setting.OAuthAutoLogin
|
||||
t.Cleanup(func() {
|
||||
setting.OAuthService = origOAuthService
|
||||
setting.OAuthAutoLogin = origOAuthAutoLogin
|
||||
})
|
||||
setting.OAuthService = &setting.OAuther{
|
||||
OAuthInfos: map[string]*setting.OAuthInfo{
|
||||
"github": {
|
||||
ClientId: "fake",
|
||||
ClientSecret: "fakefake",
|
||||
Enabled: true,
|
||||
AllowSignup: true,
|
||||
Name: "github",
|
||||
},
|
||||
},
|
||||
}
|
||||
setting.OAuthAutoLogin = true
|
||||
|
||||
oauthError := errors.New("User not a member of one of the required organizations")
|
||||
@ -160,6 +155,7 @@ func TestLoginViewRedirect(t *testing.T) {
|
||||
Cfg: cfg,
|
||||
SettingsProvider: &setting.OSSImpl{Cfg: cfg},
|
||||
License: &licensing.OSSLicensingService{},
|
||||
SocialService: &mockSocialService{},
|
||||
}
|
||||
hs.Cfg.CookieSecure = true
|
||||
|
||||
@ -171,9 +167,6 @@ func TestLoginViewRedirect(t *testing.T) {
|
||||
hs.LoginView(c)
|
||||
})
|
||||
|
||||
setting.OAuthService = &setting.OAuther{}
|
||||
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
|
||||
|
||||
redirectCases := []redirectCase{
|
||||
{
|
||||
desc: "grafana relative url without subpath",
|
||||
@ -489,25 +482,27 @@ func TestLoginOAuthRedirect(t *testing.T) {
|
||||
|
||||
sc := setupScenarioContext(t, "/login")
|
||||
cfg := setting.NewCfg()
|
||||
mock := &mockSocialService{
|
||||
oAuthInfo: &social.OAuthInfo{
|
||||
ClientId: "fake",
|
||||
ClientSecret: "fakefake",
|
||||
Enabled: true,
|
||||
AllowSignup: true,
|
||||
Name: "github",
|
||||
},
|
||||
oAuthInfos: oAuthInfos,
|
||||
}
|
||||
hs := &HTTPServer{
|
||||
Cfg: cfg,
|
||||
SettingsProvider: &setting.OSSImpl{Cfg: cfg},
|
||||
License: &licensing.OSSLicensingService{},
|
||||
SocialService: mock,
|
||||
}
|
||||
|
||||
sc.defaultHandler = routing.Wrap(func(c *models.ReqContext) {
|
||||
hs.LoginView(c)
|
||||
})
|
||||
|
||||
setting.OAuthService = &setting.OAuther{}
|
||||
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
|
||||
setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{
|
||||
ClientId: "fake",
|
||||
ClientSecret: "fakefake",
|
||||
Enabled: true,
|
||||
AllowSignup: true,
|
||||
Name: "github",
|
||||
}
|
||||
setting.OAuthAutoLogin = true
|
||||
sc.m.Get(sc.url, sc.defaultHandler)
|
||||
sc.fakeReqNoAssertions("GET", sc.url).exec()
|
||||
@ -534,15 +529,6 @@ func TestLoginInternal(t *testing.T) {
|
||||
hs.LoginView(c)
|
||||
})
|
||||
|
||||
setting.OAuthService = &setting.OAuther{}
|
||||
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
|
||||
setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{
|
||||
ClientId: "fake",
|
||||
ClientSecret: "fakefake",
|
||||
Enabled: true,
|
||||
AllowSignup: true,
|
||||
Name: "github",
|
||||
}
|
||||
setting.OAuthAutoLogin = true
|
||||
sc.m.Get(sc.url, sc.defaultHandler)
|
||||
sc.fakeReqNoAssertions("GET", sc.url).exec()
|
||||
@ -586,6 +572,7 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte
|
||||
License: &licensing.OSSLicensingService{},
|
||||
AuthTokenService: auth.NewFakeUserAuthTokenService(),
|
||||
log: log.New("hello"),
|
||||
SocialService: &mockSocialService{},
|
||||
}
|
||||
|
||||
sc.defaultHandler = routing.Wrap(func(c *models.ReqContext) {
|
||||
@ -596,8 +583,6 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte
|
||||
hs.LoginView(c)
|
||||
})
|
||||
|
||||
setting.OAuthService = &setting.OAuther{}
|
||||
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
|
||||
sc.cfg.AuthProxyEnabled = true
|
||||
sc.cfg.AuthProxyEnableLoginToken = enableLoginToken
|
||||
|
||||
@ -713,3 +698,32 @@ func TestLoginPostRunLokingHook(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockSocialService struct {
|
||||
oAuthInfo *social.OAuthInfo
|
||||
oAuthInfos map[string]*social.OAuthInfo
|
||||
oAuthProviders map[string]bool
|
||||
httpClient *http.Client
|
||||
socialConnector social.SocialConnector
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockSocialService) GetOAuthInfoProvider(name string) *social.OAuthInfo {
|
||||
return m.oAuthInfo
|
||||
}
|
||||
|
||||
func (m *mockSocialService) GetOAuthInfoProviders() map[string]*social.OAuthInfo {
|
||||
return m.oAuthInfos
|
||||
}
|
||||
|
||||
func (m *mockSocialService) GetOAuthProviders() map[string]bool {
|
||||
return m.oAuthProviders
|
||||
}
|
||||
|
||||
func (m *mockSocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
|
||||
return m.httpClient, m.err
|
||||
}
|
||||
|
||||
func (m *mockSocialService) GetConnector(string) (social.SocialConnector, error) {
|
||||
return m.socialConnector, m.err
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ type DataSourceProxy struct {
|
||||
plugin *plugins.DataSourcePlugin
|
||||
cfg *setting.Cfg
|
||||
clientProvider httpclient.Provider
|
||||
oAuthTokenService oauthtoken.OAuthTokenService
|
||||
}
|
||||
|
||||
type handleResponseTransport struct {
|
||||
@ -71,7 +72,7 @@ func (lw *logWrapper) Write(p []byte) (n int, err error) {
|
||||
|
||||
// NewDataSourceProxy creates a new Datasource proxy
|
||||
func NewDataSourceProxy(ds *models.DataSource, plugin *plugins.DataSourcePlugin, ctx *models.ReqContext,
|
||||
proxyPath string, cfg *setting.Cfg, clientProvider httpclient.Provider) (*DataSourceProxy, error) {
|
||||
proxyPath string, cfg *setting.Cfg, clientProvider httpclient.Provider, oAuthTokenService oauthtoken.OAuthTokenService) (*DataSourceProxy, error) {
|
||||
targetURL, err := datasource.ValidateURL(ds.Type, ds.Url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -85,6 +86,7 @@ func NewDataSourceProxy(ds *models.DataSource, plugin *plugins.DataSourcePlugin,
|
||||
targetUrl: targetURL,
|
||||
cfg: cfg,
|
||||
clientProvider: clientProvider,
|
||||
oAuthTokenService: oAuthTokenService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -237,8 +239,8 @@ func (proxy *DataSourceProxy) director(req *http.Request) {
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, proxy.cfg)
|
||||
}
|
||||
|
||||
if oauthtoken.IsOAuthPassThruEnabled(proxy.ds) {
|
||||
if token := oauthtoken.GetCurrentOAuthToken(proxy.ctx.Req.Context(), proxy.ctx.SignedInUser); token != nil {
|
||||
if proxy.oAuthTokenService.IsOAuthPassThruEnabled(proxy.ds) {
|
||||
if token := proxy.oAuthTokenService.GetCurrentOAuthToken(proxy.ctx.Req.Context(), proxy.ctx.SignedInUser); token != nil {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package pluginproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@ -16,9 +17,9 @@ import (
|
||||
"github.com/grafana/grafana/pkg/components/securejsondata"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -114,7 +115,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
|
||||
t.Run("When matching route path", func(t *testing.T) {
|
||||
ctx, req := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
proxy.route = plugin.Routes[0]
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
|
||||
@ -125,7 +126,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
|
||||
t.Run("When matching route path and has dynamic url", func(t *testing.T) {
|
||||
ctx, req := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/common/some/method", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/common/some/method", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
proxy.route = plugin.Routes[3]
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
|
||||
@ -136,7 +137,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
|
||||
t.Run("When matching route path with no url", func(t *testing.T) {
|
||||
ctx, req := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
proxy.route = plugin.Routes[4]
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
|
||||
@ -146,7 +147,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
|
||||
t.Run("When matching route path and has dynamic body", func(t *testing.T) {
|
||||
ctx, req := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/body", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/body", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
proxy.route = plugin.Routes[5]
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
|
||||
@ -159,7 +160,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
t.Run("Validating request", func(t *testing.T) {
|
||||
t.Run("plugin route with valid role", func(t *testing.T) {
|
||||
ctx, _ := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
err = proxy.validateRequest()
|
||||
require.NoError(t, err)
|
||||
@ -167,7 +168,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
|
||||
t.Run("plugin route with admin role and user is editor", func(t *testing.T) {
|
||||
ctx, _ := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
err = proxy.validateRequest()
|
||||
require.Error(t, err)
|
||||
@ -176,7 +177,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
t.Run("plugin route with admin role and user is admin", func(t *testing.T) {
|
||||
ctx, _ := setUp()
|
||||
ctx.SignedInUser.OrgRole = models.ROLE_ADMIN
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
err = proxy.validateRequest()
|
||||
require.NoError(t, err)
|
||||
@ -258,7 +259,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds, cfg)
|
||||
|
||||
@ -273,7 +274,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "http://localhost/asd", nil)
|
||||
require.NoError(t, err)
|
||||
client = newFakeHTTPClient(t, json2)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken2", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken2", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[1], proxy.ds, cfg)
|
||||
|
||||
@ -289,7 +290,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
client = newFakeHTTPClient(t, []byte{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds, cfg)
|
||||
|
||||
@ -309,7 +310,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
ds := &models.DataSource{Url: "htttp://graphite:8080", Type: models.DS_GRAPHITE}
|
||||
ctx := &models.ReqContext{}
|
||||
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{BuildVersion: "5.3.0"}, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{BuildVersion: "5.3.0"}, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
|
||||
require.NoError(t, err)
|
||||
@ -334,7 +335,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := &models.ReqContext{}
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
|
||||
@ -357,7 +358,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := &models.ReqContext{}
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
|
||||
requestURL, err := url.Parse("http://grafana.com/sub")
|
||||
@ -384,7 +385,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := &models.ReqContext{}
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
|
||||
requestURL, err := url.Parse("http://grafana.com/sub")
|
||||
@ -405,7 +406,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
Url: "http://host/root/",
|
||||
}
|
||||
ctx := &models.ReqContext{}
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
|
||||
req.Header.Set("Origin", "grafana.com")
|
||||
@ -423,19 +424,6 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("When proxying a datasource that has OAuth token pass-through enabled", func(t *testing.T) {
|
||||
social.SocialMap["generic_oauth"] = &social.SocialGenericOAuth{
|
||||
SocialBase: &social.SocialBase{
|
||||
Config: &oauth2.Config{},
|
||||
},
|
||||
}
|
||||
origAuthSvc := setting.OAuthService
|
||||
t.Cleanup(func() {
|
||||
setting.OAuthService = origAuthSvc
|
||||
})
|
||||
setting.OAuthService = &setting.OAuther{}
|
||||
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
|
||||
setting.OAuthService.OAuthInfos["generic_oauth"] = &setting.OAuthInfo{}
|
||||
|
||||
bus.AddHandler("test", func(query *models.GetAuthInfoQuery) error {
|
||||
query.Result = &models.UserAuth{
|
||||
Id: 1,
|
||||
@ -466,7 +454,16 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
Req: macaron.Request{Request: req},
|
||||
},
|
||||
}
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider)
|
||||
mockAuthToken := mockOAuthTokenService{
|
||||
token: &oauth2.Token{
|
||||
AccessToken: "testtoken",
|
||||
RefreshToken: "testrefreshtoken",
|
||||
TokenType: "Bearer",
|
||||
Expiry: time.Now().AddDate(0, 0, 1),
|
||||
},
|
||||
oAuthEnabled: true,
|
||||
}
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}, httpClientProvider, &mockAuthToken)
|
||||
require.NoError(t, err)
|
||||
req, err = http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
|
||||
require.NoError(t, err)
|
||||
@ -600,7 +597,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) {
|
||||
|
||||
t.Run("When response header Set-Cookie is not set should remove proxied Set-Cookie header", func(t *testing.T) {
|
||||
ctx, ds := setUp(t)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy.HandleRequest()
|
||||
@ -615,7 +612,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) {
|
||||
"Set-Cookie": "important_cookie=important_value",
|
||||
},
|
||||
})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy.HandleRequest()
|
||||
@ -634,7 +631,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) {
|
||||
t.Log("Wrote 401 response")
|
||||
},
|
||||
})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/render", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy.HandleRequest()
|
||||
@ -656,7 +653,7 @@ func TestDataSourceProxy_requestHandling(t *testing.T) {
|
||||
})
|
||||
|
||||
ctx.Req.Request = httptest.NewRequest("GET", "/api/datasources/proxy/1/path/%2Ftest%2Ftest%2F?query=%2Ftest%2Ftest%2F", nil)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/%2Ftest%2Ftest%2F", &setting.Cfg{}, httpClientProvider)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "/path/%2Ftest%2Ftest%2F", &setting.Cfg{}, httpClientProvider, &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy.HandleRequest()
|
||||
@ -680,7 +677,7 @@ func TestNewDataSourceProxy_InvalidURL(t *testing.T) {
|
||||
}
|
||||
cfg := setting.Cfg{}
|
||||
plugin := plugins.DataSourcePlugin{}
|
||||
_, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider())
|
||||
_, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider(), &oauthtoken.Service{})
|
||||
require.Error(t, err)
|
||||
assert.True(t, strings.HasPrefix(err.Error(), `validation of data source URL "://host/root" failed`))
|
||||
}
|
||||
@ -699,7 +696,7 @@ func TestNewDataSourceProxy_ProtocolLessURL(t *testing.T) {
|
||||
cfg := setting.Cfg{}
|
||||
plugin := plugins.DataSourcePlugin{}
|
||||
|
||||
_, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider())
|
||||
_, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider(), &oauthtoken.Service{})
|
||||
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@ -739,7 +736,7 @@ func TestNewDataSourceProxy_MSSQL(t *testing.T) {
|
||||
Url: tc.url,
|
||||
}
|
||||
|
||||
p, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider())
|
||||
p, err := NewDataSourceProxy(&ds, &plugin, &ctx, "api/method", &cfg, httpclient.NewProvider(), &oauthtoken.Service{})
|
||||
if tc.err == nil {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &url.URL{
|
||||
@ -777,7 +774,7 @@ func getDatasourceProxiedRequest(t *testing.T, ctx *models.ReqContext, cfg *sett
|
||||
Url: "http://host/root/",
|
||||
}
|
||||
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpclient.NewProvider())
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg, httpclient.NewProvider(), &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
|
||||
require.NoError(t, err)
|
||||
@ -888,7 +885,7 @@ func createAuthTest(t *testing.T, dsType string, authType string, authCheck stri
|
||||
func runDatasourceAuthTest(t *testing.T, test *testCase) {
|
||||
plugin := &plugins.DataSourcePlugin{}
|
||||
ctx := &models.ReqContext{}
|
||||
proxy, err := NewDataSourceProxy(test.datasource, plugin, ctx, "", &setting.Cfg{}, httpclient.NewProvider())
|
||||
proxy, err := NewDataSourceProxy(test.datasource, plugin, ctx, "", &setting.Cfg{}, httpclient.NewProvider(), &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
|
||||
@ -929,9 +926,22 @@ func Test_PathCheck(t *testing.T) {
|
||||
return ctx, req
|
||||
}
|
||||
ctx, _ := setUp()
|
||||
proxy, err := NewDataSourceProxy(&models.DataSource{}, plugin, ctx, "b", &setting.Cfg{}, httpclient.NewProvider())
|
||||
proxy, err := NewDataSourceProxy(&models.DataSource{}, plugin, ctx, "b", &setting.Cfg{}, httpclient.NewProvider(), &oauthtoken.Service{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Nil(t, proxy.validateRequest())
|
||||
require.Equal(t, plugin.Routes[1], proxy.route)
|
||||
}
|
||||
|
||||
type mockOAuthTokenService struct {
|
||||
token *oauth2.Token
|
||||
oAuthEnabled bool
|
||||
}
|
||||
|
||||
func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth2.Token {
|
||||
return m.token
|
||||
}
|
||||
|
||||
func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *models.DataSource) bool {
|
||||
return m.oAuthEnabled
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ type UsageStatsService struct {
|
||||
AlertingUsageStats alerting.UsageStatsQuerier `inject:""`
|
||||
License models.Licensing `inject:""`
|
||||
PluginManager plugins.Manager `inject:""`
|
||||
SocialService social.Service `inject:""`
|
||||
|
||||
log log.Logger
|
||||
|
||||
@ -48,7 +49,7 @@ type UsageStatsService struct {
|
||||
}
|
||||
|
||||
func (uss *UsageStatsService) Init() error {
|
||||
uss.oauthProviders = social.GetOAuthProviders(uss.Cfg)
|
||||
uss.oauthProviders = uss.SocialService.GetOAuthProviders()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/registry"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
@ -21,6 +22,184 @@ var (
|
||||
logger = log.New("social")
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.RegisterService(&SocialService{})
|
||||
}
|
||||
|
||||
type SocialService struct {
|
||||
Cfg *setting.Cfg `inject:""`
|
||||
|
||||
socialMap map[string]SocialConnector
|
||||
oAuthProvider map[string]*OAuthInfo
|
||||
}
|
||||
|
||||
type OAuthInfo struct {
|
||||
ClientId, ClientSecret string
|
||||
Scopes []string
|
||||
AuthUrl, TokenUrl string
|
||||
Enabled bool
|
||||
EmailAttributeName string
|
||||
EmailAttributePath string
|
||||
RoleAttributePath string
|
||||
RoleAttributeStrict bool
|
||||
GroupsAttributePath string
|
||||
AllowedDomains []string
|
||||
HostedDomain string
|
||||
ApiUrl string
|
||||
AllowSignup bool
|
||||
Name string
|
||||
TlsClientCert string
|
||||
TlsClientKey string
|
||||
TlsClientCa string
|
||||
TlsSkipVerify bool
|
||||
}
|
||||
|
||||
func (ss *SocialService) Init() error {
|
||||
ss.oAuthProvider = make(map[string]*OAuthInfo)
|
||||
ss.socialMap = make(map[string]SocialConnector)
|
||||
|
||||
for _, name := range allOauthes {
|
||||
sec := ss.Cfg.Raw.Section("auth." + name)
|
||||
|
||||
info := &OAuthInfo{
|
||||
ClientId: sec.Key("client_id").String(),
|
||||
ClientSecret: sec.Key("client_secret").String(),
|
||||
Scopes: util.SplitString(sec.Key("scopes").String()),
|
||||
AuthUrl: sec.Key("auth_url").String(),
|
||||
TokenUrl: sec.Key("token_url").String(),
|
||||
ApiUrl: sec.Key("api_url").String(),
|
||||
Enabled: sec.Key("enabled").MustBool(),
|
||||
EmailAttributeName: sec.Key("email_attribute_name").String(),
|
||||
EmailAttributePath: sec.Key("email_attribute_path").String(),
|
||||
RoleAttributePath: sec.Key("role_attribute_path").String(),
|
||||
RoleAttributeStrict: sec.Key("role_attribute_strict").MustBool(),
|
||||
GroupsAttributePath: sec.Key("groups_attribute_path").String(),
|
||||
AllowedDomains: util.SplitString(sec.Key("allowed_domains").String()),
|
||||
HostedDomain: sec.Key("hosted_domain").String(),
|
||||
AllowSignup: sec.Key("allow_sign_up").MustBool(),
|
||||
Name: sec.Key("name").MustString(name),
|
||||
TlsClientCert: sec.Key("tls_client_cert").String(),
|
||||
TlsClientKey: sec.Key("tls_client_key").String(),
|
||||
TlsClientCa: sec.Key("tls_client_ca").String(),
|
||||
TlsSkipVerify: sec.Key("tls_skip_verify_insecure").MustBool(),
|
||||
}
|
||||
|
||||
// when empty_scopes parameter exists and is true, overwrite scope with empty value
|
||||
if sec.Key("empty_scopes").MustBool() {
|
||||
info.Scopes = []string{}
|
||||
}
|
||||
|
||||
if !info.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if name == "grafananet" {
|
||||
name = grafanaCom
|
||||
}
|
||||
|
||||
ss.oAuthProvider[name] = info
|
||||
|
||||
config := oauth2.Config{
|
||||
ClientID: info.ClientId,
|
||||
ClientSecret: info.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: info.AuthUrl,
|
||||
TokenURL: info.TokenUrl,
|
||||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
},
|
||||
RedirectURL: strings.TrimSuffix(ss.Cfg.AppURL, "/") + SocialBaseUrl + name,
|
||||
Scopes: info.Scopes,
|
||||
}
|
||||
|
||||
// GitHub.
|
||||
if name == "github" {
|
||||
ss.socialMap["github"] = &SocialGithub{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
apiUrl: info.ApiUrl,
|
||||
teamIds: sec.Key("team_ids").Ints(","),
|
||||
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
|
||||
}
|
||||
}
|
||||
|
||||
// GitLab.
|
||||
if name == "gitlab" {
|
||||
ss.socialMap["gitlab"] = &SocialGitlab{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
apiUrl: info.ApiUrl,
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
}
|
||||
}
|
||||
|
||||
// Google.
|
||||
if name == "google" {
|
||||
ss.socialMap["google"] = &SocialGoogle{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
hostedDomain: info.HostedDomain,
|
||||
apiUrl: info.ApiUrl,
|
||||
}
|
||||
}
|
||||
|
||||
// AzureAD.
|
||||
if name == "azuread" {
|
||||
ss.socialMap["azuread"] = &SocialAzureAD{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
autoAssignOrgRole: ss.Cfg.AutoAssignOrgRole,
|
||||
}
|
||||
}
|
||||
|
||||
// Okta
|
||||
if name == "okta" {
|
||||
ss.socialMap["okta"] = &SocialOkta{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
apiUrl: info.ApiUrl,
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
roleAttributePath: info.RoleAttributePath,
|
||||
roleAttributeStrict: info.RoleAttributeStrict,
|
||||
}
|
||||
}
|
||||
|
||||
// Generic - Uses the same scheme as GitHub.
|
||||
if name == "generic_oauth" {
|
||||
ss.socialMap["generic_oauth"] = &SocialGenericOAuth{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
apiUrl: info.ApiUrl,
|
||||
emailAttributeName: info.EmailAttributeName,
|
||||
emailAttributePath: info.EmailAttributePath,
|
||||
nameAttributePath: sec.Key("name_attribute_path").String(),
|
||||
roleAttributePath: info.RoleAttributePath,
|
||||
roleAttributeStrict: info.RoleAttributeStrict,
|
||||
groupsAttributePath: info.GroupsAttributePath,
|
||||
loginAttributePath: sec.Key("login_attribute_path").String(),
|
||||
idTokenAttributeName: sec.Key("id_token_attribute_name").String(),
|
||||
teamIds: sec.Key("team_ids").Ints(","),
|
||||
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
|
||||
}
|
||||
}
|
||||
|
||||
if name == grafanaCom {
|
||||
config = oauth2.Config{
|
||||
ClientID: info.ClientId,
|
||||
ClientSecret: info.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: ss.Cfg.GrafanaComURL + "/oauth2/authorize",
|
||||
TokenURL: ss.Cfg.GrafanaComURL + "/api/oauth2/token",
|
||||
AuthStyle: oauth2.AuthStyleInHeader,
|
||||
},
|
||||
RedirectURL: strings.TrimSuffix(ss.Cfg.AppURL, "/") + SocialBaseUrl + name,
|
||||
Scopes: info.Scopes,
|
||||
}
|
||||
|
||||
ss.socialMap[grafanaCom] = &SocialGrafanaCom{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
url: ss.Cfg.GrafanaComURL,
|
||||
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type BasicUserInfo struct {
|
||||
Id string
|
||||
Name string
|
||||
@ -68,7 +247,15 @@ var (
|
||||
allOauthes = []string{"github", "gitlab", "google", "generic_oauth", "grafananet", grafanaCom, "azuread", "okta"}
|
||||
)
|
||||
|
||||
func newSocialBase(name string, config *oauth2.Config, info *setting.OAuthInfo) *SocialBase {
|
||||
type Service interface {
|
||||
GetOAuthProviders() map[string]bool
|
||||
GetOAuthHttpClient(string) (*http.Client, error)
|
||||
GetConnector(string) (SocialConnector, error)
|
||||
GetOAuthInfoProvider(string) *OAuthInfo
|
||||
GetOAuthInfoProviders() map[string]*OAuthInfo
|
||||
}
|
||||
|
||||
func newSocialBase(name string, config *oauth2.Config, info *OAuthInfo) *SocialBase {
|
||||
logger := log.New("oauth." + name)
|
||||
|
||||
return &SocialBase{
|
||||
@ -79,156 +266,11 @@ func newSocialBase(name string, config *oauth2.Config, info *setting.OAuthInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func NewOAuthService(cfg *setting.Cfg) {
|
||||
setting.OAuthService = &setting.OAuther{}
|
||||
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
|
||||
|
||||
for _, name := range allOauthes {
|
||||
sec := cfg.Raw.Section("auth." + name)
|
||||
|
||||
info := &setting.OAuthInfo{
|
||||
ClientId: sec.Key("client_id").String(),
|
||||
ClientSecret: sec.Key("client_secret").String(),
|
||||
Scopes: util.SplitString(sec.Key("scopes").String()),
|
||||
AuthUrl: sec.Key("auth_url").String(),
|
||||
TokenUrl: sec.Key("token_url").String(),
|
||||
ApiUrl: sec.Key("api_url").String(),
|
||||
Enabled: sec.Key("enabled").MustBool(),
|
||||
EmailAttributeName: sec.Key("email_attribute_name").String(),
|
||||
EmailAttributePath: sec.Key("email_attribute_path").String(),
|
||||
RoleAttributePath: sec.Key("role_attribute_path").String(),
|
||||
RoleAttributeStrict: sec.Key("role_attribute_strict").MustBool(),
|
||||
GroupsAttributePath: sec.Key("groups_attribute_path").String(),
|
||||
AllowedDomains: util.SplitString(sec.Key("allowed_domains").String()),
|
||||
HostedDomain: sec.Key("hosted_domain").String(),
|
||||
AllowSignup: sec.Key("allow_sign_up").MustBool(),
|
||||
Name: sec.Key("name").MustString(name),
|
||||
TlsClientCert: sec.Key("tls_client_cert").String(),
|
||||
TlsClientKey: sec.Key("tls_client_key").String(),
|
||||
TlsClientCa: sec.Key("tls_client_ca").String(),
|
||||
TlsSkipVerify: sec.Key("tls_skip_verify_insecure").MustBool(),
|
||||
}
|
||||
|
||||
// when empty_scopes parameter exists and is true, overwrite scope with empty value
|
||||
if sec.Key("empty_scopes").MustBool() {
|
||||
info.Scopes = []string{}
|
||||
}
|
||||
|
||||
if !info.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if name == "grafananet" {
|
||||
name = grafanaCom
|
||||
}
|
||||
|
||||
setting.OAuthService.OAuthInfos[name] = info
|
||||
|
||||
config := oauth2.Config{
|
||||
ClientID: info.ClientId,
|
||||
ClientSecret: info.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: info.AuthUrl,
|
||||
TokenURL: info.TokenUrl,
|
||||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
},
|
||||
RedirectURL: strings.TrimSuffix(cfg.AppURL, "/") + SocialBaseUrl + name,
|
||||
Scopes: info.Scopes,
|
||||
}
|
||||
|
||||
// GitHub.
|
||||
if name == "github" {
|
||||
SocialMap["github"] = &SocialGithub{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
apiUrl: info.ApiUrl,
|
||||
teamIds: sec.Key("team_ids").Ints(","),
|
||||
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
|
||||
}
|
||||
}
|
||||
|
||||
// GitLab.
|
||||
if name == "gitlab" {
|
||||
SocialMap["gitlab"] = &SocialGitlab{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
apiUrl: info.ApiUrl,
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
}
|
||||
}
|
||||
|
||||
// Google.
|
||||
if name == "google" {
|
||||
SocialMap["google"] = &SocialGoogle{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
hostedDomain: info.HostedDomain,
|
||||
apiUrl: info.ApiUrl,
|
||||
}
|
||||
}
|
||||
|
||||
// AzureAD.
|
||||
if name == "azuread" {
|
||||
SocialMap["azuread"] = &SocialAzureAD{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
autoAssignOrgRole: cfg.AutoAssignOrgRole,
|
||||
}
|
||||
}
|
||||
|
||||
// Okta
|
||||
if name == "okta" {
|
||||
SocialMap["okta"] = &SocialOkta{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
apiUrl: info.ApiUrl,
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
roleAttributePath: info.RoleAttributePath,
|
||||
roleAttributeStrict: info.RoleAttributeStrict,
|
||||
}
|
||||
}
|
||||
|
||||
// Generic - Uses the same scheme as GitHub.
|
||||
if name == "generic_oauth" {
|
||||
SocialMap["generic_oauth"] = &SocialGenericOAuth{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
apiUrl: info.ApiUrl,
|
||||
emailAttributeName: info.EmailAttributeName,
|
||||
emailAttributePath: info.EmailAttributePath,
|
||||
nameAttributePath: sec.Key("name_attribute_path").String(),
|
||||
roleAttributePath: info.RoleAttributePath,
|
||||
roleAttributeStrict: info.RoleAttributeStrict,
|
||||
groupsAttributePath: info.GroupsAttributePath,
|
||||
loginAttributePath: sec.Key("login_attribute_path").String(),
|
||||
idTokenAttributeName: sec.Key("id_token_attribute_name").String(),
|
||||
teamIds: sec.Key("team_ids").Ints(","),
|
||||
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
|
||||
}
|
||||
}
|
||||
|
||||
if name == grafanaCom {
|
||||
config = oauth2.Config{
|
||||
ClientID: info.ClientId,
|
||||
ClientSecret: info.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: cfg.GrafanaComURL + "/oauth2/authorize",
|
||||
TokenURL: cfg.GrafanaComURL + "/api/oauth2/token",
|
||||
AuthStyle: oauth2.AuthStyleInHeader,
|
||||
},
|
||||
RedirectURL: strings.TrimSuffix(cfg.AppURL, "/") + SocialBaseUrl + name,
|
||||
Scopes: info.Scopes,
|
||||
}
|
||||
|
||||
SocialMap[grafanaCom] = &SocialGrafanaCom{
|
||||
SocialBase: newSocialBase(name, &config, info),
|
||||
url: cfg.GrafanaComURL,
|
||||
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetOAuthProviders returns available oauth providers and if they're enabled or not
|
||||
var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool {
|
||||
func (ss *SocialService) GetOAuthProviders() map[string]bool {
|
||||
result := map[string]bool{}
|
||||
|
||||
if cfg == nil || cfg.Raw == nil {
|
||||
if ss.Cfg == nil || ss.Cfg.Raw == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
@ -237,7 +279,7 @@ var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool {
|
||||
name = grafanaCom
|
||||
}
|
||||
|
||||
sec := cfg.Raw.Section("auth." + name)
|
||||
sec := ss.Cfg.Raw.Section("auth." + name)
|
||||
if sec == nil {
|
||||
continue
|
||||
}
|
||||
@ -247,13 +289,10 @@ var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool {
|
||||
return result
|
||||
}
|
||||
|
||||
func GetOAuthHttpClient(name string) (*http.Client, error) {
|
||||
if setting.OAuthService == nil {
|
||||
return nil, fmt.Errorf("OAuth not enabled")
|
||||
}
|
||||
func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
|
||||
// The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
|
||||
name = strings.TrimPrefix(name, "oauth_")
|
||||
info, ok := setting.OAuthService.OAuthInfos[name]
|
||||
info, ok := ss.oAuthProvider[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not find %q in OAuth Settings", name)
|
||||
}
|
||||
@ -292,12 +331,20 @@ func GetOAuthHttpClient(name string) (*http.Client, error) {
|
||||
return oauthClient, nil
|
||||
}
|
||||
|
||||
func GetConnector(name string) (SocialConnector, error) {
|
||||
func (ss *SocialService) GetConnector(name string) (SocialConnector, error) {
|
||||
// The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
|
||||
provider := strings.TrimPrefix(name, "oauth_")
|
||||
connector, ok := SocialMap[provider]
|
||||
connector, ok := ss.socialMap[provider]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find oauth provider for %q", name)
|
||||
}
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
func (ss *SocialService) GetOAuthInfoProvider(name string) *OAuthInfo {
|
||||
return ss.oAuthProvider[name]
|
||||
}
|
||||
|
||||
func (ss *SocialService) GetOAuthInfoProviders() map[string]*OAuthInfo {
|
||||
return ss.oAuthProvider
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ import (
|
||||
_ "github.com/grafana/grafana/pkg/infra/tracing"
|
||||
_ "github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/login"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
_ "github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/middleware"
|
||||
_ "github.com/grafana/grafana/pkg/plugins/manager"
|
||||
"github.com/grafana/grafana/pkg/registry"
|
||||
@ -150,7 +150,6 @@ func (s *Server) init() error {
|
||||
}
|
||||
|
||||
login.Init()
|
||||
social.NewOAuthService(s.cfg)
|
||||
|
||||
services := s.serviceRegistry.GetServices()
|
||||
if err := s.buildServiceGraph(services); err != nil {
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/registry"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
@ -27,6 +28,7 @@ type DatasourceProxyService struct {
|
||||
PluginManager plugins.Manager `inject:""`
|
||||
Cfg *setting.Cfg `inject:""`
|
||||
HTTPClientProvider httpclient.Provider `inject:""`
|
||||
OAuthTokenService *oauthtoken.Service `inject:""`
|
||||
}
|
||||
|
||||
func (p *DatasourceProxyService) Init() error {
|
||||
@ -68,7 +70,7 @@ func (p *DatasourceProxyService) ProxyDatasourceRequestWithID(c *models.ReqConte
|
||||
}
|
||||
|
||||
proxyPath := getProxyPath(c)
|
||||
proxy, err := pluginproxy.NewDataSourceProxy(ds, plugin, c, proxyPath, p.Cfg, p.HTTPClientProvider)
|
||||
proxy, err := pluginproxy.NewDataSourceProxy(ds, plugin, c, proxyPath, p.Cfg, p.HTTPClientProvider, p.OAuthTokenService)
|
||||
if err != nil {
|
||||
if errors.Is(err, datasource.URLValidationError{}) {
|
||||
c.JsonApiErr(http.StatusBadRequest, fmt.Sprintf("Invalid data source URL: %q", ds.Url), err)
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/registry"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -15,8 +16,25 @@ var (
|
||||
logger = log.New("oauthtoken")
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.RegisterService(&Service{})
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
SocialService social.Service `inject:""`
|
||||
}
|
||||
|
||||
type OAuthTokenService interface {
|
||||
GetCurrentOAuthToken(context.Context, *models.SignedInUser) *oauth2.Token
|
||||
IsOAuthPassThruEnabled(*models.DataSource) bool
|
||||
}
|
||||
|
||||
func (o *Service) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
|
||||
func GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth2.Token {
|
||||
func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth2.Token {
|
||||
if user == nil {
|
||||
// No user, therefore no token
|
||||
return nil
|
||||
@ -34,13 +52,13 @@ func GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth
|
||||
}
|
||||
|
||||
authProvider := authInfoQuery.Result.AuthModule
|
||||
connect, err := social.GetConnector(authProvider)
|
||||
connect, err := o.SocialService.GetConnector(authProvider)
|
||||
if err != nil {
|
||||
logger.Error("failed to get OAuth connector", "provider", authProvider, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
client, err := social.GetOAuthHttpClient(authProvider)
|
||||
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
|
||||
if err != nil {
|
||||
logger.Error("failed to get OAuth http client", "provider", authProvider, "error", err)
|
||||
return nil
|
||||
@ -78,7 +96,7 @@ func GetCurrentOAuthToken(ctx context.Context, user *models.SignedInUser) *oauth
|
||||
}
|
||||
|
||||
// IsOAuthPassThruEnabled returns true if Forward OAuth Identity (oauthPassThru) is enabled for the provided data source.
|
||||
func IsOAuthPassThruEnabled(ds *models.DataSource) bool {
|
||||
func (o *Service) IsOAuthPassThruEnabled(ds *models.DataSource) bool {
|
||||
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
|
||||
}
|
||||
|
@ -1,28 +0,0 @@
|
||||
package setting
|
||||
|
||||
type OAuthInfo struct {
|
||||
ClientId, ClientSecret string
|
||||
Scopes []string
|
||||
AuthUrl, TokenUrl string
|
||||
Enabled bool
|
||||
EmailAttributeName string
|
||||
EmailAttributePath string
|
||||
RoleAttributePath string
|
||||
RoleAttributeStrict bool
|
||||
GroupsAttributePath string
|
||||
AllowedDomains []string
|
||||
HostedDomain string
|
||||
ApiUrl string
|
||||
AllowSignup bool
|
||||
Name string
|
||||
TlsClientCert string
|
||||
TlsClientKey string
|
||||
TlsClientCa string
|
||||
TlsSkipVerify bool
|
||||
}
|
||||
|
||||
type OAuther struct {
|
||||
OAuthInfos map[string]*OAuthInfo
|
||||
}
|
||||
|
||||
var OAuthService *OAuther
|
@ -13,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
// nolint:staticcheck // plugins.DataQuery deprecated
|
||||
func dataPluginQueryAdapter(pluginID string, handler backend.QueryDataHandler) plugins.DataPluginFunc {
|
||||
func dataPluginQueryAdapter(pluginID string, handler backend.QueryDataHandler, oAuthService *oauthtoken.Service) plugins.DataPluginFunc {
|
||||
return plugins.DataPluginFunc(func(ctx context.Context, ds *models.DataSource, query plugins.DataQuery) (plugins.DataResponse, error) {
|
||||
instanceSettings, err := modelToInstanceSettings(ds)
|
||||
if err != nil {
|
||||
@ -24,8 +24,8 @@ func dataPluginQueryAdapter(pluginID string, handler backend.QueryDataHandler) p
|
||||
query.Headers = make(map[string]string)
|
||||
}
|
||||
|
||||
if oauthtoken.IsOAuthPassThruEnabled(ds) {
|
||||
if token := oauthtoken.GetCurrentOAuthToken(ctx, query.User); token != nil {
|
||||
if oAuthService.IsOAuthPassThruEnabled(ds) {
|
||||
if token := oAuthService.GetCurrentOAuthToken(ctx, query.User); token != nil {
|
||||
delete(query.Headers, "Authorization")
|
||||
query.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken)
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/plugins/backendplugin"
|
||||
"github.com/grafana/grafana/pkg/registry"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor"
|
||||
"github.com/grafana/grafana/pkg/tsdb/cloudmonitoring"
|
||||
@ -48,6 +49,7 @@ type Service struct {
|
||||
PluginManager plugins.Manager `inject:""`
|
||||
BackendPluginManager backendplugin.Manager `inject:""`
|
||||
HTTPClientProvider httpclient.Provider `inject:""`
|
||||
OAuthTokenService *oauthtoken.Service `inject:""`
|
||||
|
||||
//nolint: staticcheck // plugins.DataPlugin deprecated
|
||||
registry map[string]func(*models.DataSource) (plugins.DataPlugin, error)
|
||||
@ -81,7 +83,7 @@ func (s *Service) HandleRequest(ctx context.Context, ds *models.DataSource, quer
|
||||
return plugin.DataQuery(ctx, ds, query)
|
||||
}
|
||||
|
||||
return dataPluginQueryAdapter(ds.Type, s.BackendPluginManager).DataQuery(ctx, ds, query)
|
||||
return dataPluginQueryAdapter(ds.Type, s.BackendPluginManager, s.OAuthTokenService).DataQuery(ctx, ds, query)
|
||||
}
|
||||
|
||||
// RegisterQueryHandler registers a query handler factory.
|
||||
|
Loading…
Reference in New Issue
Block a user