Merge pull request #15205 from seanlaff/12556-oauth-pass-thru

Add oauth pass-thru option for datasources
This commit is contained in:
Daniel Lee
2019-03-25 21:52:20 +01:00
committed by GitHub
10 changed files with 379 additions and 15 deletions

View File

@@ -25,4 +25,21 @@ func addUserAuthMigrations(mg *Migrator) {
mg.AddMigration("alter user_auth.auth_id to length 190", NewRawSqlMigration("").
Postgres("ALTER TABLE user_auth ALTER COLUMN auth_id TYPE VARCHAR(190);").
Mysql("ALTER TABLE user_auth MODIFY auth_id VARCHAR(190);"))
mg.AddMigration("Add OAuth access token to user_auth", NewAddColumnMigration(userAuthV1, &Column{
Name: "o_auth_access_token", Type: DB_Text, Nullable: true,
}))
mg.AddMigration("Add OAuth refresh token to user_auth", NewAddColumnMigration(userAuthV1, &Column{
Name: "o_auth_refresh_token", Type: DB_Text, Nullable: true,
}))
mg.AddMigration("Add OAuth token type to user_auth", NewAddColumnMigration(userAuthV1, &Column{
Name: "o_auth_token_type", Type: DB_Text, Nullable: true,
}))
mg.AddMigration("Add OAuth expiry to user_auth", NewAddColumnMigration(userAuthV1, &Column{
Name: "o_auth_expiry", Type: DB_DateTime, Nullable: true,
}))
mg.AddMigration("Add index to user_id column in user_auth", NewAddIndexMigration(userAuthV1, &Index{
Cols: []string{"user_id"},
}))
}

View File

@@ -1,16 +1,22 @@
package sqlstore
import (
"encoding/base64"
"time"
"github.com/grafana/grafana/pkg/bus"
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)
var getTime = time.Now
func init() {
bus.AddHandler("sql", GetUserByAuthInfo)
bus.AddHandler("sql", GetAuthInfo)
bus.AddHandler("sql", SetAuthInfo)
bus.AddHandler("sql", UpdateAuthInfo)
bus.AddHandler("sql", DeleteAuthInfo)
}
@@ -94,7 +100,7 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
}
// create authInfo record to link accounts
if authQuery.Result == nil && query.AuthModule != "" && query.AuthId != "" {
if authQuery.Result == nil && query.AuthModule != "" {
cmd2 := &m.SetAuthInfoCommand{
UserId: user.Id,
AuthModule: query.AuthModule,
@@ -111,10 +117,11 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
func GetAuthInfo(query *m.GetAuthInfoQuery) error {
userAuth := &m.UserAuth{
UserId: query.UserId,
AuthModule: query.AuthModule,
AuthId: query.AuthId,
}
has, err := x.Get(userAuth)
has, err := x.Desc("created").Get(userAuth)
if err != nil {
return err
}
@@ -122,6 +129,22 @@ func GetAuthInfo(query *m.GetAuthInfoQuery) error {
return m.ErrUserNotFound
}
secretAccessToken, err := decodeAndDecrypt(userAuth.OAuthAccessToken)
if err != nil {
return err
}
secretRefreshToken, err := decodeAndDecrypt(userAuth.OAuthRefreshToken)
if err != nil {
return err
}
secretTokenType, err := decodeAndDecrypt(userAuth.OAuthTokenType)
if err != nil {
return err
}
userAuth.OAuthAccessToken = secretAccessToken
userAuth.OAuthRefreshToken = secretRefreshToken
userAuth.OAuthTokenType = secretTokenType
query.Result = userAuth
return nil
}
@@ -132,7 +155,27 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
UserId: cmd.UserId,
AuthModule: cmd.AuthModule,
AuthId: cmd.AuthId,
Created: time.Now(),
Created: getTime(),
}
if cmd.OAuthToken != nil {
secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
if err != nil {
return err
}
secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
if err != nil {
return err
}
secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
if err != nil {
return err
}
authUser.OAuthAccessToken = secretAccessToken
authUser.OAuthRefreshToken = secretRefreshToken
authUser.OAuthTokenType = secretTokenType
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
}
_, err := sess.Insert(authUser)
@@ -140,9 +183,76 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
})
}
func UpdateAuthInfo(cmd *m.UpdateAuthInfoCommand) error {
return inTransaction(func(sess *DBSession) error {
authUser := &m.UserAuth{
UserId: cmd.UserId,
AuthModule: cmd.AuthModule,
AuthId: cmd.AuthId,
Created: getTime(),
}
if cmd.OAuthToken != nil {
secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
if err != nil {
return err
}
secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
if err != nil {
return err
}
secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
if err != nil {
return err
}
authUser.OAuthAccessToken = secretAccessToken
authUser.OAuthRefreshToken = secretRefreshToken
authUser.OAuthTokenType = secretTokenType
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
}
cond := &m.UserAuth{
UserId: cmd.UserId,
AuthModule: cmd.AuthModule,
}
_, err := sess.Update(authUser, cond)
return err
})
}
func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error {
return inTransaction(func(sess *DBSession) error {
_, err := sess.Delete(cmd.UserAuth)
return err
})
}
// decodeAndDecrypt will decode the string with the standard bas64 decoder
// and then decrypt it with grafana's secretKey
func decodeAndDecrypt(s string) (string, error) {
// Bail out if empty string since it'll cause a segfault in util.Decrypt
if s == "" {
return "", nil
}
decoded, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return "", err
}
decrypted, err := util.Decrypt(decoded, setting.SecretKey)
if err != nil {
return "", err
}
return string(decrypted), nil
}
// encryptAndEncode will encrypt a string with grafana's secretKey, and
// then encode it with the standard bas64 encoder
func encryptAndEncode(s string) (string, error) {
encrypted, err := util.Encrypt([]byte(s), setting.SecretKey)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}

View File

@@ -4,8 +4,10 @@ import (
"context"
"fmt"
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"golang.org/x/oauth2"
m "github.com/grafana/grafana/pkg/models"
)
@@ -126,5 +128,97 @@ func TestUserAuth(t *testing.T) {
So(err, ShouldEqual, m.ErrUserNotFound)
So(query.Result, ShouldBeNil)
})
Convey("Can set & retrieve oauth token information", func() {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now(),
TokenType: "Bearer",
}
// Find a user to set tokens on
login := "loginuser0"
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
err = GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
cmd := &m.UpdateAuthInfoCommand{
UserId: query.Result.Id,
AuthId: query.AuthId,
AuthModule: query.AuthModule,
OAuthToken: token,
}
err = UpdateAuthInfo(cmd)
So(err, ShouldBeNil)
getAuthQuery := &m.GetAuthInfoQuery{
UserId: query.Result.Id,
}
err = GetAuthInfo(getAuthQuery)
So(err, ShouldBeNil)
So(getAuthQuery.Result.OAuthAccessToken, ShouldEqual, token.AccessToken)
So(getAuthQuery.Result.OAuthRefreshToken, ShouldEqual, token.RefreshToken)
So(getAuthQuery.Result.OAuthTokenType, ShouldEqual, token.TokenType)
})
Convey("Always return the most recently used auth_module", func() {
// Find a user to set tokens on
login := "loginuser0"
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
// Make the first log-in during the past
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"}
err = GetUserByAuthInfo(query)
getTime = time.Now
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
// Add a second auth module for this user
// Have this module's last log-in be more recent
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
query = &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"}
err = GetUserByAuthInfo(query)
getTime = time.Now
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
// Get the latest entry by not supply an authmodule or authid
getAuthQuery := &m.GetAuthInfoQuery{
UserId: query.Result.Id,
}
err = GetAuthInfo(getAuthQuery)
So(err, ShouldBeNil)
So(getAuthQuery.Result.AuthModule, ShouldEqual, "test2")
// "log in" again with the first auth module
updateAuthCmd := &m.UpdateAuthInfoCommand{UserId: query.Result.Id, AuthModule: "test1", AuthId: "test1"}
err = UpdateAuthInfo(updateAuthCmd)
So(err, ShouldBeNil)
// Get the latest entry by not supply an authmodule or authid
getAuthQuery = &m.GetAuthInfoQuery{
UserId: query.Result.Id,
}
err = GetAuthInfo(getAuthQuery)
So(err, ShouldBeNil)
So(getAuthQuery.Result.AuthModule, ShouldEqual, "test1")
})
})
}