Cache: Refactor cache clients to use byte array (#62930)

Signed-off-by: bergquist <carl.bergquist@gmail.com>
This commit is contained in:
Carl Bergquist 2023-02-08 10:30:20 +01:00 committed by GitHub
parent 2804acd264
commit b88206d98f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 56 additions and 32 deletions

View File

@ -612,7 +612,8 @@ func TestMiddlewareContext(t *testing.T) {
h, err := authproxy.HashCacheKey(hdrName + "-" + group) h, err := authproxy.HashCacheKey(hdrName + "-" + group)
require.NoError(t, err) require.NoError(t, err)
key := fmt.Sprintf(authproxy.CachePrefix, h) key := fmt.Sprintf(authproxy.CachePrefix, h)
err = sc.remoteCacheService.Set(context.Background(), key, userID, 0) userIdBytes := []byte(strconv.FormatInt(userID, 10))
err = sc.remoteCacheService.SetByteArray(context.Background(), key, userIdBytes, 0)
require.NoError(t, err) require.NoError(t, err)
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")

View File

@ -171,8 +171,8 @@ func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) {
var jwks keySetJWKS var jwks keySetJWKS
if ks.cacheExpiration > 0 { if ks.cacheExpiration > 0 {
if val, err := ks.cache.Get(ctx, ks.cacheKey); err == nil { if val, err := ks.cache.GetByteArray(ctx, ks.cacheKey); err == nil {
err := json.Unmarshal(val.([]byte), &jwks) err := json.Unmarshal(val, &jwks)
return jwks, err return jwks, err
} }
} }
@ -200,7 +200,7 @@ func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) {
} }
if ks.cacheExpiration > 0 { if ks.cacheExpiration > 0 {
err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), ks.cacheExpiration) err = ks.cache.SetByteArray(ctx, ks.cacheKey, jsonBuf.Bytes(), ks.cacheExpiration)
} }
return jwks, err return jwks, err
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -53,7 +54,8 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
key := fmt.Sprintf(authproxy.CachePrefix, h) key := fmt.Sprintf(authproxy.CachePrefix, h)
t.Logf("Injecting stale user ID in cache with key %q", key) t.Logf("Injecting stale user ID in cache with key %q", key)
err = svc.RemoteCache.Set(context.Background(), key, int64(33), 0) userIdPayload := []byte(strconv.FormatInt(int64(33), 10))
err = svc.RemoteCache.SetByteArray(context.Background(), key, userIdPayload, 0)
require.NoError(t, err) require.NoError(t, err)
authEnabled := svc.initContextWithAuthProxy(ctx, orgID) authEnabled := svc.initContextWithAuthProxy(ctx, orgID)
@ -62,9 +64,13 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
require.Equal(t, userID, ctx.SignedInUser.UserID) require.Equal(t, userID, ctx.SignedInUser.UserID)
require.True(t, ctx.IsSignedIn) require.True(t, ctx.IsSignedIn)
i, err := svc.RemoteCache.Get(context.Background(), key) cachedByteArray, err := svc.RemoteCache.GetByteArray(context.Background(), key)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, userID, i.(int64))
cacheUserId, err := strconv.ParseInt(string(cachedByteArray), 10, 64)
require.NoError(t, err)
require.Equal(t, userID, cacheUserId)
} }
type fakeRenderService struct { type fakeRenderService struct {

View File

@ -10,6 +10,7 @@ import (
"net/mail" "net/mail"
"path" "path"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time" "time"
@ -201,14 +202,19 @@ func (auth *AuthProxy) getUserViaCache(reqCtx *contextmodel.ReqContext) (int64,
return 0, err return 0, err
} }
auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey) auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), cacheKey) cachedValue, err := auth.remoteCache.GetByteArray(reqCtx.Req.Context(), cacheKey)
if err != nil {
return 0, err
}
userId, err := strconv.ParseInt(string(cachedValue), 10, 64)
if err != nil { if err != nil {
auth.logger.Debug("Failed getting user ID via auth cache", "error", err) auth.logger.Debug("Failed getting user ID via auth cache", "error", err)
return 0, err return 0, err
} }
auth.logger.Debug("Successfully got user ID via auth cache", "id", userID) auth.logger.Debug("Successfully got user ID via auth cache", "id", cachedValue)
return userID.(int64), nil return userId, nil
} }
// RemoveUserFromCache removes user from cache. // RemoveUserFromCache removes user from cache.
@ -363,14 +369,15 @@ func (auth *AuthProxy) Remember(reqCtx *contextmodel.ReqContext, id int64) error
} }
// Check if user already in cache // Check if user already in cache
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), key) cachedValue, err := auth.remoteCache.GetByteArray(reqCtx.Req.Context(), key)
if err == nil && userID != nil { if err == nil && len(cachedValue) != 0 {
return nil return nil
} }
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
if err := auth.remoteCache.Set(reqCtx.Req.Context(), key, id, expiration); err != nil { userIdPayload := []byte(strconv.FormatInt(id, 10))
if err := auth.remoteCache.SetByteArray(reqCtx.Req.Context(), key, userIdPayload, expiration); err != nil {
return err return err
} }

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -60,7 +61,8 @@ func TestMiddlewareContext(t *testing.T) {
h, err := HashCacheKey(hdrName) h, err := HashCacheKey(hdrName)
require.NoError(t, err) require.NoError(t, err)
key := fmt.Sprintf(CachePrefix, h) key := fmt.Sprintf(CachePrefix, h)
err = cache.Set(context.Background(), key, id, 0) userIdPayload := []byte(strconv.FormatInt(id, 10))
err = cache.SetByteArray(context.Background(), key, userIdPayload, 0)
require.NoError(t, err) require.NoError(t, err)
// Set up the middleware // Set up the middleware
auth, reqCtx := prepareMiddleware(t, cache, nil) auth, reqCtx := prepareMiddleware(t, cache, nil)
@ -82,7 +84,8 @@ func TestMiddlewareContext(t *testing.T) {
h, err := HashCacheKey(hdrName + "-" + group + "-" + role) h, err := HashCacheKey(hdrName + "-" + group + "-" + role)
require.NoError(t, err) require.NoError(t, err)
key := fmt.Sprintf(CachePrefix, h) key := fmt.Sprintf(CachePrefix, h)
err = cache.Set(context.Background(), key, id, 0) userIdPayload := []byte(strconv.FormatInt(id, 10))
err = cache.SetByteArray(context.Background(), key, userIdPayload, 0)
require.NoError(t, err) require.NoError(t, err)
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {

View File

@ -1,7 +1,9 @@
package rendering package rendering
import ( import (
"bytes"
"context" "context"
"encoding/gob"
"fmt" "fmt"
"time" "time"
@ -13,33 +15,38 @@ import (
const renderKeyPrefix = "render-%s" const renderKeyPrefix = "render-%s"
type RenderUser struct { type RenderUser struct {
OrgID int64 OrgID int64 `json:"org_id"`
UserID int64 UserID int64 `json:"user_id"`
OrgRole string OrgRole string `json:"org_role"`
} }
func (rs *RenderingService) GetRenderUser(ctx context.Context, key string) (*RenderUser, bool) { func (rs *RenderingService) GetRenderUser(ctx context.Context, key string) (*RenderUser, bool) {
val, err := rs.RemoteCacheService.Get(ctx, fmt.Sprintf(renderKeyPrefix, key)) val, err := rs.RemoteCacheService.GetByteArray(ctx, fmt.Sprintf(renderKeyPrefix, key))
if err != nil { if err != nil {
rs.log.Error("Failed to get render key from cache", "error", err) rs.log.Error("Failed to get render key from cache", "error", err)
} }
ru := &RenderUser{}
if val != nil { buf := bytes.NewBuffer(val)
if user, ok := val.(*RenderUser); ok { err = gob.NewDecoder(buf).Decode(&ru)
return user, true if err != nil {
} return nil, false
} }
return nil, false return ru, true
} }
func setRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, renderKey string, expiry time.Duration) error { func setRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, renderKey string, expiry time.Duration) error {
err := cache.Set(ctx, fmt.Sprintf(renderKeyPrefix, renderKey), &RenderUser{ buf := bytes.NewBuffer(nil)
err := gob.NewEncoder(buf).Encode(&RenderUser{
OrgID: opts.OrgID, OrgID: opts.OrgID,
UserID: opts.UserID, UserID: opts.UserID,
OrgRole: string(opts.OrgRole), OrgRole: string(opts.OrgRole),
}, expiry) })
return err if err != nil {
return err
}
return cache.SetByteArray(ctx, fmt.Sprintf(renderKeyPrefix, renderKey), buf.Bytes(), expiry)
} }
func generateAndSetRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, expiry time.Duration) (string, error) { func generateAndSetRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, expiry time.Duration) (string, error) {

View File

@ -2,6 +2,7 @@ package rendering
import ( import (
"context" "context"
"encoding/gob"
"errors" "errors"
"fmt" "fmt"
"math" "math"
@ -22,10 +23,6 @@ import (
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
) )
func init() {
remotecache.Register(&RenderUser{})
}
var _ Service = (*RenderingService)(nil) var _ Service = (*RenderingService)(nil)
const ServiceName = "RenderingService" const ServiceName = "RenderingService"
@ -113,6 +110,9 @@ func ProvideService(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, rm p
domain: domain, domain: domain,
sanitizeURL: sanitizeURL, sanitizeURL: sanitizeURL,
} }
gob.Register(&RenderUser{})
return s, nil return s, nil
} }