diff --git a/go.mod b/go.mod index 2b8c07e2a6b..2b4682cfd9a 100644 --- a/go.mod +++ b/go.mod @@ -134,6 +134,7 @@ require ( github.com/oklog/ulid/v2 v2.1.0 // @grafana/identity-access-team github.com/olekukonko/tablewriter v0.0.5 // @grafana/grafana-backend-group github.com/openfga/api/proto v0.0.0-20240529184453-5b0b4941f3e0 // @grafana/identity-access-team + github.com/openfga/language/pkg/go v0.0.0-20240409225820-a53ea2892d6d // @grafana/identity-access-team github.com/openfga/openfga v1.5.4 // @grafana/identity-access-team github.com/patrickmn/go-cache v2.1.0+incompatible // @grafana/alerting-backend github.com/prometheus/alertmanager v0.27.0 // @grafana/alerting-backend @@ -455,7 +456,6 @@ require ( github.com/mfridman/interpolate v0.0.2 // indirect github.com/natefinch/wrap v0.2.0 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect - github.com/openfga/language/pkg/go v0.0.0-20240409225820-a53ea2892d6d // indirect github.com/pelletier/go-toml/v2 v2.1.1 // indirect github.com/pressly/goose/v3 v3.20.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect diff --git a/pkg/services/authz/zanzana.go b/pkg/services/authz/zanzana.go index 5533aa4a0d5..64e7345e0ae 100644 --- a/pkg/services/authz/zanzana.go +++ b/pkg/services/authz/zanzana.go @@ -16,6 +16,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/services/authz/zanzana" + "github.com/grafana/grafana/pkg/services/authz/zanzana/client" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/grpcserver" "github.com/grafana/grafana/pkg/setting" @@ -25,7 +26,7 @@ import ( // It will also start an embedded ZanzanaSever if mode is set to "embedded". func ProvideZanzana(cfg *setting.Cfg, db db.DB, features featuremgmt.FeatureToggles) (zanzana.Client, error) { if !features.IsEnabledGlobally(featuremgmt.FlagZanzana) { - return zanzana.NoopClient{}, nil + return client.NewNoop(), nil } logger := log.New("zanzana") @@ -37,7 +38,11 @@ func ProvideZanzana(cfg *setting.Cfg, db db.DB, features featuremgmt.FeatureTogg if err != nil { return nil, fmt.Errorf("failed to create zanzana client to remote server: %w", err) } - client = zanzana.NewClient(conn) + + client, err = zanzana.NewClient(context.Background(), conn, cfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize zanzana client: %w", err) + } case setting.ZanzanaModeEmbedded: store, err := zanzana.NewEmbeddedStore(cfg, db, logger) if err != nil { @@ -51,7 +56,12 @@ func ProvideZanzana(cfg *setting.Cfg, db db.DB, features featuremgmt.FeatureTogg channel := &inprocgrpc.Channel{} openfgav1.RegisterOpenFGAServiceServer(channel, srv) - client = zanzana.NewClient(channel) + + client, err = zanzana.NewClient(context.Background(), channel, cfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize zanzana client: %w", err) + } + default: return nil, fmt.Errorf("unsupported zanzana mode: %s", cfg.Zanzana.Mode) } diff --git a/pkg/services/authz/zanzana/client.go b/pkg/services/authz/zanzana/client.go index b61cd4b46fd..48f3ea3f1cc 100644 --- a/pkg/services/authz/zanzana/client.go +++ b/pkg/services/authz/zanzana/client.go @@ -2,46 +2,28 @@ package zanzana import ( "context" + "fmt" "google.golang.org/grpc" openfgav1 "github.com/openfga/api/proto/openfga/v1" "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/authz/zanzana/client" + "github.com/grafana/grafana/pkg/setting" ) -// Client is a wrapper around OpenFGAServiceClient with only methods using in Grafana included. +// Client is a wrapper around [openfgav1.OpenFGAServiceClient] type Client interface { Check(ctx context.Context, in *openfgav1.CheckRequest, opts ...grpc.CallOption) (*openfgav1.CheckResponse, error) ListObjects(ctx context.Context, in *openfgav1.ListObjectsRequest, opts ...grpc.CallOption) (*openfgav1.ListObjectsResponse, error) } -type zanzanaClient struct { - client openfgav1.OpenFGAServiceClient - logger log.Logger -} - -func NewClient(cc grpc.ClientConnInterface) Client { - return &zanzanaClient{ - client: openfgav1.NewOpenFGAServiceClient(cc), - logger: log.New("zanzana-client"), - } -} - -func (c *zanzanaClient) Check(ctx context.Context, in *openfgav1.CheckRequest, opts ...grpc.CallOption) (*openfgav1.CheckResponse, error) { - return c.client.Check(ctx, in, opts...) -} - -func (c *zanzanaClient) ListObjects(ctx context.Context, in *openfgav1.ListObjectsRequest, opts ...grpc.CallOption) (*openfgav1.ListObjectsResponse, error) { - return c.client.ListObjects(ctx, in, opts...) -} - -type NoopClient struct{} - -func (nc NoopClient) Check(ctx context.Context, in *openfgav1.CheckRequest, opts ...grpc.CallOption) (*openfgav1.CheckResponse, error) { - return nil, nil -} - -func (nc NoopClient) ListObjects(ctx context.Context, in *openfgav1.ListObjectsRequest, opts ...grpc.CallOption) (*openfgav1.ListObjectsResponse, error) { - return nil, nil +func NewClient(ctx context.Context, cc grpc.ClientConnInterface, cfg *setting.Cfg) (*client.Client, error) { + return client.New( + ctx, + cc, + client.WithTenantID(fmt.Sprintf("stack-%s", cfg.StackID)), + client.WithLogger(log.New("zanzana-client")), + ) } diff --git a/pkg/services/authz/zanzana/client/client.go b/pkg/services/authz/zanzana/client/client.go new file mode 100644 index 00000000000..0f54b1bac51 --- /dev/null +++ b/pkg/services/authz/zanzana/client/client.go @@ -0,0 +1,186 @@ +package client + +import ( + "context" + "errors" + "fmt" + + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/wrapperspb" + + openfgav1 "github.com/openfga/api/proto/openfga/v1" + + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/authz/zanzana/schema" +) + +type ClientOption func(c *Client) + +func WithTenantID(tenantID string) ClientOption { + return func(c *Client) { + c.tenantID = tenantID + } +} + +func WithLogger(logger log.Logger) ClientOption { + return func(c *Client) { + c.logger = logger + } +} + +type Client struct { + logger log.Logger + client openfgav1.OpenFGAServiceClient + tenantID string + storeID string + modelID string +} + +func New(ctx context.Context, cc grpc.ClientConnInterface, opts ...ClientOption) (*Client, error) { + c := &Client{ + client: openfgav1.NewOpenFGAServiceClient(cc), + } + + for _, o := range opts { + o(c) + } + + if c.logger == nil { + c.logger = log.NewNopLogger() + } + + if c.tenantID == "" { + c.tenantID = "stack-default" + } + + store, err := c.getOrCreateStore(ctx, c.tenantID) + if err != nil { + return nil, err + } + + c.storeID = store.GetId() + + modelID, err := c.loadModel(ctx, c.storeID, schema.DSL) + if err != nil { + return nil, err + } + + c.modelID = modelID + + return c, nil +} + +func (c *Client) Check(ctx context.Context, in *openfgav1.CheckRequest, opts ...grpc.CallOption) (*openfgav1.CheckResponse, error) { + return c.client.Check(ctx, in, opts...) +} + +func (c *Client) ListObjects(ctx context.Context, in *openfgav1.ListObjectsRequest, opts ...grpc.CallOption) (*openfgav1.ListObjectsResponse, error) { + return c.client.ListObjects(ctx, in, opts...) +} + +func (c *Client) getOrCreateStore(ctx context.Context, name string) (*openfgav1.Store, error) { + store, err := c.getStore(ctx, name) + + if errors.Is(err, errStoreNotFound) { + var res *openfgav1.CreateStoreResponse + res, err = c.client.CreateStore(ctx, &openfgav1.CreateStoreRequest{Name: name}) + if res != nil { + store = &openfgav1.Store{ + Id: res.GetId(), + Name: res.GetName(), + CreatedAt: res.GetCreatedAt(), + } + } + } + + return store, err +} + +var errStoreNotFound = errors.New("store not found") + +func (c *Client) getStore(ctx context.Context, name string) (*openfgav1.Store, error) { + var continuationToken string + + // OpenFGA client does not support any filters for stores. + // We should create an issue to support some way to get stores by name. + // For now we need to go thourh all stores until we find a match or we hit the end. + for { + res, err := c.client.ListStores(ctx, &openfgav1.ListStoresRequest{ + PageSize: &wrapperspb.Int32Value{Value: 20}, + ContinuationToken: continuationToken, + }) + + if err != nil { + return nil, fmt.Errorf("failed to initiate zanzana tenant: %w", err) + } + + for _, s := range res.GetStores() { + if s.GetName() == name { + return s, nil + } + } + + // we have no more stores to check + if res.GetContinuationToken() == "" { + return nil, errStoreNotFound + } + + continuationToken = res.GetContinuationToken() + } +} + +func (c *Client) loadModel(ctx context.Context, storeID string, dsl string) (string, error) { + var continuationToken string + + for { + // ReadAuthorizationModels returns authorization models for a store sorted in descending order of creation. + // So with a pageSize of 1 we will get the latest model. + res, err := c.client.ReadAuthorizationModels(ctx, &openfgav1.ReadAuthorizationModelsRequest{ + StoreId: storeID, + PageSize: &wrapperspb.Int32Value{Value: 20}, + ContinuationToken: continuationToken, + }) + + if err != nil { + return "", fmt.Errorf("failed to load authorization model: %w", err) + } + + for _, model := range res.GetAuthorizationModels() { + // We need to first convert stored model into dsl and compare it to provided dsl. + storedDSL, err := schema.TransformToDSL(model) + if err != nil { + return "", err + } + + // If provided dsl is equal to a stored dsl we use that as the authorization id + if schema.EqualModels(dsl, storedDSL) { + return res.AuthorizationModels[0].GetId(), nil + } + } + + // If we have not found any matching authorization model we break the loop and create a new one + if res.GetContinuationToken() == "" { + break + } + + continuationToken = res.GetContinuationToken() + } + + model, err := schema.TransformToModel(dsl) + if err != nil { + return "", err + } + + writeRes, err := c.client.WriteAuthorizationModel(ctx, &openfgav1.WriteAuthorizationModelRequest{ + StoreId: c.storeID, + TypeDefinitions: model.GetTypeDefinitions(), + SchemaVersion: model.GetSchemaVersion(), + Conditions: model.GetConditions(), + }) + + if err != nil { + return "", fmt.Errorf("failed to load authorization model: %w", err) + } + + return writeRes.GetAuthorizationModelId(), nil +} diff --git a/pkg/services/authz/zanzana/client/noop.go b/pkg/services/authz/zanzana/client/noop.go new file mode 100644 index 00000000000..ffc795b346c --- /dev/null +++ b/pkg/services/authz/zanzana/client/noop.go @@ -0,0 +1,23 @@ +package client + +import ( + "context" + + "google.golang.org/grpc" + + openfgav1 "github.com/openfga/api/proto/openfga/v1" +) + +func NewNoop() *NoopClient { + return &NoopClient{} +} + +type NoopClient struct{} + +func (nc NoopClient) Check(ctx context.Context, in *openfgav1.CheckRequest, opts ...grpc.CallOption) (*openfgav1.CheckResponse, error) { + return nil, nil +} + +func (nc NoopClient) ListObjects(ctx context.Context, in *openfgav1.ListObjectsRequest, opts ...grpc.CallOption) (*openfgav1.ListObjectsResponse, error) { + return nil, nil +} diff --git a/pkg/services/authz/zanzana/schema/schema.fga b/pkg/services/authz/zanzana/schema/schema.fga new file mode 100644 index 00000000000..0c5bd688d88 --- /dev/null +++ b/pkg/services/authz/zanzana/schema/schema.fga @@ -0,0 +1,24 @@ +model + schema 1.1 + +type instance + +type user + +type org + relations + define instance: [instance] + define member: [user] + define viewer: [user] + +type role + relations + define org: [org] + define instance: [instance] + define assignee: [user, team#member, role#assignee] + +type team + relations + define org: [org] + define admin: [user] + define member: [user] or admin diff --git a/pkg/services/authz/zanzana/schema/schema.go b/pkg/services/authz/zanzana/schema/schema.go new file mode 100644 index 00000000000..ca84e0221a6 --- /dev/null +++ b/pkg/services/authz/zanzana/schema/schema.go @@ -0,0 +1,8 @@ +package schema + +import ( + _ "embed" +) + +//go:embed schema.fga +var DSL string diff --git a/pkg/services/authz/zanzana/schema/transform.go b/pkg/services/authz/zanzana/schema/transform.go new file mode 100644 index 00000000000..e62237398d9 --- /dev/null +++ b/pkg/services/authz/zanzana/schema/transform.go @@ -0,0 +1,43 @@ +package schema + +import ( + _ "embed" + "fmt" + + openfgav1 "github.com/openfga/api/proto/openfga/v1" + language "github.com/openfga/language/pkg/go/transformer" +) + +func TransformToModel(dsl string) (*openfgav1.AuthorizationModel, error) { + parsedAuthModel, err := language.TransformDSLToProto(dsl) + if err != nil { + return nil, fmt.Errorf("failed to transform dsl to model: %w", err) + } + + return parsedAuthModel, nil +} + +func TransformToDSL(model *openfgav1.AuthorizationModel) (string, error) { + return language.TransformJSONProtoToDSL(model) +} + +// FIXME(kalleep): We need to figure out a better way to compare equality of two different +// authorization model. For now the easiest way I found to comparing different schemas was +// to convert them into their json representation but this requires us to first convert dsl into +// openfgav1.AuthorizationModel and then later parse it as json. +// Comparing parsed authorization model with authorization model from store directly by parsing them as +// as json won't work because stored model will have some fields set such as id that are not present in a parsed +// dsl from disk. +func EqualModels(a, b string) bool { + astr, err := language.TransformDSLToJSON(a) + if err != nil { + return false + } + + bstr, err := language.TransformDSLToJSON(b) + if err != nil { + return false + } + + return astr == bstr +} diff --git a/pkg/services/authz/zanzana/schema/transform_test.go b/pkg/services/authz/zanzana/schema/transform_test.go new file mode 100644 index 00000000000..f7a3b108fc3 --- /dev/null +++ b/pkg/services/authz/zanzana/schema/transform_test.go @@ -0,0 +1,131 @@ +package schema + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEqualModels(t *testing.T) { + type testCase struct { + desc string + a string + b string + expected bool + } + + tests := []testCase{ + { + desc: "should be equal", + a: ` +model + schema 1.1 + +type instance + +type user + +type org + relations + define instance: [instance] + define member: [user] + define viewer: [user] + +type role + relations + define org: [org] + define instance: [instance] + define assignee: [user, team#member, role#assignee] + +type team + relations + define org: [org] + define admin: [user] + define member: [user] or org + `, + b: ` +model + schema 1.1 + +type instance + +type user + +type org + relations + define instance: [instance] + define member: [user] + define viewer: [user] + +type role + relations + define org: [org] + define instance: [instance] + define assignee: [user, team#member, role#assignee] + +type team + relations + define org: [org] + define admin: [user] + define member: [user] or org + `, + expected: true, + }, + { + desc: "should not be equal", + a: ` +model + schema 1.1 + +type instance + +type user + +type org + relations + define instance: [instance] + define member: [user] + define viewer: [user] + +type role + relations + define org: [org] + define instance: [instance] + define assignee: [user, team#member, role#assignee] + +type team + relations + define org: [org] + define admin: [user] + define member: [user] or org + `, + b: ` +model + schema 1.1 + +type instance + +type user + +type org + relations + define instance: [instance] + define member: [user] + define viewer: [user] + +type role + relations + define org: [org] + define instance: [instance] + define assignee: [user, team#member, role#assignee] +`, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + assert.Equal(t, tt.expected, EqualModels(tt.a, tt.b)) + }) + } +} diff --git a/pkg/services/authz/zanzana/store.go b/pkg/services/authz/zanzana/store.go index 1166aafeeed..380d3d08ba1 100644 --- a/pkg/services/authz/zanzana/store.go +++ b/pkg/services/authz/zanzana/store.go @@ -115,13 +115,13 @@ func parseConfig(cfg *setting.Cfg, logger log.Logger) (*sqlstore.DatabaseConfig, } zanzanaDBCfg := &sqlcommon.Config{ - Logger: newZanzanaLogger(logger), - // MaxTuplesPerWriteField: 0, - // MaxTypesPerModelField: 0, - MaxOpenConns: grafanaDBCfg.MaxOpenConn, - MaxIdleConns: grafanaDBCfg.MaxIdleConn, - ConnMaxLifetime: time.Duration(grafanaDBCfg.ConnMaxLifetime) * time.Second, - ExportMetrics: sec.Key("instrument_queries").MustBool(false), + Logger: newZanzanaLogger(logger), + MaxTuplesPerWriteField: 100, + MaxTypesPerModelField: 100, + MaxOpenConns: grafanaDBCfg.MaxOpenConn, + MaxIdleConns: grafanaDBCfg.MaxIdleConn, + ConnMaxLifetime: time.Duration(grafanaDBCfg.ConnMaxLifetime) * time.Second, + ExportMetrics: sec.Key("instrument_queries").MustBool(false), } return grafanaDBCfg, zanzanaDBCfg, nil