Chore: Remove global encryption calls from authinfoservice (#38591)

* Add encryption service

* Add tests for encryption service

* Inject encryption service into http server

* Replace encryption global function usage in login tests

* Migrate to Wire

* Refactor authinfoservice to use encryption service

Co-authored-by: Joan López de la Franca Beltran <joanjan14@gmail.com>
Co-authored-by: Joan López de la Franca Beltran <5459617+joanlopez@users.noreply.github.com>
Co-authored-by: Emil Tullstedt <emil.tullstedt@grafana.com>
This commit is contained in:
Tania B 2021-08-31 15:00:13 +03:00 committed by GitHub
parent a5d11a3bef
commit bfde29d107
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 24 deletions

View File

@ -110,7 +110,7 @@ func TestLoginErrorCookieAPIEndpoint(t *testing.T) {
SettingsProvider: &setting.OSSImpl{Cfg: cfg}, SettingsProvider: &setting.OSSImpl{Cfg: cfg},
License: &licensing.OSSLicensingService{}, License: &licensing.OSSLicensingService{},
SocialService: &mockSocialService{}, SocialService: &mockSocialService{},
EncryptionService: &ossencryption.Service{}, EncryptionService: ossencryption.ProvideService(),
} }
sc.defaultHandler = routing.Wrap(func(w http.ResponseWriter, c *models.ReqContext) { sc.defaultHandler = routing.Wrap(func(w http.ResponseWriter, c *models.ReqContext) {

View File

@ -9,7 +9,6 @@ import (
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
) )
var getTime = time.Now var getTime = time.Now
@ -60,15 +59,15 @@ func (s *Implementation) GetAuthInfo(query *models.GetAuthInfoQuery) error {
return models.ErrUserNotFound return models.ErrUserNotFound
} }
secretAccessToken, err := decodeAndDecrypt(userAuth.OAuthAccessToken) secretAccessToken, err := s.decodeAndDecrypt(userAuth.OAuthAccessToken)
if err != nil { if err != nil {
return err return err
} }
secretRefreshToken, err := decodeAndDecrypt(userAuth.OAuthRefreshToken) secretRefreshToken, err := s.decodeAndDecrypt(userAuth.OAuthRefreshToken)
if err != nil { if err != nil {
return err return err
} }
secretTokenType, err := decodeAndDecrypt(userAuth.OAuthTokenType) secretTokenType, err := s.decodeAndDecrypt(userAuth.OAuthTokenType)
if err != nil { if err != nil {
return err return err
} }
@ -90,15 +89,15 @@ func (s *Implementation) SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
} }
if cmd.OAuthToken != nil { if cmd.OAuthToken != nil {
secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken) secretAccessToken, err := s.encryptAndEncode(cmd.OAuthToken.AccessToken)
if err != nil { if err != nil {
return err return err
} }
secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken) secretRefreshToken, err := s.encryptAndEncode(cmd.OAuthToken.RefreshToken)
if err != nil { if err != nil {
return err return err
} }
secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType) secretTokenType, err := s.encryptAndEncode(cmd.OAuthToken.TokenType)
if err != nil { if err != nil {
return err return err
} }
@ -124,15 +123,15 @@ func (s *Implementation) UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error
} }
if cmd.OAuthToken != nil { if cmd.OAuthToken != nil {
secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken) secretAccessToken, err := s.encryptAndEncode(cmd.OAuthToken.AccessToken)
if err != nil { if err != nil {
return err return err
} }
secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken) secretRefreshToken, err := s.encryptAndEncode(cmd.OAuthToken.RefreshToken)
if err != nil { if err != nil {
return err return err
} }
secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType) secretTokenType, err := s.encryptAndEncode(cmd.OAuthToken.TokenType)
if err != nil { if err != nil {
return err return err
} }
@ -160,18 +159,17 @@ func (s *Implementation) DeleteAuthInfo(cmd *models.DeleteAuthInfoCommand) error
}) })
} }
// decodeAndDecrypt will decode the string with the standard bas64 decoder // decodeAndDecrypt will decode the string with the standard base64 decoder and then decrypt it
// and then decrypt it with grafana's secretKey func (s *Implementation) decodeAndDecrypt(str string) (string, error) {
func decodeAndDecrypt(s string) (string, error) {
// Bail out if empty string since it'll cause a segfault in util.Decrypt // Bail out if empty string since it'll cause a segfault in util.Decrypt
if s == "" { if str == "" {
return "", nil return "", nil
} }
decoded, err := base64.StdEncoding.DecodeString(s) decoded, err := base64.StdEncoding.DecodeString(str)
if err != nil { if err != nil {
return "", err return "", err
} }
decrypted, err := util.Decrypt(decoded, setting.SecretKey) decrypted, err := s.EncryptionService.Decrypt(decoded, setting.SecretKey)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -180,8 +178,8 @@ func decodeAndDecrypt(s string) (string, error) {
// encryptAndEncode will encrypt a string with grafana's secretKey, and // encryptAndEncode will encrypt a string with grafana's secretKey, and
// then encode it with the standard bas64 encoder // then encode it with the standard bas64 encoder
func encryptAndEncode(s string) (string, error) { func (s *Implementation) encryptAndEncode(str string) (string, error) {
encrypted, err := util.Encrypt([]byte(s), setting.SecretKey) encrypted, err := s.EncryptionService.Encrypt([]byte(str), setting.SecretKey)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"errors" "errors"
"github.com/grafana/grafana/pkg/services/encryption"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
@ -17,15 +19,17 @@ type Implementation struct {
Bus bus.Bus Bus bus.Bus
SQLStore *sqlstore.SQLStore SQLStore *sqlstore.SQLStore
UserProtectionService login.UserProtectionService UserProtectionService login.UserProtectionService
EncryptionService encryption.Service
logger log.Logger logger log.Logger
} }
func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, service login.UserProtectionService) *Implementation { func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, userProtectionService login.UserProtectionService,
encryptionService encryption.Service) *Implementation {
s := &Implementation{ s := &Implementation{
Bus: bus, Bus: bus,
SQLStore: store, SQLStore: store,
UserProtectionService: service, UserProtectionService: userProtectionService,
EncryptionService: encryptionService,
logger: log.New("login.authinfo"), logger: log.New("login.authinfo"),
} }

View File

@ -8,6 +8,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/grafana/grafana/pkg/services/encryption/ossencryption"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
@ -20,7 +22,7 @@ import (
//nolint:goconst //nolint:goconst
func TestUserAuth(t *testing.T) { func TestUserAuth(t *testing.T) {
sqlStore := sqlstore.InitTestDB(t) sqlStore := sqlstore.InitTestDB(t)
srv := ProvideAuthInfoService(bus.New(), sqlStore, &OSSUserProtectionImpl{}) srv := ProvideAuthInfoService(bus.New(), sqlStore, &OSSUserProtectionImpl{}, ossencryption.ProvideService())
t.Run("Given 5 users", func(t *testing.T) { t.Run("Given 5 users", func(t *testing.T) {
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {