mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Merge pull request #15205 from seanlaff/12556-oauth-pass-thru
Add oauth pass-thru option for datasources
This commit is contained in:
@@ -25,4 +25,21 @@ func addUserAuthMigrations(mg *Migrator) {
|
||||
mg.AddMigration("alter user_auth.auth_id to length 190", NewRawSqlMigration("").
|
||||
Postgres("ALTER TABLE user_auth ALTER COLUMN auth_id TYPE VARCHAR(190);").
|
||||
Mysql("ALTER TABLE user_auth MODIFY auth_id VARCHAR(190);"))
|
||||
|
||||
mg.AddMigration("Add OAuth access token to user_auth", NewAddColumnMigration(userAuthV1, &Column{
|
||||
Name: "o_auth_access_token", Type: DB_Text, Nullable: true,
|
||||
}))
|
||||
mg.AddMigration("Add OAuth refresh token to user_auth", NewAddColumnMigration(userAuthV1, &Column{
|
||||
Name: "o_auth_refresh_token", Type: DB_Text, Nullable: true,
|
||||
}))
|
||||
mg.AddMigration("Add OAuth token type to user_auth", NewAddColumnMigration(userAuthV1, &Column{
|
||||
Name: "o_auth_token_type", Type: DB_Text, Nullable: true,
|
||||
}))
|
||||
mg.AddMigration("Add OAuth expiry to user_auth", NewAddColumnMigration(userAuthV1, &Column{
|
||||
Name: "o_auth_expiry", Type: DB_DateTime, Nullable: true,
|
||||
}))
|
||||
|
||||
mg.AddMigration("Add index to user_id column in user_auth", NewAddIndexMigration(userAuthV1, &Index{
|
||||
Cols: []string{"user_id"},
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
m "github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
var getTime = time.Now
|
||||
|
||||
func init() {
|
||||
bus.AddHandler("sql", GetUserByAuthInfo)
|
||||
bus.AddHandler("sql", GetAuthInfo)
|
||||
bus.AddHandler("sql", SetAuthInfo)
|
||||
bus.AddHandler("sql", UpdateAuthInfo)
|
||||
bus.AddHandler("sql", DeleteAuthInfo)
|
||||
}
|
||||
|
||||
@@ -94,7 +100,7 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
|
||||
}
|
||||
|
||||
// create authInfo record to link accounts
|
||||
if authQuery.Result == nil && query.AuthModule != "" && query.AuthId != "" {
|
||||
if authQuery.Result == nil && query.AuthModule != "" {
|
||||
cmd2 := &m.SetAuthInfoCommand{
|
||||
UserId: user.Id,
|
||||
AuthModule: query.AuthModule,
|
||||
@@ -111,10 +117,11 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
|
||||
|
||||
func GetAuthInfo(query *m.GetAuthInfoQuery) error {
|
||||
userAuth := &m.UserAuth{
|
||||
UserId: query.UserId,
|
||||
AuthModule: query.AuthModule,
|
||||
AuthId: query.AuthId,
|
||||
}
|
||||
has, err := x.Get(userAuth)
|
||||
has, err := x.Desc("created").Get(userAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -122,6 +129,22 @@ func GetAuthInfo(query *m.GetAuthInfoQuery) error {
|
||||
return m.ErrUserNotFound
|
||||
}
|
||||
|
||||
secretAccessToken, err := decodeAndDecrypt(userAuth.OAuthAccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretRefreshToken, err := decodeAndDecrypt(userAuth.OAuthRefreshToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretTokenType, err := decodeAndDecrypt(userAuth.OAuthTokenType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAuth.OAuthAccessToken = secretAccessToken
|
||||
userAuth.OAuthRefreshToken = secretRefreshToken
|
||||
userAuth.OAuthTokenType = secretTokenType
|
||||
|
||||
query.Result = userAuth
|
||||
return nil
|
||||
}
|
||||
@@ -132,7 +155,27 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
AuthId: cmd.AuthId,
|
||||
Created: time.Now(),
|
||||
Created: getTime(),
|
||||
}
|
||||
|
||||
if cmd.OAuthToken != nil {
|
||||
secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authUser.OAuthAccessToken = secretAccessToken
|
||||
authUser.OAuthRefreshToken = secretRefreshToken
|
||||
authUser.OAuthTokenType = secretTokenType
|
||||
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||
}
|
||||
|
||||
_, err := sess.Insert(authUser)
|
||||
@@ -140,9 +183,76 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
|
||||
})
|
||||
}
|
||||
|
||||
func UpdateAuthInfo(cmd *m.UpdateAuthInfoCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
authUser := &m.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
AuthId: cmd.AuthId,
|
||||
Created: getTime(),
|
||||
}
|
||||
|
||||
if cmd.OAuthToken != nil {
|
||||
secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authUser.OAuthAccessToken = secretAccessToken
|
||||
authUser.OAuthRefreshToken = secretRefreshToken
|
||||
authUser.OAuthTokenType = secretTokenType
|
||||
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||
}
|
||||
|
||||
cond := &m.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
}
|
||||
|
||||
_, err := sess.Update(authUser, cond)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
_, err := sess.Delete(cmd.UserAuth)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// decodeAndDecrypt will decode the string with the standard bas64 decoder
|
||||
// and then decrypt it with grafana's secretKey
|
||||
func decodeAndDecrypt(s string) (string, error) {
|
||||
// Bail out if empty string since it'll cause a segfault in util.Decrypt
|
||||
if s == "" {
|
||||
return "", nil
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
decrypted, err := util.Decrypt(decoded, setting.SecretKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(decrypted), nil
|
||||
}
|
||||
|
||||
// encryptAndEncode will encrypt a string with grafana's secretKey, and
|
||||
// then encode it with the standard bas64 encoder
|
||||
func encryptAndEncode(s string) (string, error) {
|
||||
encrypted, err := util.Encrypt([]byte(s), setting.SecretKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
m "github.com/grafana/grafana/pkg/models"
|
||||
)
|
||||
@@ -126,5 +128,97 @@ func TestUserAuth(t *testing.T) {
|
||||
So(err, ShouldEqual, m.ErrUserNotFound)
|
||||
So(query.Result, ShouldBeNil)
|
||||
})
|
||||
|
||||
Convey("Can set & retrieve oauth token information", func() {
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now(),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
|
||||
// Find a user to set tokens on
|
||||
login := "loginuser0"
|
||||
|
||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||
query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(query.Result.Login, ShouldEqual, login)
|
||||
|
||||
cmd := &m.UpdateAuthInfoCommand{
|
||||
UserId: query.Result.Id,
|
||||
AuthId: query.AuthId,
|
||||
AuthModule: query.AuthModule,
|
||||
OAuthToken: token,
|
||||
}
|
||||
err = UpdateAuthInfo(cmd)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
getAuthQuery := &m.GetAuthInfoQuery{
|
||||
UserId: query.Result.Id,
|
||||
}
|
||||
|
||||
err = GetAuthInfo(getAuthQuery)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(getAuthQuery.Result.OAuthAccessToken, ShouldEqual, token.AccessToken)
|
||||
So(getAuthQuery.Result.OAuthRefreshToken, ShouldEqual, token.RefreshToken)
|
||||
So(getAuthQuery.Result.OAuthTokenType, ShouldEqual, token.TokenType)
|
||||
|
||||
})
|
||||
|
||||
Convey("Always return the most recently used auth_module", func() {
|
||||
// Find a user to set tokens on
|
||||
login := "loginuser0"
|
||||
|
||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||
// Make the first log-in during the past
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
getTime = time.Now
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(query.Result.Login, ShouldEqual, login)
|
||||
|
||||
// Add a second auth module for this user
|
||||
// Have this module's last log-in be more recent
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
||||
query = &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
getTime = time.Now
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(query.Result.Login, ShouldEqual, login)
|
||||
|
||||
// Get the latest entry by not supply an authmodule or authid
|
||||
getAuthQuery := &m.GetAuthInfoQuery{
|
||||
UserId: query.Result.Id,
|
||||
}
|
||||
|
||||
err = GetAuthInfo(getAuthQuery)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(getAuthQuery.Result.AuthModule, ShouldEqual, "test2")
|
||||
|
||||
// "log in" again with the first auth module
|
||||
updateAuthCmd := &m.UpdateAuthInfoCommand{UserId: query.Result.Id, AuthModule: "test1", AuthId: "test1"}
|
||||
err = UpdateAuthInfo(updateAuthCmd)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Get the latest entry by not supply an authmodule or authid
|
||||
getAuthQuery = &m.GetAuthInfoQuery{
|
||||
UserId: query.Result.Id,
|
||||
}
|
||||
|
||||
err = GetAuthInfo(getAuthQuery)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(getAuthQuery.Result.AuthModule, ShouldEqual, "test1")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user