From b88206d98f064478bc6bbe4b5a195235cb128a77 Mon Sep 17 00:00:00 2001 From: Carl Bergquist Date: Wed, 8 Feb 2023 10:30:20 +0100 Subject: [PATCH] Cache: Refactor cache clients to use byte array (#62930) Signed-off-by: bergquist --- pkg/middleware/middleware_test.go | 3 +- pkg/services/auth/jwt/key_sets.go | 6 ++-- .../contexthandler/auth_proxy_test.go | 12 +++++-- .../contexthandler/authproxy/authproxy.go | 19 +++++++---- .../authproxy/authproxy_test.go | 7 ++-- pkg/services/rendering/auth.go | 33 +++++++++++-------- pkg/services/rendering/rendering.go | 8 ++--- 7 files changed, 56 insertions(+), 32 deletions(-) diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index 98d34f4d498..e8dab37c49a 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -612,7 +612,8 @@ func TestMiddlewareContext(t *testing.T) { h, err := authproxy.HashCacheKey(hdrName + "-" + group) require.NoError(t, err) key := fmt.Sprintf(authproxy.CachePrefix, h) - err = sc.remoteCacheService.Set(context.Background(), key, userID, 0) + userIdBytes := []byte(strconv.FormatInt(userID, 10)) + err = sc.remoteCacheService.SetByteArray(context.Background(), key, userIdBytes, 0) require.NoError(t, err) sc.fakeReq("GET", "/") diff --git a/pkg/services/auth/jwt/key_sets.go b/pkg/services/auth/jwt/key_sets.go index 175b60c14d2..5ebb6cb9d0d 100644 --- a/pkg/services/auth/jwt/key_sets.go +++ b/pkg/services/auth/jwt/key_sets.go @@ -171,8 +171,8 @@ func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) { var jwks keySetJWKS if ks.cacheExpiration > 0 { - if val, err := ks.cache.Get(ctx, ks.cacheKey); err == nil { - err := json.Unmarshal(val.([]byte), &jwks) + if val, err := ks.cache.GetByteArray(ctx, ks.cacheKey); err == nil { + err := json.Unmarshal(val, &jwks) return jwks, err } } @@ -200,7 +200,7 @@ func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) { } if ks.cacheExpiration > 0 { - err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), ks.cacheExpiration) + err = ks.cache.SetByteArray(ctx, ks.cacheKey, jsonBuf.Bytes(), ks.cacheExpiration) } return jwks, err } diff --git a/pkg/services/contexthandler/auth_proxy_test.go b/pkg/services/contexthandler/auth_proxy_test.go index 773e6b36ee0..18c1b479609 100644 --- a/pkg/services/contexthandler/auth_proxy_test.go +++ b/pkg/services/contexthandler/auth_proxy_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strconv" "testing" "github.com/stretchr/testify/require" @@ -53,7 +54,8 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { key := fmt.Sprintf(authproxy.CachePrefix, h) t.Logf("Injecting stale user ID in cache with key %q", key) - err = svc.RemoteCache.Set(context.Background(), key, int64(33), 0) + userIdPayload := []byte(strconv.FormatInt(int64(33), 10)) + err = svc.RemoteCache.SetByteArray(context.Background(), key, userIdPayload, 0) require.NoError(t, err) authEnabled := svc.initContextWithAuthProxy(ctx, orgID) @@ -62,9 +64,13 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { require.Equal(t, userID, ctx.SignedInUser.UserID) require.True(t, ctx.IsSignedIn) - i, err := svc.RemoteCache.Get(context.Background(), key) + cachedByteArray, err := svc.RemoteCache.GetByteArray(context.Background(), key) require.NoError(t, err) - require.Equal(t, userID, i.(int64)) + + cacheUserId, err := strconv.ParseInt(string(cachedByteArray), 10, 64) + + require.NoError(t, err) + require.Equal(t, userID, cacheUserId) } type fakeRenderService struct { diff --git a/pkg/services/contexthandler/authproxy/authproxy.go b/pkg/services/contexthandler/authproxy/authproxy.go index 95a340c7cc8..321a59f4e44 100644 --- a/pkg/services/contexthandler/authproxy/authproxy.go +++ b/pkg/services/contexthandler/authproxy/authproxy.go @@ -10,6 +10,7 @@ import ( "net/mail" "path" "reflect" + "strconv" "strings" "time" @@ -201,14 +202,19 @@ func (auth *AuthProxy) getUserViaCache(reqCtx *contextmodel.ReqContext) (int64, return 0, err } auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey) - userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), cacheKey) + cachedValue, err := auth.remoteCache.GetByteArray(reqCtx.Req.Context(), cacheKey) + if err != nil { + return 0, err + } + + userId, err := strconv.ParseInt(string(cachedValue), 10, 64) if err != nil { auth.logger.Debug("Failed getting user ID via auth cache", "error", err) return 0, err } - auth.logger.Debug("Successfully got user ID via auth cache", "id", userID) - return userID.(int64), nil + auth.logger.Debug("Successfully got user ID via auth cache", "id", cachedValue) + return userId, nil } // RemoveUserFromCache removes user from cache. @@ -363,14 +369,15 @@ func (auth *AuthProxy) Remember(reqCtx *contextmodel.ReqContext, id int64) error } // Check if user already in cache - userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), key) - if err == nil && userID != nil { + cachedValue, err := auth.remoteCache.GetByteArray(reqCtx.Req.Context(), key) + if err == nil && len(cachedValue) != 0 { return nil } expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute - if err := auth.remoteCache.Set(reqCtx.Req.Context(), key, id, expiration); err != nil { + userIdPayload := []byte(strconv.FormatInt(id, 10)) + if err := auth.remoteCache.SetByteArray(reqCtx.Req.Context(), key, userIdPayload, expiration); err != nil { return err } diff --git a/pkg/services/contexthandler/authproxy/authproxy_test.go b/pkg/services/contexthandler/authproxy/authproxy_test.go index 109f1c878f5..20cad6d7d8f 100644 --- a/pkg/services/contexthandler/authproxy/authproxy_test.go +++ b/pkg/services/contexthandler/authproxy/authproxy_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -60,7 +61,8 @@ func TestMiddlewareContext(t *testing.T) { h, err := HashCacheKey(hdrName) require.NoError(t, err) key := fmt.Sprintf(CachePrefix, h) - err = cache.Set(context.Background(), key, id, 0) + userIdPayload := []byte(strconv.FormatInt(id, 10)) + err = cache.SetByteArray(context.Background(), key, userIdPayload, 0) require.NoError(t, err) // Set up the middleware auth, reqCtx := prepareMiddleware(t, cache, nil) @@ -82,7 +84,8 @@ func TestMiddlewareContext(t *testing.T) { h, err := HashCacheKey(hdrName + "-" + group + "-" + role) require.NoError(t, err) key := fmt.Sprintf(CachePrefix, h) - err = cache.Set(context.Background(), key, id, 0) + userIdPayload := []byte(strconv.FormatInt(id, 10)) + err = cache.SetByteArray(context.Background(), key, userIdPayload, 0) require.NoError(t, err) auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { diff --git a/pkg/services/rendering/auth.go b/pkg/services/rendering/auth.go index f68067c5ed9..307430333a8 100644 --- a/pkg/services/rendering/auth.go +++ b/pkg/services/rendering/auth.go @@ -1,7 +1,9 @@ package rendering import ( + "bytes" "context" + "encoding/gob" "fmt" "time" @@ -13,33 +15,38 @@ import ( const renderKeyPrefix = "render-%s" type RenderUser struct { - OrgID int64 - UserID int64 - OrgRole string + OrgID int64 `json:"org_id"` + UserID int64 `json:"user_id"` + OrgRole string `json:"org_role"` } func (rs *RenderingService) GetRenderUser(ctx context.Context, key string) (*RenderUser, bool) { - val, err := rs.RemoteCacheService.Get(ctx, fmt.Sprintf(renderKeyPrefix, key)) + val, err := rs.RemoteCacheService.GetByteArray(ctx, fmt.Sprintf(renderKeyPrefix, key)) if err != nil { rs.log.Error("Failed to get render key from cache", "error", err) } - - if val != nil { - if user, ok := val.(*RenderUser); ok { - return user, true - } + ru := &RenderUser{} + buf := bytes.NewBuffer(val) + err = gob.NewDecoder(buf).Decode(&ru) + if err != nil { + return nil, false } - return nil, false + return ru, true } func setRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, renderKey string, expiry time.Duration) error { - err := cache.Set(ctx, fmt.Sprintf(renderKeyPrefix, renderKey), &RenderUser{ + buf := bytes.NewBuffer(nil) + err := gob.NewEncoder(buf).Encode(&RenderUser{ OrgID: opts.OrgID, UserID: opts.UserID, OrgRole: string(opts.OrgRole), - }, expiry) - return err + }) + if err != nil { + return err + } + + return cache.SetByteArray(ctx, fmt.Sprintf(renderKeyPrefix, renderKey), buf.Bytes(), expiry) } func generateAndSetRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, expiry time.Duration) (string, error) { diff --git a/pkg/services/rendering/rendering.go b/pkg/services/rendering/rendering.go index 3b68e8be9f2..d2cc08f127a 100644 --- a/pkg/services/rendering/rendering.go +++ b/pkg/services/rendering/rendering.go @@ -2,6 +2,7 @@ package rendering import ( "context" + "encoding/gob" "errors" "fmt" "math" @@ -22,10 +23,6 @@ import ( "github.com/grafana/grafana/pkg/util" ) -func init() { - remotecache.Register(&RenderUser{}) -} - var _ Service = (*RenderingService)(nil) const ServiceName = "RenderingService" @@ -113,6 +110,9 @@ func ProvideService(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, rm p domain: domain, sanitizeURL: sanitizeURL, } + + gob.Register(&RenderUser{}) + return s, nil }