diff --git a/pkg/services/authz/zanzana/server/server.go b/pkg/services/authz/zanzana/server/server.go index 039af673649..27f12cf6aad 100644 --- a/pkg/services/authz/zanzana/server/server.go +++ b/pkg/services/authz/zanzana/server/server.go @@ -1,7 +1,6 @@ package server import ( - "errors" "sync" authzv1 "github.com/grafana/authlib/authz/proto/v1" @@ -25,19 +24,16 @@ var _ authzextv1.AuthzExtentionServiceServer = (*Server)(nil) var tracer = otel.Tracer("github.com/grafana/grafana/pkg/services/authz/zanzana/server") -var errStoreNotFound = errors.New("store not found") -var errAuthorizationModelNotInitialized = errors.New("authorization model not initialized") - type Server struct { authzv1.UnimplementedAuthzServiceServer authzextv1.UnimplementedAuthzExtentionServiceServer openfga openfgav1.OpenFGAServiceServer - logger log.Logger - modules []transformer.ModuleFile - storeMap map[string]storeInfo - storeLock *sync.Mutex + logger log.Logger + modules []transformer.ModuleFile + stores map[string]storeInfo + storesMU *sync.Mutex } type storeInfo struct { @@ -65,9 +61,9 @@ func NewAuthzServer(cfg *setting.Cfg, openfga openfgav1.OpenFGAServiceServer) (* func NewAuthz(openfga openfgav1.OpenFGAServiceServer, opts ...ServerOption) (*Server, error) { s := &Server{ - openfga: openfga, - storeLock: &sync.Mutex{}, - storeMap: make(map[string]storeInfo), + openfga: openfga, + storesMU: &sync.Mutex{}, + stores: make(map[string]storeInfo), } for _, o := range opts { diff --git a/pkg/services/authz/zanzana/server/server_check.go b/pkg/services/authz/zanzana/server/server_check.go index 5da418c4129..3bb77aaf182 100644 --- a/pkg/services/authz/zanzana/server/server_check.go +++ b/pkg/services/authz/zanzana/server/server_check.go @@ -21,7 +21,7 @@ func (s *Server) Check(ctx context.Context, r *authzv1.CheckRequest) (*authzv1.C } func (s *Server) checkTyped(ctx context.Context, r *authzv1.CheckRequest, info common.TypeInfo) (*authzv1.CheckResponse, error) { - storeInf, err := s.getNamespaceStore(ctx, r.Namespace) + storeInf, err := s.getStoreInfo(ctx, r.Namespace) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func (s *Server) checkTyped(ctx context.Context, r *authzv1.CheckRequest, info c } func (s *Server) checkGeneric(ctx context.Context, r *authzv1.CheckRequest) (*authzv1.CheckResponse, error) { - storeInf, err := s.getNamespaceStore(ctx, r.Namespace) + storeInf, err := s.getStoreInfo(ctx, r.Namespace) if err != nil { return nil, err } diff --git a/pkg/services/authz/zanzana/server/server_list.go b/pkg/services/authz/zanzana/server/server_list.go index c2ea35e4020..a874ac5feee 100644 --- a/pkg/services/authz/zanzana/server/server_list.go +++ b/pkg/services/authz/zanzana/server/server_list.go @@ -24,7 +24,7 @@ func (s *Server) List(ctx context.Context, r *authzextv1.ListRequest) (*authzext } func (s *Server) listTyped(ctx context.Context, r *authzextv1.ListRequest, info common.TypeInfo) (*authzextv1.ListResponse, error) { - storeInf, err := s.getNamespaceStore(ctx, r.Namespace) + storeInf, err := s.getStoreInfo(ctx, r.Namespace) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (s *Server) listTyped(ctx context.Context, r *authzextv1.ListRequest, info } func (s *Server) listGeneric(ctx context.Context, r *authzextv1.ListRequest) (*authzextv1.ListResponse, error) { - storeInf, err := s.getNamespaceStore(ctx, r.Namespace) + storeInf, err := s.getStoreInfo(ctx, r.Namespace) if err != nil { return nil, err } diff --git a/pkg/services/authz/zanzana/server/server_read.go b/pkg/services/authz/zanzana/server/server_read.go index 4c0d238d401..eafecdf5f76 100644 --- a/pkg/services/authz/zanzana/server/server_read.go +++ b/pkg/services/authz/zanzana/server/server_read.go @@ -13,7 +13,7 @@ func (s *Server) Read(ctx context.Context, req *authzextv1.ReadRequest) (*authze ctx, span := tracer.Start(ctx, "authzServer.Read") defer span.End() - storeInf, err := s.getNamespaceStore(ctx, req.Namespace) + storeInf, err := s.getStoreInfo(ctx, req.Namespace) if err != nil { return nil, err } diff --git a/pkg/services/authz/zanzana/server/server_store.go b/pkg/services/authz/zanzana/server/server_store.go index 3e760c377da..8af1ca98ef9 100644 --- a/pkg/services/authz/zanzana/server/server_store.go +++ b/pkg/services/authz/zanzana/server/server_store.go @@ -2,7 +2,6 @@ package server import ( "context" - "errors" "fmt" openfgav1 "github.com/openfga/api/proto/openfga/v1" @@ -12,66 +11,35 @@ import ( "github.com/grafana/grafana/pkg/services/authz/zanzana/schema" ) -func (s *Server) getOrCreateStore(ctx context.Context, namespace string) (*openfgav1.Store, error) { - store, err := s.getStore(ctx, namespace) - - if errors.Is(err, errStoreNotFound) { - var res *openfgav1.CreateStoreResponse - res, err = s.openfga.CreateStore(ctx, &openfgav1.CreateStoreRequest{Name: namespace}) - if res != nil { - store = &openfgav1.Store{ - Id: res.GetId(), - Name: res.GetName(), - CreatedAt: res.GetCreatedAt(), - } - s.storeMap[res.GetName()] = storeInfo{ - Id: res.GetId(), - } - } +func (s *Server) getStoreInfo(ctx context.Context, namespace string) (*storeInfo, error) { + s.storesMU.Lock() + defer s.storesMU.Unlock() + info, ok := s.stores[namespace] + if ok { + return &info, nil } - return store, err -} - -func (s *Server) getStoreInfo(namespace string) (*storeInfo, error) { - info, ok := s.storeMap[namespace] - if !ok { - return nil, errStoreNotFound + store, err := s.getOrCreateStore(ctx, namespace) + if err != nil { + return nil, err } + modelID, err := s.loadModel(ctx, store.GetId(), schema.SchemaModules) + if err != nil { + return nil, err + } + + info = storeInfo{ + Id: store.GetId(), + AuthorizationModelId: modelID, + } + + s.stores[namespace] = info + return &info, nil } -func (s *Server) getStore(ctx context.Context, namespace string) (*openfgav1.Store, error) { - if len(s.storeMap) == 0 { - err := s.initStores(ctx) - if err != nil { - return nil, err - } - } - - storeInf, err := s.getStoreInfo(namespace) - if err != nil { - return nil, err - } - - res, err := s.openfga.GetStore(ctx, &openfgav1.GetStoreRequest{ - StoreId: storeInf.Id, - }) - if err != nil { - return nil, err - } - - store := &openfgav1.Store{ - Id: res.GetId(), - Name: res.GetName(), - CreatedAt: res.GetCreatedAt(), - } - - return store, nil -} - -func (s *Server) initStores(ctx context.Context) error { +func (s *Server) getOrCreateStore(ctx context.Context, namespace string) (*openfgav1.Store, error) { var continuationToken string for { @@ -81,13 +49,12 @@ func (s *Server) initStores(ctx context.Context) error { }) if err != nil { - return fmt.Errorf("failed to load zanzana stores: %w", err) + return nil, fmt.Errorf("failed to load zanzana stores: %w", err) } - for _, store := range res.GetStores() { - name := store.GetName() - s.storeMap[name] = storeInfo{ - Id: store.GetId(), + for _, s := range res.GetStores() { + if s.GetName() == namespace { + return s, nil } } @@ -99,10 +66,18 @@ func (s *Server) initStores(ctx context.Context) error { continuationToken = res.GetContinuationToken() } - return nil + res, err := s.openfga.CreateStore(ctx, &openfgav1.CreateStoreRequest{Name: namespace}) + if err != nil { + return nil, err + } + + return &openfgav1.Store{ + Id: res.GetId(), + Name: res.GetName(), + }, nil } -func (s *Server) loadModel(ctx context.Context, namespace string, modules []transformer.ModuleFile) (string, error) { +func (s *Server) loadModel(ctx context.Context, storeID string, modules []transformer.ModuleFile) (string, error) { var continuationToken string model, err := schema.TransformModulesToModel(modules) @@ -110,41 +85,27 @@ func (s *Server) loadModel(ctx context.Context, namespace string, modules []tran return "", err } - store, err := s.getStore(ctx, namespace) + // ReadAuthorizationModels returns authorization models for a store sorted in descending order of creation. + // So with a pageSize of 1 we will get the latest model. + res, err := s.openfga.ReadAuthorizationModels(ctx, &openfgav1.ReadAuthorizationModelsRequest{ + StoreId: storeID, + PageSize: &wrapperspb.Int32Value{Value: 1}, + ContinuationToken: continuationToken, + }) + if err != nil { - return "", err + return "", fmt.Errorf("failed to load authorization model: %w", err) } - for { - // ReadAuthorizationModels returns authorization models for a store sorted in descending order of creation. - // So with a pageSize of 1 we will get the latest model. - res, err := s.openfga.ReadAuthorizationModels(ctx, &openfgav1.ReadAuthorizationModelsRequest{ - StoreId: store.GetId(), - PageSize: &wrapperspb.Int32Value{Value: 20}, - ContinuationToken: continuationToken, - }) - - if err != nil { - return "", fmt.Errorf("failed to load authorization model: %w", err) + for _, m := range res.GetAuthorizationModels() { + // If provided dsl is equal to a stored dsl we use that as the authorization id + if schema.EqualModels(m, model) { + return m.GetId(), nil } - - for _, m := range res.GetAuthorizationModels() { - // If provided dsl is equal to a stored dsl we use that as the authorization id - if schema.EqualModels(m, model) { - return m.GetId(), nil - } - } - - // If we have not found any matching authorization model we break the loop and create a new one - if res.GetContinuationToken() == "" { - break - } - - continuationToken = res.GetContinuationToken() } writeRes, err := s.openfga.WriteAuthorizationModel(ctx, &openfgav1.WriteAuthorizationModelRequest{ - StoreId: store.GetId(), + StoreId: storeID, TypeDefinitions: model.GetTypeDefinitions(), SchemaVersion: model.GetSchemaVersion(), Conditions: model.GetConditions(), @@ -156,44 +117,3 @@ func (s *Server) loadModel(ctx context.Context, namespace string, modules []tran return writeRes.GetAuthorizationModelId(), nil } - -func (s *Server) getNamespaceStore(ctx context.Context, namespace string) (*storeInfo, error) { - var storeInf *storeInfo - var err error - - s.storeLock.Lock() - defer s.storeLock.Unlock() - - storeInf, err = s.getStoreInfo(namespace) - if errors.Is(err, errStoreNotFound) || storeInf.AuthorizationModelId == "" { - storeInf, err = s.initNamespaceStore(ctx, namespace) - } - if err != nil { - return nil, err - } - - return storeInf, nil -} - -func (s *Server) initNamespaceStore(ctx context.Context, namespace string) (*storeInfo, error) { - store, err := s.getOrCreateStore(ctx, namespace) - if err != nil { - return nil, err - } - - modules := schema.SchemaModules - modelID, err := s.loadModel(ctx, namespace, modules) - if err != nil { - return nil, err - } - - if info, ok := s.storeMap[store.GetName()]; ok { - s.storeMap[store.GetName()] = storeInfo{ - Id: info.Id, - AuthorizationModelId: modelID, - } - } - - updatedInfo := s.storeMap[store.GetName()] - return &updatedInfo, nil -} diff --git a/pkg/services/authz/zanzana/server/server_test.go b/pkg/services/authz/zanzana/server/server_test.go index 4800f38541e..9e04ceb117b 100644 --- a/pkg/services/authz/zanzana/server/server_test.go +++ b/pkg/services/authz/zanzana/server/server_test.go @@ -62,7 +62,7 @@ func setup(t *testing.T, testDB db.DB, cfg *setting.Cfg) *Server { require.NoError(t, err) namespace := "default" - storeInf, err := srv.initNamespaceStore(context.Background(), namespace) + storeInf, err := srv.getStoreInfo(context.Background(), namespace) require.NoError(t, err) // seed tuples diff --git a/pkg/services/authz/zanzana/server/server_write.go b/pkg/services/authz/zanzana/server/server_write.go index 38c36fdebf6..a8c76c0ebe3 100644 --- a/pkg/services/authz/zanzana/server/server_write.go +++ b/pkg/services/authz/zanzana/server/server_write.go @@ -13,13 +13,10 @@ func (s *Server) Write(ctx context.Context, req *authzextv1.WriteRequest) (*auth ctx, span := tracer.Start(ctx, "authzServer.Write") defer span.End() - storeInf, err := s.getNamespaceStore(ctx, req.Namespace) + storeInf, err := s.getStoreInfo(ctx, req.Namespace) if err != nil { return nil, err } - if storeInf.AuthorizationModelId == "" { - return nil, errAuthorizationModelNotInitialized - } writeTuples := make([]*openfgav1.TupleKey, 0) for _, t := range req.GetWrites().GetTupleKeys() {