Zanzana: sqlite data store (#89486)

* Zanzana: Add sqlite3 store

* Zanzana: Initilize sqlite store with migrations
This commit is contained in:
Karl Persson 2024-06-25 09:52:33 +02:00 committed by GitHub
parent fd96edaef7
commit eea7319a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1562 additions and 131 deletions

10
go.mod
View File

@ -30,6 +30,7 @@ require (
github.com/Masterminds/semver v1.5.0 // @grafana/grafana-backend-group
github.com/Masterminds/semver/v3 v3.2.0 // @grafana/grafana-release-guild
github.com/Masterminds/sprig/v3 v3.2.3 // @grafana/grafana-backend-group
github.com/Masterminds/squirrel v1.5.4 // @grafana/identity-access-team
github.com/ProtonMail/go-crypto v0.0.0-20230828082145-3c4c8a2d2371 // @grafana/plugins-platform-backend
github.com/VividCortex/mysqlerr v0.0.0-20170204212430-6c6b55f8796f // @grafana/grafana-backend-group
github.com/alicebob/miniredis/v2 v2.30.1 // @grafana/alerting-backend
@ -129,6 +130,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // @grafana/alerting-backend
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // @grafana/alerting-backend
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f // @grafana/grafana-operator-experience-squad
github.com/oklog/ulid/v2 v2.1.0 // @grafana/identity-access-team
github.com/olekukonko/tablewriter v0.0.5 // @grafana/grafana-backend-group
github.com/openfga/api/proto v0.0.0-20240529184453-5b0b4941f3e0 // @grafana/identity-access-team
github.com/openfga/openfga v1.5.4 // @grafana/identity-access-team
@ -299,7 +301,7 @@ require (
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/grafana/grafana/pkg/storage/unified/resource v0.0.0-20240620152449-c88de7f4d073 // @grafana/grafana-search-and-storage
github.com/grafana/grafana/pkg/storage/unified/resource v0.0.0-20240624122844-a89deaeb7365 // @grafana/grafana-search-and-storage
github.com/grafana/regexp v0.0.0-20221123153739-15dc172cd2db // indirect
github.com/grafana/sqlds/v3 v3.2.0 // indirect
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.1-0.20191002090509-6af20e3a5340 // indirect; @grafana/plugins-platform-backend
@ -436,9 +438,11 @@ require (
)
require (
github.com/Masterminds/squirrel v1.5.4 // indirect
github.com/antlr4-go/antlr/v4 v4.13.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.5 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
@ -451,7 +455,7 @@ require (
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/natefinch/wrap v0.2.0 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/oklog/ulid/v2 v2.1.0 // indirect
github.com/openfga/language/pkg/go v0.0.0-20240409225820-a53ea2892d6d // indirect
github.com/pelletier/go-toml/v2 v2.1.1 // indirect
github.com/pressly/goose/v3 v3.20.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect

4
go.sum
View File

@ -2334,8 +2334,8 @@ github.com/grafana/grafana/pkg/apiserver v0.0.0-20240226124929-648abdbd0ea4 h1:t
github.com/grafana/grafana/pkg/apiserver v0.0.0-20240226124929-648abdbd0ea4/go.mod h1:vpYI6DHvFO595rpQGooUjcyicjt9rOevldDdW79peV0=
github.com/grafana/grafana/pkg/promlib v0.0.6 h1:FuRyHMIgVVXkLuJnCflNfk3gqJflmyiI+/ZuJ9MoAfY=
github.com/grafana/grafana/pkg/promlib v0.0.6/go.mod h1:shFkrG1fQ/PPNRGhxAPNMLp0SAeG/jhqaLoG6n2191M=
github.com/grafana/grafana/pkg/storage/unified/resource v0.0.0-20240620152449-c88de7f4d073 h1:Au8+1QORZnMVo52+19dgkP4zzIlXyJ7x/d9ysSGQHOk=
github.com/grafana/grafana/pkg/storage/unified/resource v0.0.0-20240620152449-c88de7f4d073/go.mod h1:zOInHv2y6bsgm9bIMsCVDaz1XylqIVX9r4amH4iuWPE=
github.com/grafana/grafana/pkg/storage/unified/resource v0.0.0-20240624122844-a89deaeb7365 h1:XRHqYGxjN2+/4QHPoOtr7kYTL9p2P5UxTXfnbiaO/NI=
github.com/grafana/grafana/pkg/storage/unified/resource v0.0.0-20240624122844-a89deaeb7365/go.mod h1:X4dwV2eQI8z8G2aHXvhZZXu/y/rb3psQXuaZa66WZfA=
github.com/grafana/grafana/pkg/util/xorm v0.0.1 h1:72QZjxWIWpSeOF8ob4aMV058kfgZyeetkAB8dmeti2o=
github.com/grafana/grafana/pkg/util/xorm v0.0.1/go.mod h1:eNfbB9f2jM8o9RfwqwjY8SYm5tvowJ8Ly+iE4P9rXII=
github.com/grafana/otel-profiling-go v0.5.1 h1:stVPKAFZSa7eGiqbYuG25VcqYksR6iWvF3YH66t4qL8=

View File

@ -958,6 +958,7 @@ github.com/grafana/grafana-plugin-sdk-go v0.231.1-0.20240523124942-62dae9836284/
github.com/grafana/grafana-plugin-sdk-go v0.234.0/go.mod h1:FlXjmBESxaD6Hoi8ojWLkH007nyjtJM3XC8SpwzF/YE=
github.com/grafana/grafana/pkg/apimachinery v0.0.0-20240613114114-5e2f08de316d/go.mod h1:adT8O7k6ZSzUKjAC4WS6VfWlCE4G1VavPwSXVhvScCs=
github.com/grafana/grafana/pkg/promlib v0.0.3/go.mod h1:3El4NlsfALz8QQCbEGHGFvJUG+538QLMuALRhZ3pcoo=
github.com/grafana/grafana/pkg/storage/unified/resource v0.0.0-20240620152449-c88de7f4d073/go.mod h1:zOInHv2y6bsgm9bIMsCVDaz1XylqIVX9r4amH4iuWPE=
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY=

View File

@ -3,9 +3,10 @@ package zanzana
import (
"context"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"google.golang.org/grpc"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/grafana/grafana/pkg/infra/log"
)

View File

@ -8,7 +8,7 @@ import (
)
func NewServer(store storage.OpenFGADatastore, logger log.Logger) (*server.Server, error) {
// FIXME(kalleep): add support for more options, configure logging, tracing etc
// FIXME(kalleep): add support for more options, tracing etc
opts := []server.OpenFGAServiceV1Option{
server.WithDatastore(store),
server.WithLogger(newZanzanaLogger(logger)),

View File

@ -1,21 +1,22 @@
package zanzana
import (
"errors"
"fmt"
"strings"
"time"
"xorm.io/xorm"
"github.com/openfga/openfga/assets"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/memory"
"github.com/openfga/openfga/pkg/storage/mysql"
"github.com/openfga/openfga/pkg/storage/postgres"
"github.com/openfga/openfga/pkg/storage/sqlcommon"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/log"
zstore "github.com/grafana/grafana/pkg/services/authz/zanzana/store"
"github.com/grafana/grafana/pkg/services/authz/zanzana/store/migration"
"github.com/grafana/grafana/pkg/services/authz/zanzana/store/sqlite"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/setting"
@ -31,19 +32,37 @@ func NewStore(cfg *setting.Cfg, logger log.Logger) (storage.OpenFGADatastore, er
switch grafanaDBCfg.Type {
case migrator.SQLite:
return memory.New(), nil
connStr := grafanaDBCfg.ConnectionString
// Initilize connection using xorm engine so we can reuse it for both migrations and data store
engine, err := xorm.NewEngine(grafanaDBCfg.Type, connStr)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
m := migrator.NewMigrator(engine, cfg)
if err := migration.RunWithMigrator(m, cfg, zstore.EmbedMigrations, zstore.SQLiteMigrationDir); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return sqlite.NewWithDB(engine.DB().DB, &sqlite.Config{
Config: zanzanaDBCfg,
QueryRetries: grafanaDBCfg.QueryRetries,
})
case migrator.MySQL:
// For mysql we need to pass parseTime parameter in connection string
connStr := grafanaDBCfg.ConnectionString + "&parseTime=true"
if err := runMigrations(cfg, migrator.MySQL, connStr, assets.MySQLMigrationDir); err != nil {
if err := migration.Run(cfg, migrator.MySQL, connStr, assets.EmbedMigrations, assets.MySQLMigrationDir); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return mysql.New(connStr, zanzanaDBCfg)
case migrator.Postgres:
if err := runMigrations(cfg, migrator.Postgres, grafanaDBCfg.ConnectionString, assets.PostgresMigrationDir); err != nil {
connStr := grafanaDBCfg.ConnectionString
if err := migration.Run(cfg, migrator.Postgres, connStr, assets.EmbedMigrations, assets.PostgresMigrationDir); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return postgres.New(grafanaDBCfg.ConnectionString, zanzanaDBCfg)
return postgres.New(connStr, zanzanaDBCfg)
}
// Should never happen
@ -60,18 +79,24 @@ func NewEmbeddedStore(cfg *setting.Cfg, db db.DB, logger log.Logger) (storage.Op
switch grafanaDBCfg.Type {
case migrator.SQLite:
// FIXME(kalleep): At the moment sqlite is not a supported data store.
// So we just return in memory store for now.
return memory.New(), nil
if err := migration.RunWithMigrator(m, cfg, zstore.EmbedMigrations, zstore.SQLiteMigrationDir); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
// FIXME(kalleep): We should work on getting sqlite implemtation merged upstream and replace this one
return sqlite.NewWithDB(db.GetEngine().DB().DB, &sqlite.Config{
Config: zanzanaDBCfg,
QueryRetries: grafanaDBCfg.QueryRetries,
})
case migrator.MySQL:
if err := runMigrationsWithMigrator(m, cfg, assets.MySQLMigrationDir); err != nil {
if err := migration.RunWithMigrator(m, cfg, assets.EmbedMigrations, assets.MySQLMigrationDir); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
// For mysql we need to pass parseTime parameter in connection string
return mysql.New(grafanaDBCfg.ConnectionString+"&parseTime=true", zanzanaDBCfg)
case migrator.Postgres:
if err := runMigrationsWithMigrator(m, cfg, assets.PostgresMigrationDir); err != nil {
if err := migration.RunWithMigrator(m, cfg, assets.EmbedMigrations, assets.PostgresMigrationDir); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
@ -101,114 +126,3 @@ func parseConfig(cfg *setting.Cfg, logger log.Logger) (*sqlstore.DatabaseConfig,
return grafanaDBCfg, zanzanaDBCfg, nil
}
func runMigrations(cfg *setting.Cfg, typ, connStr, path string) error {
engine, err := xorm.NewEngine(typ, connStr)
if err != nil {
return fmt.Errorf("failed to parse database config: %w", err)
}
m := migrator.NewMigrator(engine, cfg)
m.AddCreateMigration()
return runMigrationsWithMigrator(m, cfg, path)
}
func runMigrationsWithMigrator(m *migrator.Migrator, cfg *setting.Cfg, path string) error {
migrations, err := getMigrations(path)
if err != nil {
return err
}
for _, mig := range migrations {
m.AddMigration(mig.name, mig.migration)
}
sec := cfg.Raw.Section("database")
return m.Start(
sec.Key("migration_locking").MustBool(true),
sec.Key("locking_attempt_timeout_sec").MustInt(),
)
}
type migration struct {
name string
migration migrator.Migration
}
func getMigrations(path string) ([]migration, error) {
entries, err := assets.EmbedMigrations.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("failed to read migration dir: %w", err)
}
// parseStatements extracts statements from a sql file so we can execute
// them as separate migrations. OpenFGA uses Goose as their migration egine
// and Goose uses a single sql file for both up and down migrations.
// Grafana only supports up migration so we strip out the down migration
// and parse each individual statement
parseStatements := func(data []byte) ([]string, error) {
scripts := strings.Split(strings.TrimPrefix(string(data), "-- +goose Up"), "-- +goose Down")
if len(scripts) != 2 {
return nil, errors.New("malformed migration file")
}
// We assume that up migrations are always before down migrations
parts := strings.SplitAfter(scripts[0], ";")
stmts := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
stmts = append(stmts, p)
}
}
return stmts, nil
}
formatName := func(name string) string {
// Each migration file start with XXX where X is a number.
// We remove that part and prefix each migration with "zanzana".
return strings.TrimSuffix("zanzana"+name[3:], ".sql")
}
migrations := make([]migration, 0, len(entries))
for _, e := range entries {
data, err := assets.EmbedMigrations.ReadFile(path + "/" + e.Name())
if err != nil {
return nil, fmt.Errorf("failed to read migration file: %w", err)
}
stmts, err := parseStatements(data)
if err != nil {
return nil, fmt.Errorf("failed to parse migration: %w", err)
}
migrations = append(migrations, migration{
name: formatName(e.Name()),
migration: &rawMigration{stmts: stmts},
})
}
return migrations, nil
}
var _ migrator.CodeMigration = (*rawMigration)(nil)
type rawMigration struct {
stmts []string
migrator.MigrationBase
}
func (m *rawMigration) Exec(sess *xorm.Session, migrator *migrator.Migrator) error {
for _, stmt := range m.stmts {
if _, err := sess.Exec(stmt); err != nil {
return fmt.Errorf("failed to run migration: %w", err)
}
}
return nil
}
func (m *rawMigration) SQL(dialect migrator.Dialect) string {
return strings.Join(m.stmts, "\n")
}

View File

@ -0,0 +1,10 @@
package store
import "embed"
// EmbedMigrations within the grafana binary.
//
//go:embed migrations/*
var EmbedMigrations embed.FS
const SQLiteMigrationDir = "migrations/sqlite"

View File

@ -0,0 +1,124 @@
package migration
import (
"embed"
"errors"
"fmt"
"strings"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/setting"
)
func Run(cfg *setting.Cfg, typ, connStr string, fs embed.FS, path string) error {
engine, err := xorm.NewEngine(typ, connStr)
if err != nil {
return fmt.Errorf("failed to parse database config: %w", err)
}
m := migrator.NewMigrator(engine, cfg)
m.AddCreateMigration()
return RunWithMigrator(m, cfg, fs, path)
}
func RunWithMigrator(m *migrator.Migrator, cfg *setting.Cfg, fs embed.FS, path string) error {
migrations, err := getMigrations(fs, path)
if err != nil {
return err
}
for _, mig := range migrations {
m.AddMigration(mig.name, mig.migration)
}
sec := cfg.Raw.Section("database")
return m.Start(
sec.Key("migration_locking").MustBool(true),
sec.Key("locking_attempt_timeout_sec").MustInt(),
)
}
type migration struct {
name string
migration migrator.Migration
}
func getMigrations(fs embed.FS, path string) ([]migration, error) {
entries, err := fs.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("failed to read migration dir: %w", err)
}
// parseStatements extracts statements from a sql file so we can execute
// them as separate migrations. OpenFGA uses Goose as their migration egine
// and Goose uses a single sql file for both up and down migrations.
// Grafana only supports up migration so we strip out the down migration
// and parse each individual statement
parseStatements := func(data []byte) ([]string, error) {
scripts := strings.Split(strings.TrimPrefix(string(data), "-- +goose Up"), "-- +goose Down")
if len(scripts) != 2 {
return nil, errors.New("malformed migration file")
}
// We assume that up migrations are always before down migrations
parts := strings.SplitAfter(scripts[0], ";")
stmts := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
stmts = append(stmts, p)
}
}
return stmts, nil
}
formatName := func(name string) string {
// Each migration file start with XXX where X is a number.
// We remove that part and prefix each migration with "zanzana".
return strings.TrimSuffix("zanzana"+name[3:], ".sql")
}
migrations := make([]migration, 0, len(entries))
for _, e := range entries {
data, err := fs.ReadFile(path + "/" + e.Name())
if err != nil {
return nil, fmt.Errorf("failed to read migration file: %w", err)
}
stmts, err := parseStatements(data)
if err != nil {
return nil, fmt.Errorf("failed to parse migration: %w", err)
}
migrations = append(migrations, migration{
name: formatName(e.Name()),
migration: &rawMigration{stmts: stmts},
})
}
return migrations, nil
}
var _ migrator.CodeMigration = (*rawMigration)(nil)
type rawMigration struct {
stmts []string
migrator.MigrationBase
}
func (m *rawMigration) Exec(sess *xorm.Session, migrator *migrator.Migrator) error {
for _, stmt := range m.stmts {
if _, err := sess.Exec(stmt); err != nil {
return fmt.Errorf("failed to run migration: %w", err)
}
}
return nil
}
func (m *rawMigration) SQL(dialect migrator.Dialect) string {
return strings.Join(m.stmts, "\n")
}

View File

@ -0,0 +1,56 @@
-- +goose Up
CREATE TABLE tuple (
store CHAR(26) NOT NULL,
object_type VARCHAR(128) NOT NULL,
object_id VARCHAR(128) NOT NULL,
relation VARCHAR(50) NOT NULL,
_user VARCHAR(256) NOT NULL,
user_type VARCHAR(7) NOT NULL,
ulid CHAR(26) NOT NULL,
inserted_at TIMESTAMP NOT NULL,
PRIMARY KEY (store, object_type, object_id, relation, _user)
);
CREATE UNIQUE INDEX idx_tuple_ulid ON tuple (ulid);
CREATE TABLE authorization_model (
store CHAR(26) NOT NULL,
authorization_model_id CHAR(26) NOT NULL,
type VARCHAR(256) NOT NULL,
type_definition BLOB,
PRIMARY KEY (store, authorization_model_id, type)
);
CREATE TABLE store (
id CHAR(26) PRIMARY KEY,
name VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP,
deleted_at TIMESTAMP
);
CREATE TABLE assertion (
store CHAR(26) NOT NULL,
authorization_model_id CHAR(26) NOT NULL,
assertions BLOB,
PRIMARY KEY (store, authorization_model_id)
);
CREATE TABLE changelog (
store CHAR(26) NOT NULL,
object_type VARCHAR(256) NOT NULL,
object_id VARCHAR(256) NOT NULL,
relation VARCHAR(50) NOT NULL,
_user VARCHAR(512) NOT NULL,
operation INTEGER NOT NULL,
ulid CHAR(26) NOT NULL,
inserted_at TIMESTAMP NOT NULL,
PRIMARY KEY (store, ulid, object_type)
);
-- +goose Down
DROP TABLE tuple;
DROP TABLE authorization_model;
DROP TABLE store;
DROP TABLE assertion;
DROP TABLE changelog;

View File

@ -0,0 +1,5 @@
-- +goose Up
ALTER TABLE authorization_model ADD COLUMN schema_version VARCHAR(5) NOT NULL DEFAULT '1.0';
-- +goose Down
ALTER TABLE authorization_model DROP COLUMN schema_version;

View File

@ -0,0 +1,5 @@
-- +goose Up
CREATE INDEX idx_reverse_lookup_user on tuple (store, object_type, relation, _user);
-- +goose Down
DROP INDEX idx_reverse_lookup_user on tuple;

View File

@ -0,0 +1,5 @@
-- +goose Up
ALTER TABLE authorization_model ADD COLUMN serialized_protobuf LONGBLOB;
-- +goose Down
ALTER TABLE authorization_model DROP COLUMN serialized_protobuf;

View File

@ -0,0 +1,11 @@
-- +goose Up
ALTER TABLE tuple ADD COLUMN condition_name VARCHAR(256);
ALTER TABLE tuple ADD COLUMN condition_context LONGBLOB;
ALTER TABLE changelog ADD COLUMN condition_name VARCHAR(256);
ALTER TABLE changelog ADD COLUMN condition_context LONGBLOB;
-- +goose Down
ALTER TABLE tuple DROP COLUMN condition_name;
ALTER TABLE tuple DROP COLUMN condition_context;
ALTER TABLE changelog DROP COLUMN condition_name;
ALTER TABLE changelog DROP COLUMN condition_context;

View File

@ -0,0 +1,15 @@
package sqlite
import "github.com/openfga/openfga/pkg/storage/sqlcommon"
type Config struct {
*sqlcommon.Config
QueryRetries int
}
func NewConfig() *Config {
return &Config{
Config: sqlcommon.NewConfig(),
QueryRetries: 0,
}
}

View File

@ -0,0 +1,821 @@
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"go.opentelemetry.io/otel"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
// Pull in sqlite driver.
"github.com/mattn/go-sqlite3"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/sqlcommon"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
)
var tracer = otel.Tracer("openfga/pkg/storage/sqlite")
// SQLite provides a SQLite based implementation of [storage.OpenFGADatastore].
type SQLite struct {
stbl sq.StatementBuilderType
cfg *Config
db *sql.DB
dbInfo *sqlcommon.DBInfo
sqlTime sq.Sqlizer
logger logger.Logger
dbStatsCollector prometheus.Collector
}
// Ensures that SQLite implements the OpenFGADatastore interface.
var _ storage.OpenFGADatastore = (*SQLite)(nil)
// New creates a new [SQLite] storage.
func New(uri string, cfg *Config) (*SQLite, error) {
// Set journal mode and busy timeout pragmas if not specified.
query := url.Values{}
var err error
if i := strings.Index(uri, "?"); i != -1 {
query, err = url.ParseQuery(uri[i+1:])
if err != nil {
return nil, fmt.Errorf("error parsing dsn: %w", err)
}
uri = uri[:i]
}
foundJournalMode := false
foundBusyTimeout := false
for _, val := range query["_pragma"] {
if strings.HasPrefix(val, "journal_mode") {
foundJournalMode = true
} else if strings.HasPrefix(val, "busy_timeout") {
foundBusyTimeout = true
}
}
if !foundJournalMode {
query.Add("_pragma", "journal_mode(WAL)")
}
if !foundBusyTimeout {
query.Add("_pragma", "busy_timeout(500)")
}
uri += "?" + query.Encode()
db, err := sql.Open("sqlite", uri)
if err != nil {
return nil, fmt.Errorf("initialize sqlite connection: %w", err)
}
return NewWithDB(db, cfg)
}
// NewWithDB creates a new [SQLite] storage using provided [*sql.DB]
func NewWithDB(db *sql.DB, cfg *Config) (*SQLite, error) {
var collector prometheus.Collector
if cfg.ExportMetrics {
collector = collectors.NewDBStatsCollector(db, "openfga")
if err := prometheus.Register(collector); err != nil {
return nil, fmt.Errorf("initialize metrics: %w", err)
}
}
sqlTime := sq.Expr("datetime('subsec')")
stbl := sq.StatementBuilder.RunWith(db)
dbInfo := sqlcommon.NewDBInfo(db, stbl, sqlTime)
return &SQLite{
cfg: cfg,
stbl: stbl,
db: db,
sqlTime: sqlTime,
dbInfo: dbInfo,
logger: cfg.Logger,
dbStatsCollector: collector,
}, nil
}
// Close see [storage.OpenFGADatastore].Close.
func (m *SQLite) Close() {
if m.dbStatsCollector != nil {
prometheus.Unregister(m.dbStatsCollector)
}
_ = m.db.Close()
}
// Read see [storage.RelationshipTupleReader].Read.
func (m *SQLite) Read(ctx context.Context, store string, tupleKey *openfgav1.TupleKey) (storage.TupleIterator, error) {
ctx, span := tracer.Start(ctx, "sqlite.Read")
defer span.End()
return m.read(ctx, store, tupleKey, nil)
}
// ReadPage see [storage.RelationshipTupleReader].ReadPage.
func (m *SQLite) ReadPage(
ctx context.Context,
store string,
tupleKey *openfgav1.TupleKey,
opts storage.PaginationOptions,
) ([]*openfgav1.Tuple, []byte, error) {
ctx, span := tracer.Start(ctx, "sqlite.ReadPage")
defer span.End()
iter, err := m.read(ctx, store, tupleKey, &opts)
if err != nil {
return nil, nil, err
}
defer iter.Stop()
return iter.ToArray(opts)
}
func (m *SQLite) read(ctx context.Context, store string, tupleKey *openfgav1.TupleKey, opts *storage.PaginationOptions) (*sqlcommon.SQLTupleIterator, error) {
ctx, span := tracer.Start(ctx, "sqlite.read")
defer span.End()
sb := m.stbl.
Select(
"store", "object_type", "object_id", "relation", "_user",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{"store": store})
if opts != nil {
sb = sb.OrderBy("ulid")
}
objectType, objectID := tupleUtils.SplitObject(tupleKey.GetObject())
if objectType != "" {
sb = sb.Where(sq.Eq{"object_type": objectType})
}
if objectID != "" {
sb = sb.Where(sq.Eq{"object_id": objectID})
}
if tupleKey.GetRelation() != "" {
sb = sb.Where(sq.Eq{"relation": tupleKey.GetRelation()})
}
if tupleKey.GetUser() != "" {
sb = sb.Where(sq.Eq{"_user": tupleKey.GetUser()})
}
if opts != nil && opts.From != "" {
token, err := sqlcommon.UnmarshallContToken(opts.From)
if err != nil {
return nil, err
}
sb = sb.Where(sq.GtOrEq{"ulid": token.Ulid})
}
if opts != nil && opts.PageSize != 0 {
sb = sb.Limit(uint64(opts.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, handleSQLError(err)
}
return sqlcommon.NewSQLTupleIterator(rows), nil
}
// Write see [storage.RelationshipTupleWriter].Write.
func (m *SQLite) Write(ctx context.Context, store string, deletes storage.Deletes, writes storage.Writes) error {
ctx, span := tracer.Start(ctx, "sqlite.Write")
defer span.End()
if len(deletes)+len(writes) > m.MaxTuplesPerWrite() {
return storage.ErrExceededWriteBatchLimit
}
return m.busyRetry(func() error {
now := time.Now().UTC()
return write(ctx, m.db, m.stbl, m.sqlTime, store, deletes, writes, now)
})
}
// ReadUserTuple see [storage.RelationshipTupleReader].ReadUserTuple.
func (m *SQLite) ReadUserTuple(ctx context.Context, store string, tupleKey *openfgav1.TupleKey) (*openfgav1.Tuple, error) {
ctx, span := tracer.Start(ctx, "sqlite.ReadUserTuple")
defer span.End()
objectType, objectID := tupleUtils.SplitObject(tupleKey.GetObject())
userType := tupleUtils.GetUserTypeFromUser(tupleKey.GetUser())
var conditionName sql.NullString
var conditionContext []byte
var record storage.TupleRecord
err := m.stbl.
Select(
"object_type", "object_id", "relation", "_user",
"condition_name", "condition_context",
).
From("tuple").
Where(sq.Eq{
"store": store,
"object_type": objectType,
"object_id": objectID,
"relation": tupleKey.GetRelation(),
"_user": tupleKey.GetUser(),
"user_type": userType,
}).
QueryRowContext(ctx).
Scan(
&record.ObjectType,
&record.ObjectID,
&record.Relation,
&record.User,
&conditionName,
&conditionContext,
)
if err != nil {
return nil, handleSQLError(err)
}
if conditionName.String != "" {
record.ConditionName = conditionName.String
if conditionContext != nil {
var conditionContextStruct structpb.Struct
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, err
}
record.ConditionContext = &conditionContextStruct
}
}
return record.AsTuple(), nil
}
// ReadUsersetTuples see [storage.RelationshipTupleReader].ReadUsersetTuples.
func (m *SQLite) ReadUsersetTuples(
ctx context.Context,
store string,
filter storage.ReadUsersetTuplesFilter,
) (storage.TupleIterator, error) {
ctx, span := tracer.Start(ctx, "sqlite.ReadUsersetTuples")
defer span.End()
sb := m.stbl.
Select(
"store", "object_type", "object_id", "relation", "_user",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{"store": store}).
Where(sq.Eq{"user_type": tupleUtils.UserSet})
objectType, objectID := tupleUtils.SplitObject(filter.Object)
if objectType != "" {
sb = sb.Where(sq.Eq{"object_type": objectType})
}
if objectID != "" {
sb = sb.Where(sq.Eq{"object_id": objectID})
}
if filter.Relation != "" {
sb = sb.Where(sq.Eq{"relation": filter.Relation})
}
if len(filter.AllowedUserTypeRestrictions) > 0 {
orConditions := sq.Or{}
for _, userset := range filter.AllowedUserTypeRestrictions {
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Relation); ok {
orConditions = append(orConditions, sq.Like{"_user": userset.GetType() + ":%#" + userset.GetRelation()})
}
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Wildcard); ok {
orConditions = append(orConditions, sq.Eq{"_user": userset.GetType() + ":*"})
}
}
sb = sb.Where(orConditions)
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, handleSQLError(err)
}
return sqlcommon.NewSQLTupleIterator(rows), nil
}
// ReadStartingWithUser see [storage.RelationshipTupleReader].ReadStartingWithUser.
func (m *SQLite) ReadStartingWithUser(
ctx context.Context,
store string,
opts storage.ReadStartingWithUserFilter,
) (storage.TupleIterator, error) {
ctx, span := tracer.Start(ctx, "sqlite.ReadStartingWithUser")
defer span.End()
targetUsersArg := make([]string, 0, len(opts.UserFilter))
for _, u := range opts.UserFilter {
targetUser := u.GetObject()
if u.GetRelation() != "" {
targetUser = strings.Join([]string{u.GetObject(), u.GetRelation()}, "#")
}
targetUsersArg = append(targetUsersArg, targetUser)
}
rows, err := m.stbl.
Select(
"store", "object_type", "object_id", "relation", "_user",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{
"store": store,
"object_type": opts.ObjectType,
"relation": opts.Relation,
"_user": targetUsersArg,
}).QueryContext(ctx)
if err != nil {
return nil, handleSQLError(err)
}
return sqlcommon.NewSQLTupleIterator(rows), nil
}
// MaxTuplesPerWrite see [storage.RelationshipTupleWriter].MaxTuplesPerWrite.
func (m *SQLite) MaxTuplesPerWrite() int {
return m.cfg.MaxTuplesPerWriteField
}
// ReadAuthorizationModel see [storage.AuthorizationModelReadBackend].ReadAuthorizationModel.
func (m *SQLite) ReadAuthorizationModel(ctx context.Context, store string, modelID string) (*openfgav1.AuthorizationModel, error) {
ctx, span := tracer.Start(ctx, "sqlite.ReadAuthorizationModel")
defer span.End()
return sqlcommon.ReadAuthorizationModel(ctx, m.dbInfo, store, modelID)
}
// ReadAuthorizationModels see [storage.AuthorizationModelReadBackend].ReadAuthorizationModels.
func (m *SQLite) ReadAuthorizationModels(
ctx context.Context,
store string,
opts storage.PaginationOptions,
) ([]*openfgav1.AuthorizationModel, []byte, error) {
ctx, span := tracer.Start(ctx, "sqlite.ReadAuthorizationModels")
defer span.End()
sb := m.stbl.Select("authorization_model_id").
Distinct().
From("authorization_model").
Where(sq.Eq{"store": store}).
OrderBy("authorization_model_id desc")
if opts.From != "" {
token, err := sqlcommon.UnmarshallContToken(opts.From)
if err != nil {
return nil, nil, err
}
sb = sb.Where(sq.LtOrEq{"authorization_model_id": token.Ulid})
}
if opts.PageSize > 0 {
sb = sb.Limit(uint64(opts.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, nil, handleSQLError(err)
}
defer func() { _ = rows.Close() }()
var modelIDs []string
var modelID string
for rows.Next() {
err = rows.Scan(&modelID)
if err != nil {
return nil, nil, handleSQLError(err)
}
modelIDs = append(modelIDs, modelID)
}
if err := rows.Err(); err != nil {
return nil, nil, handleSQLError(err)
}
var token []byte
numModelIDs := len(modelIDs)
if len(modelIDs) > opts.PageSize {
numModelIDs = opts.PageSize
token, err = json.Marshal(sqlcommon.NewContToken(modelID, ""))
if err != nil {
return nil, nil, err
}
}
// TODO: make this concurrent with a maximum of 5 goroutines. This may be helpful:
// https://stackoverflow.com/questions/25306073/always-have-x-number-of-goroutines-running-at-any-time
models := make([]*openfgav1.AuthorizationModel, 0, numModelIDs)
// We use numModelIDs here to avoid retrieving possibly one extra model.
for i := 0; i < numModelIDs; i++ {
model, err := m.ReadAuthorizationModel(ctx, store, modelIDs[i])
if err != nil {
return nil, nil, err
}
models = append(models, model)
}
return models, token, nil
}
// FindLatestAuthorizationModel see [storage.AuthorizationModelReadBackend].FindLatestAuthorizationModel.
func (m *SQLite) FindLatestAuthorizationModel(ctx context.Context, store string) (*openfgav1.AuthorizationModel, error) {
ctx, span := tracer.Start(ctx, "sqlite.FindLatestAuthorizationModel")
defer span.End()
return sqlcommon.FindLatestAuthorizationModel(ctx, m.dbInfo, store)
}
// MaxTypesPerAuthorizationModel see [storage.TypeDefinitionWriteBackend].MaxTypesPerAuthorizationModel.
func (m *SQLite) MaxTypesPerAuthorizationModel() int {
return m.cfg.MaxTypesPerModelField
}
// WriteAuthorizationModel see [storage.TypeDefinitionWriteBackend].WriteAuthorizationModel.
func (m *SQLite) WriteAuthorizationModel(ctx context.Context, store string, model *openfgav1.AuthorizationModel) error {
ctx, span := tracer.Start(ctx, "sqlite.WriteAuthorizationModel")
defer span.End()
typeDefinitions := model.GetTypeDefinitions()
if len(typeDefinitions) > m.MaxTypesPerAuthorizationModel() {
return storage.ExceededMaxTypeDefinitionsLimitError(m.MaxTypesPerAuthorizationModel())
}
return m.busyRetry(func() error {
return sqlcommon.WriteAuthorizationModel(ctx, m.dbInfo, store, model)
})
}
// CreateStore adds a new store to the SQLite storage.
func (m *SQLite) CreateStore(ctx context.Context, store *openfgav1.Store) (*openfgav1.Store, error) {
ctx, span := tracer.Start(ctx, "sqlite.CreateStore")
defer span.End()
txn, err := m.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return nil, handleSQLError(err)
}
defer func() {
_ = txn.Rollback()
}()
_, err = m.stbl.
Insert("store").
Columns("id", "name", "created_at", "updated_at").
Values(store.GetId(), store.GetName(), sq.Expr("datetime('subsec')"), sq.Expr("datetime('subsec')")).
RunWith(txn).
ExecContext(ctx)
if err != nil {
return nil, handleSQLError(err)
}
var createdAt time.Time
var id, name string
err = m.stbl.
Select("id", "name", "created_at").
From("store").
Where(sq.Eq{"id": store.GetId()}).
RunWith(txn).
QueryRowContext(ctx).
Scan(&id, &name, &createdAt)
if err != nil {
return nil, handleSQLError(err)
}
err = txn.Commit()
if err != nil {
return nil, handleSQLError(err)
}
return &openfgav1.Store{
Id: id,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(createdAt),
}, nil
}
// GetStore retrieves the details of a specific store from the SQLite using its storeID.
func (m *SQLite) GetStore(ctx context.Context, id string) (*openfgav1.Store, error) {
ctx, span := tracer.Start(ctx, "sqlite.GetStore")
defer span.End()
row := m.stbl.
Select("id", "name", "created_at", "updated_at").
From("store").
Where(sq.Eq{
"id": id,
"deleted_at": nil,
}).
QueryRowContext(ctx)
var storeID, name string
var createdAt, updatedAt time.Time
err := row.Scan(&storeID, &name, &createdAt, &updatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, storage.ErrNotFound
}
return nil, handleSQLError(err)
}
return &openfgav1.Store{
Id: storeID,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
}, nil
}
// ListStores provides a paginated list of all stores present in the SQLite storage.
func (m *SQLite) ListStores(ctx context.Context, opts storage.PaginationOptions) ([]*openfgav1.Store, []byte, error) {
ctx, span := tracer.Start(ctx, "sqlite.ListStores")
defer span.End()
sb := m.stbl.
Select("id", "name", "created_at", "updated_at").
From("store").
Where(sq.Eq{"deleted_at": nil}).
OrderBy("id")
if opts.From != "" {
token, err := sqlcommon.UnmarshallContToken(opts.From)
if err != nil {
return nil, nil, err
}
sb = sb.Where(sq.GtOrEq{"id": token.Ulid})
}
if opts.PageSize > 0 {
sb = sb.Limit(uint64(opts.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, nil, handleSQLError(err)
}
defer func() { _ = rows.Close() }()
var stores []*openfgav1.Store
var id string
for rows.Next() {
var name string
var createdAt, updatedAt time.Time
err := rows.Scan(&id, &name, &createdAt, &updatedAt)
if err != nil {
return nil, nil, handleSQLError(err)
}
stores = append(stores, &openfgav1.Store{
Id: id,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
})
}
if err := rows.Err(); err != nil {
return nil, nil, handleSQLError(err)
}
if len(stores) > opts.PageSize {
contToken, err := json.Marshal(sqlcommon.NewContToken(id, ""))
if err != nil {
return nil, nil, err
}
return stores[:opts.PageSize], contToken, nil
}
return stores, nil, nil
}
// DeleteStore removes a store from the SQLite storage.
func (m *SQLite) DeleteStore(ctx context.Context, id string) error {
ctx, span := tracer.Start(ctx, "sqlite.DeleteStore")
defer span.End()
_, err := m.stbl.
Update("store").
Set("deleted_at", sq.Expr("datetime('subsec')")).
Where(sq.Eq{"id": id}).
ExecContext(ctx)
if err != nil {
return handleSQLError(err)
}
return nil
}
// WriteAssertions see [storage.AssertionsBackend].WriteAssertions.
func (m *SQLite) WriteAssertions(ctx context.Context, store, modelID string, assertions []*openfgav1.Assertion) error {
ctx, span := tracer.Start(ctx, "sqlite.WriteAssertions")
defer span.End()
marshalledAssertions, err := proto.Marshal(&openfgav1.Assertions{Assertions: assertions})
if err != nil {
return err
}
return m.busyRetry(func() error {
_, err = m.stbl.
Insert("assertion").
Columns("store", "authorization_model_id", "assertions").
Values(store, modelID, marshalledAssertions).
Suffix("ON CONFLICT(store,authorization_model_id) DO UPDATE SET assertions = ?", marshalledAssertions).
ExecContext(ctx)
if err != nil {
return handleSQLError(err)
}
return nil
})
}
// ReadAssertions see [storage.AssertionsBackend].ReadAssertions.
func (m *SQLite) ReadAssertions(ctx context.Context, store, modelID string) ([]*openfgav1.Assertion, error) {
ctx, span := tracer.Start(ctx, "sqlite.ReadAssertions")
defer span.End()
var marshalledAssertions []byte
err := m.stbl.
Select("assertions").
From("assertion").
Where(sq.Eq{
"store": store,
"authorization_model_id": modelID,
}).
QueryRowContext(ctx).
Scan(&marshalledAssertions)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return []*openfgav1.Assertion{}, nil
}
return nil, handleSQLError(err)
}
var assertions openfgav1.Assertions
err = proto.Unmarshal(marshalledAssertions, &assertions)
if err != nil {
return nil, err
}
return assertions.GetAssertions(), nil
}
// ReadChanges see [storage.ChangelogBackend].ReadChanges.
func (m *SQLite) ReadChanges(
ctx context.Context,
store, objectTypeFilter string,
opts storage.PaginationOptions,
horizonOffset time.Duration,
) ([]*openfgav1.TupleChange, []byte, error) {
ctx, span := tracer.Start(ctx, "sqlite.ReadChanges")
defer span.End()
sb := m.stbl.
Select(
"ulid", "object_type", "object_id", "relation", "_user", "operation",
"condition_name", "condition_context", "inserted_at",
).
From("changelog").
Where(sq.Eq{"store": store}).
Where(fmt.Sprintf("inserted_at <= datetime('subsec','-%f seconds')", horizonOffset.Seconds())).
OrderBy("ulid asc")
if objectTypeFilter != "" {
sb = sb.Where(sq.Eq{"object_type": objectTypeFilter})
}
if opts.From != "" {
token, err := sqlcommon.UnmarshallContToken(opts.From)
if err != nil {
return nil, nil, err
}
if token.ObjectType != objectTypeFilter {
return nil, nil, storage.ErrMismatchObjectType
}
sb = sb.Where(sq.Gt{"ulid": token.Ulid}) // > as we always return a continuation token.
}
if opts.PageSize > 0 {
sb = sb.Limit(uint64(opts.PageSize)) // + 1 is NOT used here as we always return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, nil, handleSQLError(err)
}
defer func() { _ = rows.Close() }()
var changes []*openfgav1.TupleChange
var ulid string
for rows.Next() {
var objectType, objectID, relation, user string
var operation int
var insertedAt time.Time
var conditionName sql.NullString
var conditionContext []byte
err = rows.Scan(
&ulid,
&objectType,
&objectID,
&relation,
&user,
&operation,
&conditionName,
&conditionContext,
&insertedAt,
)
if err != nil {
return nil, nil, handleSQLError(err)
}
var conditionContextStruct structpb.Struct
if conditionName.String != "" {
if conditionContext != nil {
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, nil, err
}
}
}
tk := tupleUtils.NewTupleKeyWithCondition(
tupleUtils.BuildObject(objectType, objectID),
relation,
user,
conditionName.String,
&conditionContextStruct,
)
changes = append(changes, &openfgav1.TupleChange{
TupleKey: tk,
Operation: openfgav1.TupleOperation(operation),
Timestamp: timestamppb.New(insertedAt.UTC()),
})
}
if len(changes) == 0 {
return nil, nil, storage.ErrNotFound
}
contToken, err := json.Marshal(sqlcommon.NewContToken(ulid, objectTypeFilter))
if err != nil {
return nil, nil, err
}
return changes, contToken, nil
}
func (m *SQLite) IsReady(ctx context.Context) (storage.ReadinessStatus, error) {
if err := m.db.PingContext(ctx); err != nil {
return storage.ReadinessStatus{}, err
}
return storage.ReadinessStatus{
IsReady: true,
}, nil
}
// SQLite will return an SQLITE_BUSY error when the database is locked rather than waiting for the lock.
// This function retries the operation up to 5 times before returning the error.
func (m *SQLite) busyRetry(fn func() error) error {
for retries := 0; ; retries++ {
err := fn()
if err == nil || retries == m.cfg.QueryRetries {
return err
}
var sqliteErr *sqlite3.Error
if errors.As(err, &sqliteErr) && (sqliteErr.Code == sqlite3.ErrLocked || sqliteErr.Code == sqlite3.ErrBusy) {
time.Sleep(10 * time.Millisecond)
continue
}
return err
}
}
func handleSQLError(err error, args ...any) error {
if strings.Contains(err.Error(), "UNIQUE constraint failed:") {
if len(args) > 0 {
if tk, ok := args[0].(*openfgav1.TupleKey); ok {
return storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE)
}
}
return storage.ErrCollision
}
return sqlcommon.HandleSQLError(err, args...)
}

View File

@ -0,0 +1,294 @@
package sqlite
import (
"context"
"database/sql"
"testing"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/oklog/ulid/v2"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/sqlcommon"
"github.com/openfga/openfga/pkg/storage/test"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/services/authz/zanzana/store"
"github.com/grafana/grafana/pkg/services/authz/zanzana/store/migration"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/tests/testsuite"
)
func TestMain(m *testing.M) {
testsuite.Run(m)
}
// TestIntegrationDatastore runs open fga default datastore test suite
func TestIntegrationDatastore(t *testing.T) {
db := sqliteIntegrationTest(t)
ds, err := NewWithDB(db, NewConfig())
require.NoError(t, err)
test.RunAllTests(t, ds)
}
// TestIntegrationReadEnsureNoOrder asserts that the read response is not ordered by ulid.
func TestIntegrationReadEnsureNoOrder(t *testing.T) {
db := sqliteIntegrationTest(t)
ds, err := NewWithDB(db, NewConfig())
require.NoError(t, err)
ctx := context.Background()
store := "store"
firstTuple := tuple.NewTupleKey("doc:object_id_1", "relation", "user:user_1")
secondTuple := tuple.NewTupleKey("doc:object_id_2", "relation", "user:user_2")
err = sqlcommon.Write(ctx,
sqlcommon.NewDBInfo(ds.db, ds.stbl, sq.Expr("datetime('subsec')")),
store,
[]*openfgav1.TupleKeyWithoutCondition{},
[]*openfgav1.TupleKey{firstTuple},
time.Now())
require.NoError(t, err)
// Tweak time so that ULID is smaller.
err = sqlcommon.Write(ctx,
sqlcommon.NewDBInfo(ds.db, ds.stbl, sq.Expr("datetime('subsec')")),
store,
[]*openfgav1.TupleKeyWithoutCondition{},
[]*openfgav1.TupleKey{secondTuple},
time.Now().Add(time.Minute*-1))
require.NoError(t, err)
iter, err := ds.Read(ctx,
store,
tuple.NewTupleKey("doc:", "relation", ""))
defer iter.Stop()
require.NoError(t, err)
// We expect that objectID1 will return first because it is inserted first.
curTuple, err := iter.Next(ctx)
require.NoError(t, err)
require.Equal(t, firstTuple, curTuple.GetKey())
curTuple, err = iter.Next(ctx)
require.NoError(t, err)
require.Equal(t, secondTuple, curTuple.GetKey())
}
// TestIntegrationReadPageEnsureNoOrder asserts that the read page is ordered by ulid.
func TestIntegrationReadPageEnsureOrder(t *testing.T) {
db := sqliteIntegrationTest(t)
ds, err := NewWithDB(db, NewConfig())
require.NoError(t, err)
ctx := context.Background()
store := "store"
firstTuple := tuple.NewTupleKey("doc:object_id_1", "relation", "user:user_1")
secondTuple := tuple.NewTupleKey("doc:object_id_2", "relation", "user:user_2")
err = sqlcommon.Write(ctx,
sqlcommon.NewDBInfo(ds.db, ds.stbl, sq.Expr("datetime('subsec')")),
store,
[]*openfgav1.TupleKeyWithoutCondition{},
[]*openfgav1.TupleKey{firstTuple},
time.Now())
require.NoError(t, err)
// Tweak time so that ULID is smaller.
err = sqlcommon.Write(ctx,
sqlcommon.NewDBInfo(ds.db, ds.stbl, sq.Expr("datetime('subsec')")),
store,
[]*openfgav1.TupleKeyWithoutCondition{},
[]*openfgav1.TupleKey{secondTuple},
time.Now().Add(time.Minute*-1))
require.NoError(t, err)
tuples, _, err := ds.ReadPage(ctx,
store,
tuple.NewTupleKey("doc:", "relation", ""),
storage.NewPaginationOptions(0, ""))
require.NoError(t, err)
require.Len(t, tuples, 2)
// We expect that objectID2 will return first because it has a smaller ulid.
require.Equal(t, secondTuple, tuples[0].GetKey())
require.Equal(t, firstTuple, tuples[1].GetKey())
}
func TestIntegrationReadAuthorizationModelUnmarshallError(t *testing.T) {
db := sqliteIntegrationTest(t)
ds, err := NewWithDB(db, NewConfig())
require.NoError(t, err)
ctx := context.Background()
store := "store"
modelID := "foo"
schemaVersion := typesystem.SchemaVersion1_0
bytes, err := proto.Marshal(&openfgav1.TypeDefinition{Type: "document"})
require.NoError(t, err)
pbdata := []byte{0x01, 0x02, 0x03}
_, err = ds.db.ExecContext(ctx, "INSERT INTO authorization_model (store, authorization_model_id, schema_version, type, type_definition, serialized_protobuf) VALUES (?, ?, ?, ?, ?, ?)", store, modelID, schemaVersion, "document", bytes, pbdata)
require.NoError(t, err)
_, err = ds.ReadAuthorizationModel(ctx, store, modelID)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot parse invalid wire-format data")
}
// TestIntegrationAllowNullCondition tests that tuple and changelog rows existing before
// migration 005_add_conditions_to_tuples can be successfully read.
func TestIntegrationAllowNullCondition(t *testing.T) {
db := sqliteIntegrationTest(t)
ds, err := NewWithDB(db, NewConfig())
require.NoError(t, err)
ctx := context.Background()
stmt := `
INSERT INTO tuple (
store, object_type, object_id, relation, _user, user_type, ulid,
condition_name, condition_context, inserted_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('subsec'));
`
_, err = ds.db.ExecContext(
ctx, stmt, "store", "folder", "2021-budget", "owner", "user:anne", "user",
ulid.Make().String(), nil, nil,
)
require.NoError(t, err)
tk := tuple.NewTupleKey("folder:2021-budget", "owner", "user:anne")
iter, err := ds.Read(ctx, "store", tk)
require.NoError(t, err)
defer iter.Stop()
curTuple, err := iter.Next(ctx)
require.NoError(t, err)
require.Equal(t, tk, curTuple.GetKey())
tuples, _, err := ds.ReadPage(ctx, "store", &openfgav1.TupleKey{}, storage.PaginationOptions{
PageSize: 2,
})
require.NoError(t, err)
require.Len(t, tuples, 1)
require.Equal(t, tk, tuples[0].GetKey())
userTuple, err := ds.ReadUserTuple(ctx, "store", tk)
require.NoError(t, err)
require.Equal(t, tk, userTuple.GetKey())
tk2 := tuple.NewTupleKey("folder:2022-budget", "viewer", "user:anne")
_, err = ds.db.ExecContext(
ctx, stmt, "store", "folder", "2022-budget", "viewer", "user:anne", "userset",
ulid.Make().String(), nil, nil,
)
require.NoError(t, err)
iter, err = ds.ReadUsersetTuples(ctx, "store", storage.ReadUsersetTuplesFilter{Object: "folder:2022-budget"})
require.NoError(t, err)
defer iter.Stop()
curTuple, err = iter.Next(ctx)
require.NoError(t, err)
require.Equal(t, tk2, curTuple.GetKey())
iter, err = ds.ReadStartingWithUser(ctx, "store", storage.ReadStartingWithUserFilter{
ObjectType: "folder",
Relation: "owner",
UserFilter: []*openfgav1.ObjectRelation{
{Object: "user:anne"},
},
})
require.NoError(t, err)
defer iter.Stop()
curTuple, err = iter.Next(ctx)
require.NoError(t, err)
require.Equal(t, tk, curTuple.GetKey())
stmt = `
INSERT INTO changelog (
store, object_type, object_id, relation, _user, ulid,
condition_name, condition_context, inserted_at, operation
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('subsec'), ?);
`
_, err = ds.db.ExecContext(
ctx, stmt, "store", "folder", "2021-budget", "owner", "user:anne",
ulid.Make().String(), nil, nil, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE,
)
require.NoError(t, err)
_, err = ds.db.ExecContext(
ctx, stmt, "store", "folder", "2021-budget", "owner", "user:anne",
ulid.Make().String(), nil, nil, openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
)
require.NoError(t, err)
changes, _, err := ds.ReadChanges(ctx, "store", "folder", storage.PaginationOptions{}, 0)
require.NoError(t, err)
require.Len(t, changes, 2)
require.Equal(t, tk, changes[0].GetTupleKey())
require.Equal(t, tk, changes[1].GetTupleKey())
}
// TestIntegrationMarshalledAssertions tests that previously persisted marshalled
// assertions can be read back. In any case where the Assertions proto model
// needs to change, we'll likely need to introduce a series of data migrations.
func TestIntegrationMarshalledAssertions(t *testing.T) {
db := sqliteIntegrationTest(t)
ds, err := NewWithDB(db, NewConfig())
require.NoError(t, err)
ctx := context.Background()
// Note: this represents an assertion written on v1.3.7.
stmt := `
INSERT INTO assertion (
store, authorization_model_id, assertions
) VALUES (?, ?, UNHEX('0A2B0A270A12666F6C6465723A323032312D62756467657412056F776E65721A0A757365723A616E6E657A1001'));
`
_, err = ds.db.ExecContext(ctx, stmt, "store", "model")
require.NoError(t, err)
assertions, err := ds.ReadAssertions(ctx, "store", "model")
require.NoError(t, err)
expectedAssertions := []*openfgav1.Assertion{
{
TupleKey: &openfgav1.AssertionTupleKey{
Object: "folder:2021-budget",
Relation: "owner",
User: "user:annez",
},
Expectation: true,
},
}
require.Equal(t, expectedAssertions, assertions)
}
func sqliteIntegrationTest(tb testing.TB) *sql.DB {
if testing.Short() || !db.IsTestDbSQLite() {
tb.Skip("skipping integration test")
}
db, cfg := db.InitTestDBWithCfg(tb)
m := migrator.NewMigrator(db.GetEngine(), cfg)
err := migration.RunWithMigrator(m, cfg, store.EmbedMigrations, store.SQLiteMigrationDir)
require.NoError(tb, err)
return db.GetEngine().DB().DB
}

View File

@ -0,0 +1,165 @@
package sqlite
import (
"context"
"database/sql"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/oklog/ulid/v2"
"google.golang.org/protobuf/proto"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
)
// write is copied from https://github.com/openfga/openfga/blob/main/pkg/storage/sqlcommon/sqlcommon.go#L330-L456
// but uses custom handleSQLError.
func write(
ctx context.Context,
db *sql.DB,
stbl sq.StatementBuilderType,
sqlTime any,
store string,
deletes storage.Deletes,
writes storage.Writes,
now time.Time,
) error {
txn, err := db.BeginTx(ctx, nil)
if err != nil {
return handleSQLError(err)
}
defer func() {
_ = txn.Rollback()
}()
changelogBuilder := stbl.
Insert("changelog").
Columns(
"store", "object_type", "object_id", "relation", "_user",
"condition_name", "condition_context", "operation", "ulid", "inserted_at",
)
deleteBuilder := stbl.Delete("tuple")
for _, tk := range deletes {
id := ulid.MustNew(ulid.Timestamp(now), ulid.DefaultEntropy()).String()
objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
res, err := deleteBuilder.
Where(sq.Eq{
"store": store,
"object_type": objectType,
"object_id": objectID,
"relation": tk.GetRelation(),
"_user": tk.GetUser(),
"user_type": tupleUtils.GetUserTypeFromUser(tk.GetUser()),
}).
RunWith(txn). // Part of a txn.
ExecContext(ctx)
if err != nil {
return handleSQLError(err, tk)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return handleSQLError(err)
}
if rowsAffected != 1 {
return storage.InvalidWriteInputError(
tk,
openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
)
}
changelogBuilder = changelogBuilder.Values(
store, objectType, objectID,
tk.GetRelation(), tk.GetUser(),
"", nil, // Redact condition info for deletes since we only need the base triplet (object, relation, user).
openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
id, sqlTime,
)
}
insertBuilder := stbl.
Insert("tuple").
Columns(
"store", "object_type", "object_id", "relation", "_user", "user_type",
"condition_name", "condition_context", "ulid", "inserted_at",
)
for _, tk := range writes {
id := ulid.MustNew(ulid.Timestamp(now), ulid.DefaultEntropy()).String()
objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
conditionName, conditionContext, err := marshalRelationshipCondition(tk.GetCondition())
if err != nil {
return err
}
_, err = insertBuilder.
Values(
store,
objectType,
objectID,
tk.GetRelation(),
tk.GetUser(),
tupleUtils.GetUserTypeFromUser(tk.GetUser()),
conditionName,
conditionContext,
id,
sqlTime,
).
RunWith(txn). // Part of a txn.
ExecContext(ctx)
if err != nil {
return handleSQLError(err, tk)
}
changelogBuilder = changelogBuilder.Values(
store,
objectType,
objectID,
tk.GetRelation(),
tk.GetUser(),
conditionName,
conditionContext,
openfgav1.TupleOperation_TUPLE_OPERATION_WRITE,
id,
sqlTime,
)
}
if len(writes) > 0 || len(deletes) > 0 {
_, err := changelogBuilder.RunWith(txn).ExecContext(ctx) // Part of a txn.
if err != nil {
return handleSQLError(err)
}
}
if err := txn.Commit(); err != nil {
return handleSQLError(err)
}
return nil
}
// copied from https://github.com/openfga/openfga/blob/main/pkg/storage/sqlcommon/encoding.go#L8-L24
func marshalRelationshipCondition(
rel *openfgav1.RelationshipCondition,
) (name string, context []byte, err error) {
if rel != nil {
if rel.GetContext() != nil && len(rel.GetContext().GetFields()) > 0 {
context, err = proto.Marshal(rel.GetContext())
if err != nil {
return name, context, err
}
}
return rel.GetName(), context, err
}
return name, context, err
}