From 9f5a8bf92607aedb93361ca908d3c83542f07ca5 Mon Sep 17 00:00:00 2001 From: Jo Date: Tue, 23 Jan 2024 15:26:38 +0100 Subject: [PATCH] AuthInfo: Revert #81013. Fix cache invalidation (#81050) * Revert "Auth: Revert "Auth: Cache Auth Info" (#81013)" This reverts commit ce84f7c5405b9696bbe443409b5130db490856cf. * fix cache invalidation during user takeover * fix incomplete test --- pkg/api/user_test.go | 4 +- pkg/infra/remotecache/redis_storage.go | 11 +- pkg/services/login/authinfoimpl/service.go | 164 +++++++++++++++++++- pkg/services/oauthtoken/oauth_token_test.go | 22 ++- 4 files changed, 185 insertions(+), 16 deletions(-) diff --git a/pkg/api/user_test.go b/pkg/api/user_test.go index fa1fc14c2ee..2b9bfa01951 100644 --- a/pkg/api/user_test.go +++ b/pkg/api/user_test.go @@ -18,6 +18,7 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/db/dbtest" + "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social/socialtest" "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" @@ -64,8 +65,7 @@ func TestUserAPIEndpoint_userLoggedIn(t *testing.T) { secretsService := secretsManager.SetupTestService(t, database.ProvideSecretsStore(sqlStore)) authInfoStore := authinfoimpl.ProvideStore(sqlStore, secretsService) srv := authinfoimpl.ProvideService( - authInfoStore, - ) + authInfoStore, remotecache.NewFakeCacheStorage(), secretsService) hs.authInfoService = srv orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, quotatest.New(false, nil)) require.NoError(t, err) diff --git a/pkg/infra/remotecache/redis_storage.go b/pkg/infra/remotecache/redis_storage.go index 13684530525..e97679c04f6 100644 --- a/pkg/infra/remotecache/redis_storage.go +++ b/pkg/infra/remotecache/redis_storage.go @@ -3,6 +3,7 @@ package remotecache import ( "context" "crypto/tls" + "errors" "fmt" "strconv" "strings" @@ -93,7 +94,15 @@ func (s *redisStorage) Set(ctx context.Context, key string, data []byte, expires // GetByteArray returns the value as byte array func (s *redisStorage) Get(ctx context.Context, key string) ([]byte, error) { - return s.c.Get(ctx, key).Bytes() + item, err := s.c.Get(ctx, key).Bytes() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, ErrCacheItemNotFound + } + return nil, err + } + + return item, nil } // Delete delete a key from session. diff --git a/pkg/services/login/authinfoimpl/service.go b/pkg/services/login/authinfoimpl/service.go index bb4759febe0..ef82d4a150b 100644 --- a/pkg/services/login/authinfoimpl/service.go +++ b/pkg/services/login/authinfoimpl/service.go @@ -2,27 +2,68 @@ package authinfoimpl import ( "context" + "encoding/json" + "errors" + "strconv" + "time" "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/secrets" + "github.com/grafana/grafana/pkg/services/user" ) type Service struct { authInfoStore login.Store logger log.Logger + remoteCache remotecache.CacheStorage + secretService secrets.Service } -func ProvideService(authInfoStore login.Store) *Service { +const remoteCachePrefix = "authinfo-" +const remoteCacheTTL = 60 * time.Hour + +var errMissingParameters = errors.New("user ID and auth ID must be set") + +func ProvideService(authInfoStore login.Store, + remoteCache remotecache.CacheStorage, + secretService secrets.Service) *Service { s := &Service{ authInfoStore: authInfoStore, logger: log.New("login.authinfo"), + remoteCache: remoteCache, + secretService: secretService, } return s } func (s *Service) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) { - return s.authInfoStore.GetAuthInfo(ctx, query) + if query.UserId == 0 && query.AuthId == "" { + return nil, user.ErrUserNotFound + } + + authInfo, err := s.getAuthInfoFromCache(ctx, query) + if err != nil && !errors.Is(err, remotecache.ErrCacheItemNotFound) { + s.logger.Error("failed to retrieve auth info from cache", "error", err) + } else if authInfo != nil { + return authInfo, nil + } + + authInfo, err = s.authInfoStore.GetAuthInfo(ctx, query) + if err != nil { + return nil, err + } + + err = s.setAuthInfoInCache(ctx, query, authInfo) + if err != nil { + s.logger.Error("failed to set auth info in cache", "error", err) + } else { + s.logger.Debug("auth info set in cache", "cacheKey", generateCacheKey(query)) + } + + return authInfo, nil } func (s *Service) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) { @@ -32,14 +73,127 @@ func (s *Service) GetUserLabels(ctx context.Context, query login.GetUserLabelsQu return s.authInfoStore.GetUserLabels(ctx, query) } +func (s *Service) setAuthInfoInCache(ctx context.Context, query *login.GetAuthInfoQuery, info *login.UserAuth) error { + cacheKey := generateCacheKey(query) + infoJSON, err := json.Marshal(info) + if err != nil { + return err + } + + encryptedInfo, err := s.secretService.Encrypt(ctx, infoJSON, secrets.WithoutScope()) + if err != nil { + return err + } + + return s.remoteCache.Set(ctx, cacheKey, encryptedInfo, remoteCacheTTL) +} + +func (s *Service) getAuthInfoFromCache(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) { + // check if we have the auth info in the remote cache + cacheKey := generateCacheKey(query) + item, err := s.remoteCache.Get(ctx, cacheKey) + if err != nil { + return nil, err + } + + info := &login.UserAuth{} + itemJSON, err := s.secretService.Decrypt(ctx, item) + if err != nil { + return nil, err + } + + if err := json.Unmarshal(itemJSON, info); err != nil { + return nil, err + } + + s.logger.Debug("auth info retrieved from cache", "cacheKey", cacheKey) + + return info, nil +} + +func generateCacheKey(query *login.GetAuthInfoQuery) string { + cacheKey := remoteCachePrefix + strconv.FormatInt(query.UserId, 10) + "-" + + query.AuthModule + "-" + query.AuthId + return cacheKey +} + func (s *Service) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error { - return s.authInfoStore.UpdateAuthInfo(ctx, cmd) + if cmd.UserId == 0 || cmd.AuthId == "" { + return errMissingParameters + } + + err := s.authInfoStore.UpdateAuthInfo(ctx, cmd) + if err != nil { + return err + } + + s.deleteUserAuthInfoInCache(ctx, &login.GetAuthInfoQuery{ + AuthModule: cmd.AuthModule, + AuthId: cmd.AuthId, + UserId: cmd.UserId, + }) + + return nil } func (s *Service) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error { - return s.authInfoStore.SetAuthInfo(ctx, cmd) + if cmd.UserId == 0 || cmd.AuthId == "" { + return errMissingParameters + } + + err := s.authInfoStore.SetAuthInfo(ctx, cmd) + if err != nil { + return err + } + + s.deleteUserAuthInfoInCache(ctx, &login.GetAuthInfoQuery{ + AuthModule: cmd.AuthModule, + AuthId: cmd.AuthId, + UserId: cmd.UserId, + }) + + return nil } func (s *Service) DeleteUserAuthInfo(ctx context.Context, userID int64) error { - return s.authInfoStore.DeleteUserAuthInfo(ctx, userID) + err := s.authInfoStore.DeleteUserAuthInfo(ctx, userID) + if err != nil { + return err + } + + err = s.remoteCache.Delete(ctx, generateCacheKey(&login.GetAuthInfoQuery{ + UserId: userID, + })) + if err != nil { + s.logger.Error("failed to delete auth info from cache", "error", err) + } + + return nil +} + +func (s *Service) deleteUserAuthInfoInCache(ctx context.Context, query *login.GetAuthInfoQuery) { + err := s.remoteCache.Delete(ctx, generateCacheKey(&login.GetAuthInfoQuery{ + AuthModule: query.AuthModule, + AuthId: query.AuthId, + })) + if err != nil { + s.logger.Warn("failed to delete auth info from cache", "error", err) + } + + errN := s.remoteCache.Delete(ctx, generateCacheKey( + &login.GetAuthInfoQuery{ + UserId: query.UserId, + })) + if errN != nil { + s.logger.Warn("failed to delete user auth info from cache", "error", errN) + } + + errA := s.remoteCache.Delete(ctx, generateCacheKey( + &login.GetAuthInfoQuery{ + UserId: query.UserId, + AuthModule: query.AuthModule, + })) + if errA != nil { + s.logger.Warn("failed to delete user module auth info from cache", "error", errA) + } } diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go index 28e56398ac6..ed7422449ee 100644 --- a/pkg/services/oauthtoken/oauth_token_test.go +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -10,12 +10,16 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/sync/singleflight" + "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/login/social/socialtest" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login/authinfoimpl" + "github.com/grafana/grafana/pkg/services/secrets/fakes" + secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -114,13 +118,12 @@ func TestService_TryTokenRefresh_ValidToken(t *testing.T) { socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)) err := srv.TryTokenRefresh(ctx, usr) - assert.Nil(t, err) + require.Nil(t, err) socialConnector.AssertNumberOfCalls(t, "TokenSource", 1) - authInfoQuery := &login.GetAuthInfoQuery{} + authInfoQuery := &login.GetAuthInfoQuery{UserId: 1} resultUsr, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery) - - assert.Nil(t, err) + require.Nil(t, err) // User's token data had not been updated assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken) @@ -175,6 +178,8 @@ func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) { usr := &login.UserAuth{ AuthModule: "oauth_generic_oauth", + UserId: 1, + AuthId: "test", OAuthAccessToken: token.AccessToken, OAuthRefreshToken: token.RefreshToken, OAuthExpiry: token.Expiry, @@ -187,13 +192,13 @@ func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) { err := srv.TryTokenRefresh(ctx, usr) - assert.Nil(t, err) + require.Nil(t, err) socialConnector.AssertNumberOfCalls(t, "TokenSource", 1) - authInfoQuery := &login.GetAuthInfoQuery{} + authInfoQuery := &login.GetAuthInfoQuery{UserId: 1} authInfo, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery) - assert.Nil(t, err) + require.Nil(t, err) // newToken should be returned after the .Token() call, therefore the User had to be updated assert.Equal(t, authInfo.OAuthAccessToken, newToken.AccessToken) @@ -229,7 +234,8 @@ func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *social } authInfoStore := &FakeAuthInfoStore{} - authInfoService := authinfoimpl.ProvideService(authInfoStore) + authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(), + secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore())) return &Service{ Cfg: setting.NewCfg(), SocialService: socialService,