diff --git a/pkg/api/login_test.go b/pkg/api/login_test.go index dc06e65c5ba..9e396c220f5 100644 --- a/pkg/api/login_test.go +++ b/pkg/api/login_test.go @@ -110,7 +110,7 @@ func TestLoginErrorCookieAPIEndpoint(t *testing.T) { SettingsProvider: &setting.OSSImpl{Cfg: cfg}, License: &licensing.OSSLicensingService{}, SocialService: &mockSocialService{}, - EncryptionService: &ossencryption.Service{}, + EncryptionService: ossencryption.ProvideService(), } sc.defaultHandler = routing.Wrap(func(w http.ResponseWriter, c *models.ReqContext) { diff --git a/pkg/services/login/authinfoservice/database.go b/pkg/services/login/authinfoservice/database.go index e29a50af536..157fa387053 100644 --- a/pkg/services/login/authinfoservice/database.go +++ b/pkg/services/login/authinfoservice/database.go @@ -9,7 +9,6 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/util" ) var getTime = time.Now @@ -60,15 +59,15 @@ func (s *Implementation) GetAuthInfo(query *models.GetAuthInfoQuery) error { return models.ErrUserNotFound } - secretAccessToken, err := decodeAndDecrypt(userAuth.OAuthAccessToken) + secretAccessToken, err := s.decodeAndDecrypt(userAuth.OAuthAccessToken) if err != nil { return err } - secretRefreshToken, err := decodeAndDecrypt(userAuth.OAuthRefreshToken) + secretRefreshToken, err := s.decodeAndDecrypt(userAuth.OAuthRefreshToken) if err != nil { return err } - secretTokenType, err := decodeAndDecrypt(userAuth.OAuthTokenType) + secretTokenType, err := s.decodeAndDecrypt(userAuth.OAuthTokenType) if err != nil { return err } @@ -90,15 +89,15 @@ func (s *Implementation) SetAuthInfo(cmd *models.SetAuthInfoCommand) error { } if cmd.OAuthToken != nil { - secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken) + secretAccessToken, err := s.encryptAndEncode(cmd.OAuthToken.AccessToken) if err != nil { return err } - secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken) + secretRefreshToken, err := s.encryptAndEncode(cmd.OAuthToken.RefreshToken) if err != nil { return err } - secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType) + secretTokenType, err := s.encryptAndEncode(cmd.OAuthToken.TokenType) if err != nil { return err } @@ -124,15 +123,15 @@ func (s *Implementation) UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error } if cmd.OAuthToken != nil { - secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken) + secretAccessToken, err := s.encryptAndEncode(cmd.OAuthToken.AccessToken) if err != nil { return err } - secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken) + secretRefreshToken, err := s.encryptAndEncode(cmd.OAuthToken.RefreshToken) if err != nil { return err } - secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType) + secretTokenType, err := s.encryptAndEncode(cmd.OAuthToken.TokenType) if err != nil { return err } @@ -160,18 +159,17 @@ func (s *Implementation) DeleteAuthInfo(cmd *models.DeleteAuthInfoCommand) error }) } -// decodeAndDecrypt will decode the string with the standard bas64 decoder -// and then decrypt it with grafana's secretKey -func decodeAndDecrypt(s string) (string, error) { +// decodeAndDecrypt will decode the string with the standard base64 decoder and then decrypt it +func (s *Implementation) decodeAndDecrypt(str string) (string, error) { // Bail out if empty string since it'll cause a segfault in util.Decrypt - if s == "" { + if str == "" { return "", nil } - decoded, err := base64.StdEncoding.DecodeString(s) + decoded, err := base64.StdEncoding.DecodeString(str) if err != nil { return "", err } - decrypted, err := util.Decrypt(decoded, setting.SecretKey) + decrypted, err := s.EncryptionService.Decrypt(decoded, setting.SecretKey) if err != nil { return "", err } @@ -180,8 +178,8 @@ func decodeAndDecrypt(s string) (string, error) { // encryptAndEncode will encrypt a string with grafana's secretKey, and // then encode it with the standard bas64 encoder -func encryptAndEncode(s string) (string, error) { - encrypted, err := util.Encrypt([]byte(s), setting.SecretKey) +func (s *Implementation) encryptAndEncode(str string) (string, error) { + encrypted, err := s.EncryptionService.Encrypt([]byte(str), setting.SecretKey) if err != nil { return "", err } diff --git a/pkg/services/login/authinfoservice/service.go b/pkg/services/login/authinfoservice/service.go index 95d72f3d8fd..c691709ab40 100644 --- a/pkg/services/login/authinfoservice/service.go +++ b/pkg/services/login/authinfoservice/service.go @@ -4,6 +4,8 @@ import ( "context" "errors" + "github.com/grafana/grafana/pkg/services/encryption" + "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" @@ -17,15 +19,17 @@ type Implementation struct { Bus bus.Bus SQLStore *sqlstore.SQLStore UserProtectionService login.UserProtectionService - - logger log.Logger + EncryptionService encryption.Service + logger log.Logger } -func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, service login.UserProtectionService) *Implementation { +func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, userProtectionService login.UserProtectionService, + encryptionService encryption.Service) *Implementation { s := &Implementation{ Bus: bus, SQLStore: store, - UserProtectionService: service, + UserProtectionService: userProtectionService, + EncryptionService: encryptionService, logger: log.New("login.authinfo"), } diff --git a/pkg/services/login/authinfoservice/user_auth_test.go b/pkg/services/login/authinfoservice/user_auth_test.go index 64b6f1d96a3..cd93970a91d 100644 --- a/pkg/services/login/authinfoservice/user_auth_test.go +++ b/pkg/services/login/authinfoservice/user_auth_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/grafana/grafana/pkg/services/encryption/ossencryption" + "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/bus" @@ -20,7 +22,7 @@ import ( //nolint:goconst func TestUserAuth(t *testing.T) { sqlStore := sqlstore.InitTestDB(t) - srv := ProvideAuthInfoService(bus.New(), sqlStore, &OSSUserProtectionImpl{}) + srv := ProvideAuthInfoService(bus.New(), sqlStore, &OSSUserProtectionImpl{}, ossencryption.ProvideService()) t.Run("Given 5 users", func(t *testing.T) { for i := 0; i < 5; i++ {