Zanzana: Lazy load cached store info (#96452)

* Lazy load cached store infos
This commit is contained in:
Karl Persson 2024-11-15 11:44:34 +01:00 committed by GitHub
parent 76a3d79231
commit 7e38fd733b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 64 additions and 151 deletions

View File

@ -1,7 +1,6 @@
package server package server
import ( import (
"errors"
"sync" "sync"
authzv1 "github.com/grafana/authlib/authz/proto/v1" authzv1 "github.com/grafana/authlib/authz/proto/v1"
@ -25,19 +24,16 @@ var _ authzextv1.AuthzExtentionServiceServer = (*Server)(nil)
var tracer = otel.Tracer("github.com/grafana/grafana/pkg/services/authz/zanzana/server") var tracer = otel.Tracer("github.com/grafana/grafana/pkg/services/authz/zanzana/server")
var errStoreNotFound = errors.New("store not found")
var errAuthorizationModelNotInitialized = errors.New("authorization model not initialized")
type Server struct { type Server struct {
authzv1.UnimplementedAuthzServiceServer authzv1.UnimplementedAuthzServiceServer
authzextv1.UnimplementedAuthzExtentionServiceServer authzextv1.UnimplementedAuthzExtentionServiceServer
openfga openfgav1.OpenFGAServiceServer openfga openfgav1.OpenFGAServiceServer
logger log.Logger logger log.Logger
modules []transformer.ModuleFile modules []transformer.ModuleFile
storeMap map[string]storeInfo stores map[string]storeInfo
storeLock *sync.Mutex storesMU *sync.Mutex
} }
type storeInfo struct { type storeInfo struct {
@ -65,9 +61,9 @@ func NewAuthzServer(cfg *setting.Cfg, openfga openfgav1.OpenFGAServiceServer) (*
func NewAuthz(openfga openfgav1.OpenFGAServiceServer, opts ...ServerOption) (*Server, error) { func NewAuthz(openfga openfgav1.OpenFGAServiceServer, opts ...ServerOption) (*Server, error) {
s := &Server{ s := &Server{
openfga: openfga, openfga: openfga,
storeLock: &sync.Mutex{}, storesMU: &sync.Mutex{},
storeMap: make(map[string]storeInfo), stores: make(map[string]storeInfo),
} }
for _, o := range opts { for _, o := range opts {

View File

@ -21,7 +21,7 @@ func (s *Server) Check(ctx context.Context, r *authzv1.CheckRequest) (*authzv1.C
} }
func (s *Server) checkTyped(ctx context.Context, r *authzv1.CheckRequest, info common.TypeInfo) (*authzv1.CheckResponse, error) { func (s *Server) checkTyped(ctx context.Context, r *authzv1.CheckRequest, info common.TypeInfo) (*authzv1.CheckResponse, error) {
storeInf, err := s.getNamespaceStore(ctx, r.Namespace) storeInf, err := s.getStoreInfo(ctx, r.Namespace)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -64,7 +64,7 @@ func (s *Server) checkTyped(ctx context.Context, r *authzv1.CheckRequest, info c
} }
func (s *Server) checkGeneric(ctx context.Context, r *authzv1.CheckRequest) (*authzv1.CheckResponse, error) { func (s *Server) checkGeneric(ctx context.Context, r *authzv1.CheckRequest) (*authzv1.CheckResponse, error) {
storeInf, err := s.getNamespaceStore(ctx, r.Namespace) storeInf, err := s.getStoreInfo(ctx, r.Namespace)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -24,7 +24,7 @@ func (s *Server) List(ctx context.Context, r *authzextv1.ListRequest) (*authzext
} }
func (s *Server) listTyped(ctx context.Context, r *authzextv1.ListRequest, info common.TypeInfo) (*authzextv1.ListResponse, error) { func (s *Server) listTyped(ctx context.Context, r *authzextv1.ListRequest, info common.TypeInfo) (*authzextv1.ListResponse, error) {
storeInf, err := s.getNamespaceStore(ctx, r.Namespace) storeInf, err := s.getStoreInfo(ctx, r.Namespace)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -67,7 +67,7 @@ func (s *Server) listTyped(ctx context.Context, r *authzextv1.ListRequest, info
} }
func (s *Server) listGeneric(ctx context.Context, r *authzextv1.ListRequest) (*authzextv1.ListResponse, error) { func (s *Server) listGeneric(ctx context.Context, r *authzextv1.ListRequest) (*authzextv1.ListResponse, error) {
storeInf, err := s.getNamespaceStore(ctx, r.Namespace) storeInf, err := s.getStoreInfo(ctx, r.Namespace)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -13,7 +13,7 @@ func (s *Server) Read(ctx context.Context, req *authzextv1.ReadRequest) (*authze
ctx, span := tracer.Start(ctx, "authzServer.Read") ctx, span := tracer.Start(ctx, "authzServer.Read")
defer span.End() defer span.End()
storeInf, err := s.getNamespaceStore(ctx, req.Namespace) storeInf, err := s.getStoreInfo(ctx, req.Namespace)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -2,7 +2,6 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
openfgav1 "github.com/openfga/api/proto/openfga/v1" openfgav1 "github.com/openfga/api/proto/openfga/v1"
@ -12,66 +11,35 @@ import (
"github.com/grafana/grafana/pkg/services/authz/zanzana/schema" "github.com/grafana/grafana/pkg/services/authz/zanzana/schema"
) )
func (s *Server) getOrCreateStore(ctx context.Context, namespace string) (*openfgav1.Store, error) { func (s *Server) getStoreInfo(ctx context.Context, namespace string) (*storeInfo, error) {
store, err := s.getStore(ctx, namespace) s.storesMU.Lock()
defer s.storesMU.Unlock()
if errors.Is(err, errStoreNotFound) { info, ok := s.stores[namespace]
var res *openfgav1.CreateStoreResponse if ok {
res, err = s.openfga.CreateStore(ctx, &openfgav1.CreateStoreRequest{Name: namespace}) return &info, nil
if res != nil {
store = &openfgav1.Store{
Id: res.GetId(),
Name: res.GetName(),
CreatedAt: res.GetCreatedAt(),
}
s.storeMap[res.GetName()] = storeInfo{
Id: res.GetId(),
}
}
} }
return store, err store, err := s.getOrCreateStore(ctx, namespace)
} if err != nil {
return nil, err
func (s *Server) getStoreInfo(namespace string) (*storeInfo, error) {
info, ok := s.storeMap[namespace]
if !ok {
return nil, errStoreNotFound
} }
modelID, err := s.loadModel(ctx, store.GetId(), schema.SchemaModules)
if err != nil {
return nil, err
}
info = storeInfo{
Id: store.GetId(),
AuthorizationModelId: modelID,
}
s.stores[namespace] = info
return &info, nil return &info, nil
} }
func (s *Server) getStore(ctx context.Context, namespace string) (*openfgav1.Store, error) { func (s *Server) getOrCreateStore(ctx context.Context, namespace string) (*openfgav1.Store, error) {
if len(s.storeMap) == 0 {
err := s.initStores(ctx)
if err != nil {
return nil, err
}
}
storeInf, err := s.getStoreInfo(namespace)
if err != nil {
return nil, err
}
res, err := s.openfga.GetStore(ctx, &openfgav1.GetStoreRequest{
StoreId: storeInf.Id,
})
if err != nil {
return nil, err
}
store := &openfgav1.Store{
Id: res.GetId(),
Name: res.GetName(),
CreatedAt: res.GetCreatedAt(),
}
return store, nil
}
func (s *Server) initStores(ctx context.Context) error {
var continuationToken string var continuationToken string
for { for {
@ -81,13 +49,12 @@ func (s *Server) initStores(ctx context.Context) error {
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to load zanzana stores: %w", err) return nil, fmt.Errorf("failed to load zanzana stores: %w", err)
} }
for _, store := range res.GetStores() { for _, s := range res.GetStores() {
name := store.GetName() if s.GetName() == namespace {
s.storeMap[name] = storeInfo{ return s, nil
Id: store.GetId(),
} }
} }
@ -99,10 +66,18 @@ func (s *Server) initStores(ctx context.Context) error {
continuationToken = res.GetContinuationToken() continuationToken = res.GetContinuationToken()
} }
return nil res, err := s.openfga.CreateStore(ctx, &openfgav1.CreateStoreRequest{Name: namespace})
if err != nil {
return nil, err
}
return &openfgav1.Store{
Id: res.GetId(),
Name: res.GetName(),
}, nil
} }
func (s *Server) loadModel(ctx context.Context, namespace string, modules []transformer.ModuleFile) (string, error) { func (s *Server) loadModel(ctx context.Context, storeID string, modules []transformer.ModuleFile) (string, error) {
var continuationToken string var continuationToken string
model, err := schema.TransformModulesToModel(modules) model, err := schema.TransformModulesToModel(modules)
@ -110,41 +85,27 @@ func (s *Server) loadModel(ctx context.Context, namespace string, modules []tran
return "", err return "", err
} }
store, err := s.getStore(ctx, namespace) // ReadAuthorizationModels returns authorization models for a store sorted in descending order of creation.
// So with a pageSize of 1 we will get the latest model.
res, err := s.openfga.ReadAuthorizationModels(ctx, &openfgav1.ReadAuthorizationModelsRequest{
StoreId: storeID,
PageSize: &wrapperspb.Int32Value{Value: 1},
ContinuationToken: continuationToken,
})
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to load authorization model: %w", err)
} }
for { for _, m := range res.GetAuthorizationModels() {
// ReadAuthorizationModels returns authorization models for a store sorted in descending order of creation. // If provided dsl is equal to a stored dsl we use that as the authorization id
// So with a pageSize of 1 we will get the latest model. if schema.EqualModels(m, model) {
res, err := s.openfga.ReadAuthorizationModels(ctx, &openfgav1.ReadAuthorizationModelsRequest{ return m.GetId(), nil
StoreId: store.GetId(),
PageSize: &wrapperspb.Int32Value{Value: 20},
ContinuationToken: continuationToken,
})
if err != nil {
return "", fmt.Errorf("failed to load authorization model: %w", err)
} }
for _, m := range res.GetAuthorizationModels() {
// If provided dsl is equal to a stored dsl we use that as the authorization id
if schema.EqualModels(m, model) {
return m.GetId(), nil
}
}
// If we have not found any matching authorization model we break the loop and create a new one
if res.GetContinuationToken() == "" {
break
}
continuationToken = res.GetContinuationToken()
} }
writeRes, err := s.openfga.WriteAuthorizationModel(ctx, &openfgav1.WriteAuthorizationModelRequest{ writeRes, err := s.openfga.WriteAuthorizationModel(ctx, &openfgav1.WriteAuthorizationModelRequest{
StoreId: store.GetId(), StoreId: storeID,
TypeDefinitions: model.GetTypeDefinitions(), TypeDefinitions: model.GetTypeDefinitions(),
SchemaVersion: model.GetSchemaVersion(), SchemaVersion: model.GetSchemaVersion(),
Conditions: model.GetConditions(), Conditions: model.GetConditions(),
@ -156,44 +117,3 @@ func (s *Server) loadModel(ctx context.Context, namespace string, modules []tran
return writeRes.GetAuthorizationModelId(), nil return writeRes.GetAuthorizationModelId(), nil
} }
func (s *Server) getNamespaceStore(ctx context.Context, namespace string) (*storeInfo, error) {
var storeInf *storeInfo
var err error
s.storeLock.Lock()
defer s.storeLock.Unlock()
storeInf, err = s.getStoreInfo(namespace)
if errors.Is(err, errStoreNotFound) || storeInf.AuthorizationModelId == "" {
storeInf, err = s.initNamespaceStore(ctx, namespace)
}
if err != nil {
return nil, err
}
return storeInf, nil
}
func (s *Server) initNamespaceStore(ctx context.Context, namespace string) (*storeInfo, error) {
store, err := s.getOrCreateStore(ctx, namespace)
if err != nil {
return nil, err
}
modules := schema.SchemaModules
modelID, err := s.loadModel(ctx, namespace, modules)
if err != nil {
return nil, err
}
if info, ok := s.storeMap[store.GetName()]; ok {
s.storeMap[store.GetName()] = storeInfo{
Id: info.Id,
AuthorizationModelId: modelID,
}
}
updatedInfo := s.storeMap[store.GetName()]
return &updatedInfo, nil
}

View File

@ -62,7 +62,7 @@ func setup(t *testing.T, testDB db.DB, cfg *setting.Cfg) *Server {
require.NoError(t, err) require.NoError(t, err)
namespace := "default" namespace := "default"
storeInf, err := srv.initNamespaceStore(context.Background(), namespace) storeInf, err := srv.getStoreInfo(context.Background(), namespace)
require.NoError(t, err) require.NoError(t, err)
// seed tuples // seed tuples

View File

@ -13,13 +13,10 @@ func (s *Server) Write(ctx context.Context, req *authzextv1.WriteRequest) (*auth
ctx, span := tracer.Start(ctx, "authzServer.Write") ctx, span := tracer.Start(ctx, "authzServer.Write")
defer span.End() defer span.End()
storeInf, err := s.getNamespaceStore(ctx, req.Namespace) storeInf, err := s.getStoreInfo(ctx, req.Namespace)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if storeInf.AuthorizationModelId == "" {
return nil, errAuthorizationModelNotInitialized
}
writeTuples := make([]*openfgav1.TupleKey, 0) writeTuples := make([]*openfgav1.TupleKey, 0)
for _, t := range req.GetWrites().GetTupleKeys() { for _, t := range req.GetWrites().GetTupleKeys() {