mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Cache: Refactor cache clients to use byte array (#62930)
Signed-off-by: bergquist <carl.bergquist@gmail.com>
This commit is contained in:
parent
2804acd264
commit
b88206d98f
@ -612,7 +612,8 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
h, err := authproxy.HashCacheKey(hdrName + "-" + group)
|
h, err := authproxy.HashCacheKey(hdrName + "-" + group)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
key := fmt.Sprintf(authproxy.CachePrefix, h)
|
key := fmt.Sprintf(authproxy.CachePrefix, h)
|
||||||
err = sc.remoteCacheService.Set(context.Background(), key, userID, 0)
|
userIdBytes := []byte(strconv.FormatInt(userID, 10))
|
||||||
|
err = sc.remoteCacheService.SetByteArray(context.Background(), key, userIdBytes, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
|
|
||||||
|
@ -171,8 +171,8 @@ func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) {
|
|||||||
var jwks keySetJWKS
|
var jwks keySetJWKS
|
||||||
|
|
||||||
if ks.cacheExpiration > 0 {
|
if ks.cacheExpiration > 0 {
|
||||||
if val, err := ks.cache.Get(ctx, ks.cacheKey); err == nil {
|
if val, err := ks.cache.GetByteArray(ctx, ks.cacheKey); err == nil {
|
||||||
err := json.Unmarshal(val.([]byte), &jwks)
|
err := json.Unmarshal(val, &jwks)
|
||||||
return jwks, err
|
return jwks, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -200,7 +200,7 @@ func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ks.cacheExpiration > 0 {
|
if ks.cacheExpiration > 0 {
|
||||||
err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), ks.cacheExpiration)
|
err = ks.cache.SetByteArray(ctx, ks.cacheKey, jsonBuf.Bytes(), ks.cacheExpiration)
|
||||||
}
|
}
|
||||||
return jwks, err
|
return jwks, err
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -53,7 +54,8 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
|
|||||||
key := fmt.Sprintf(authproxy.CachePrefix, h)
|
key := fmt.Sprintf(authproxy.CachePrefix, h)
|
||||||
|
|
||||||
t.Logf("Injecting stale user ID in cache with key %q", key)
|
t.Logf("Injecting stale user ID in cache with key %q", key)
|
||||||
err = svc.RemoteCache.Set(context.Background(), key, int64(33), 0)
|
userIdPayload := []byte(strconv.FormatInt(int64(33), 10))
|
||||||
|
err = svc.RemoteCache.SetByteArray(context.Background(), key, userIdPayload, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
authEnabled := svc.initContextWithAuthProxy(ctx, orgID)
|
authEnabled := svc.initContextWithAuthProxy(ctx, orgID)
|
||||||
@ -62,9 +64,13 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
|
|||||||
require.Equal(t, userID, ctx.SignedInUser.UserID)
|
require.Equal(t, userID, ctx.SignedInUser.UserID)
|
||||||
require.True(t, ctx.IsSignedIn)
|
require.True(t, ctx.IsSignedIn)
|
||||||
|
|
||||||
i, err := svc.RemoteCache.Get(context.Background(), key)
|
cachedByteArray, err := svc.RemoteCache.GetByteArray(context.Background(), key)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, userID, i.(int64))
|
|
||||||
|
cacheUserId, err := strconv.ParseInt(string(cachedByteArray), 10, 64)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, userID, cacheUserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeRenderService struct {
|
type fakeRenderService struct {
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"net/mail"
|
"net/mail"
|
||||||
"path"
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -201,14 +202,19 @@ func (auth *AuthProxy) getUserViaCache(reqCtx *contextmodel.ReqContext) (int64,
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
|
auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
|
||||||
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), cacheKey)
|
cachedValue, err := auth.remoteCache.GetByteArray(reqCtx.Req.Context(), cacheKey)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
userId, err := strconv.ParseInt(string(cachedValue), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.logger.Debug("Failed getting user ID via auth cache", "error", err)
|
auth.logger.Debug("Failed getting user ID via auth cache", "error", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.logger.Debug("Successfully got user ID via auth cache", "id", userID)
|
auth.logger.Debug("Successfully got user ID via auth cache", "id", cachedValue)
|
||||||
return userID.(int64), nil
|
return userId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveUserFromCache removes user from cache.
|
// RemoveUserFromCache removes user from cache.
|
||||||
@ -363,14 +369,15 @@ func (auth *AuthProxy) Remember(reqCtx *contextmodel.ReqContext, id int64) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if user already in cache
|
// Check if user already in cache
|
||||||
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), key)
|
cachedValue, err := auth.remoteCache.GetByteArray(reqCtx.Req.Context(), key)
|
||||||
if err == nil && userID != nil {
|
if err == nil && len(cachedValue) != 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
|
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
|
||||||
|
|
||||||
if err := auth.remoteCache.Set(reqCtx.Req.Context(), key, id, expiration); err != nil {
|
userIdPayload := []byte(strconv.FormatInt(id, 10))
|
||||||
|
if err := auth.remoteCache.SetByteArray(reqCtx.Req.Context(), key, userIdPayload, expiration); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -60,7 +61,8 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
h, err := HashCacheKey(hdrName)
|
h, err := HashCacheKey(hdrName)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
key := fmt.Sprintf(CachePrefix, h)
|
key := fmt.Sprintf(CachePrefix, h)
|
||||||
err = cache.Set(context.Background(), key, id, 0)
|
userIdPayload := []byte(strconv.FormatInt(id, 10))
|
||||||
|
err = cache.SetByteArray(context.Background(), key, userIdPayload, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Set up the middleware
|
// Set up the middleware
|
||||||
auth, reqCtx := prepareMiddleware(t, cache, nil)
|
auth, reqCtx := prepareMiddleware(t, cache, nil)
|
||||||
@ -82,7 +84,8 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
h, err := HashCacheKey(hdrName + "-" + group + "-" + role)
|
h, err := HashCacheKey(hdrName + "-" + group + "-" + role)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
key := fmt.Sprintf(CachePrefix, h)
|
key := fmt.Sprintf(CachePrefix, h)
|
||||||
err = cache.Set(context.Background(), key, id, 0)
|
userIdPayload := []byte(strconv.FormatInt(id, 10))
|
||||||
|
err = cache.SetByteArray(context.Background(), key, userIdPayload, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package rendering
|
package rendering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/gob"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -13,33 +15,38 @@ import (
|
|||||||
const renderKeyPrefix = "render-%s"
|
const renderKeyPrefix = "render-%s"
|
||||||
|
|
||||||
type RenderUser struct {
|
type RenderUser struct {
|
||||||
OrgID int64
|
OrgID int64 `json:"org_id"`
|
||||||
UserID int64
|
UserID int64 `json:"user_id"`
|
||||||
OrgRole string
|
OrgRole string `json:"org_role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RenderingService) GetRenderUser(ctx context.Context, key string) (*RenderUser, bool) {
|
func (rs *RenderingService) GetRenderUser(ctx context.Context, key string) (*RenderUser, bool) {
|
||||||
val, err := rs.RemoteCacheService.Get(ctx, fmt.Sprintf(renderKeyPrefix, key))
|
val, err := rs.RemoteCacheService.GetByteArray(ctx, fmt.Sprintf(renderKeyPrefix, key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rs.log.Error("Failed to get render key from cache", "error", err)
|
rs.log.Error("Failed to get render key from cache", "error", err)
|
||||||
}
|
}
|
||||||
|
ru := &RenderUser{}
|
||||||
if val != nil {
|
buf := bytes.NewBuffer(val)
|
||||||
if user, ok := val.(*RenderUser); ok {
|
err = gob.NewDecoder(buf).Decode(&ru)
|
||||||
return user, true
|
if err != nil {
|
||||||
}
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false
|
return ru, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func setRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, renderKey string, expiry time.Duration) error {
|
func setRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, renderKey string, expiry time.Duration) error {
|
||||||
err := cache.Set(ctx, fmt.Sprintf(renderKeyPrefix, renderKey), &RenderUser{
|
buf := bytes.NewBuffer(nil)
|
||||||
|
err := gob.NewEncoder(buf).Encode(&RenderUser{
|
||||||
OrgID: opts.OrgID,
|
OrgID: opts.OrgID,
|
||||||
UserID: opts.UserID,
|
UserID: opts.UserID,
|
||||||
OrgRole: string(opts.OrgRole),
|
OrgRole: string(opts.OrgRole),
|
||||||
}, expiry)
|
})
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cache.SetByteArray(ctx, fmt.Sprintf(renderKeyPrefix, renderKey), buf.Bytes(), expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateAndSetRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, expiry time.Duration) (string, error) {
|
func generateAndSetRenderKey(cache *remotecache.RemoteCache, ctx context.Context, opts AuthOpts, expiry time.Duration) (string, error) {
|
||||||
|
@ -2,6 +2,7 @@ package rendering
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/gob"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
@ -22,10 +23,6 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/util"
|
"github.com/grafana/grafana/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
remotecache.Register(&RenderUser{})
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Service = (*RenderingService)(nil)
|
var _ Service = (*RenderingService)(nil)
|
||||||
|
|
||||||
const ServiceName = "RenderingService"
|
const ServiceName = "RenderingService"
|
||||||
@ -113,6 +110,9 @@ func ProvideService(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, rm p
|
|||||||
domain: domain,
|
domain: domain,
|
||||||
sanitizeURL: sanitizeURL,
|
sanitizeURL: sanitizeURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gob.Register(&RenderUser{})
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user