Unified Storage: Add SQL template package (#87524)

* added sqltemplate package

* addded example

* fix linting issues

* improve code readability

* fix documentation
This commit is contained in:
Diego Augusto Molina 2024-05-08 17:58:18 -03:00 committed by GitHub
parent 4b38253b20
commit acf17c7fb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 535 additions and 0 deletions

View File

@ -0,0 +1,14 @@
package sqltemplate
// Args keeps the data that needs to be passed to the engine for execution in
// the right order. Add it to your data types passed to SQLTemplate, either by
// embedding or with a named struct field if its Arg method would clash with
// another struct field.
type Args []any
// Arg can be called from within templates to pass arguments to the SQL driver
// to use in the execution of the query.
func (a *Args) Arg(x any) string {
*a = append(*a, x)
return "?"
}

View File

@ -0,0 +1,34 @@
package sqltemplate
import (
"testing"
)
func TestArgs_Arg(t *testing.T) {
t.Parallel()
shouldBeQuestionMark := func(t *testing.T, s string) {
t.Helper()
if s != "?" {
t.Fatalf("expecting question mark, got %q", s)
}
}
a := new(Args)
shouldBeQuestionMark(t, a.Arg(0))
shouldBeQuestionMark(t, a.Arg(1))
shouldBeQuestionMark(t, a.Arg(2))
shouldBeQuestionMark(t, a.Arg(3))
shouldBeQuestionMark(t, a.Arg(4))
for i, arg := range *a {
v, ok := arg.(int)
if !ok {
t.Fatalf("unexpected value: %T(%v)", arg, arg)
}
if v != i {
t.Fatalf("unexpected int value: %v", v)
}
}
}

View File

@ -0,0 +1,99 @@
package sqltemplate
import (
"errors"
"strings"
)
// Dialect-agnostic errors.
var (
ErrEmptyIdent = errors.New("empty identifier")
ErrInvalidRowLockingClause = errors.New("invalid row-locking clause")
)
// Dialect should be added to the data types passed to SQL templates to
// provide methods that deal with SQL implementation-specific traits. It can be
// embedded for ease of use, or with a named struct field if any of its methods
// would clash with other struct field names.
type Dialect interface {
// Ident returns the given string quoted in a way that is suitable to be
// used as an identifier. Database names, schema names, table names, column
// names are all examples of identifiers.
Ident(string) (string, error)
// SelectFor parses and returns the given row-locking clause for a SELECT
// statement. If the clause is invalid it returns an error. Implementations
// of this method should use ParseRowLockingClause.
// Example:
//
// SELECT *
// FROM mytab
// WHERE id = ?
// {{ .SelectFor Update NoWait }}; -- will be uppercased
SelectFor(...string) (string, error)
}
// RowLockingClause represents a row-locking clause in a SELECT statement.
type RowLockingClause string
// Valid returns whether the given option is valid.
func (o RowLockingClause) Valid() bool {
switch o {
case SelectForShare, SelectForShareNoWait, SelectForShareSkipLocked,
SelectForUpdate, SelectForUpdateNoWait, SelectForUpdateSkipLocked:
return true
}
return false
}
// ParseRowLockingClause parses a RowLockingClause from the given strings. This
// should be used by implementations of Dialect to parse the input of the
// SelectFor method.
func ParseRowLockingClause(s ...string) (RowLockingClause, error) {
opt := RowLockingClause(strings.ToUpper(strings.Join(s, " ")))
if !opt.Valid() {
return "", ErrInvalidRowLockingClause
}
return opt, nil
}
// Row-locking clause options.
const (
SelectForShare RowLockingClause = "SHARE"
SelectForShareNoWait RowLockingClause = "SHARE NOWAIT"
SelectForShareSkipLocked RowLockingClause = "SHARE SKIP LOCKED"
SelectForUpdate RowLockingClause = "UPDATE"
SelectForUpdateNoWait RowLockingClause = "UPDATE NOWAIT"
SelectForUpdateSkipLocked RowLockingClause = "UPDATE SKIP LOCKED"
)
// rowLockingClauseAll aids implementations that either support all the
// row-locking clause options or none. If it's true, it returns the clause,
// otherwise it returns an empty string.
type rowLockingClauseAll bool
func (rlc rowLockingClauseAll) SelectFor(s ...string) (string, error) {
// all implementations should err on invalid input, otherwise we would just
// be hiding the error until we change the dialect
o, err := ParseRowLockingClause(s...)
if err != nil {
return "", err
}
if !rlc {
return "", nil
}
return string(o), nil
}
// standardIdent provides standard SQL escaping of identifiers.
type standardIdent struct{}
func (standardIdent) Ident(s string) (string, error) {
if s == "" {
return "", ErrEmptyIdent
}
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`, nil
}

View File

@ -0,0 +1,18 @@
package sqltemplate
// MySQL is an implementation of Dialect for the MySQL DMBS. It relies on having
// ANSI_QUOTES SQL Mode enabled. For more information about ANSI_QUOTES and SQL
// Modes see:
//
// https://dev.mysql.com/doc/refman/8.4/en/sql-mode.html#sqlmode_ansi_quotes
var MySQL mysql
var _ Dialect = MySQL
type mysql struct {
standardIdent
}
func (mysql) SelectFor(s ...string) (string, error) {
return rowLockingClauseAll(true).SelectFor(s...)
}

View File

@ -0,0 +1,7 @@
package sqltemplate
import "testing"
func TestMySQL_SelectFor(t *testing.T) {
MySQL.SelectFor() //nolint: errcheck,gosec
}

View File

@ -0,0 +1,34 @@
package sqltemplate
import (
"errors"
"strings"
)
// PostgreSQL is an implementation of Dialect for the PostgreSQL DMBS.
var PostgreSQL postgresql
var _ Dialect = PostgreSQL
// PostgreSQL-specific errors.
var (
ErrPostgreSQLUnsupportedIdent = errors.New("identifiers in PostgreSQL cannot contain the character with code zero")
)
type postgresql struct {
standardIdent
}
func (p postgresql) Ident(s string) (string, error) {
// See:
// https://www.postgresql.org/docs/current/sql-syntax-lexical.html
if strings.IndexByte(s, 0) != -1 {
return "", ErrPostgreSQLUnsupportedIdent
}
return p.standardIdent.Ident(s)
}
func (postgresql) SelectFor(s ...string) (string, error) {
return rowLockingClauseAll(true).SelectFor(s...)
}

View File

@ -0,0 +1,42 @@
package sqltemplate
import (
"errors"
"testing"
)
func TestPostgreSQL_SelectFor(t *testing.T) {
PostgreSQL.SelectFor() //nolint: errcheck,gosec
}
func TestPostgreSQL_Ident(t *testing.T) {
t.Parallel()
testCases := []struct {
input string
output string
err error
}{
{input: ``, err: ErrEmptyIdent},
{input: `polite_example`, output: `"polite_example"`},
{input: `Juan Carlos`, output: `"Juan Carlos"`},
{
input: `unpolite_` + string([]byte{0}) + `example`,
err: ErrPostgreSQLUnsupportedIdent,
},
{
input: `exaggerated " ' ` + "`" + ` example`,
output: `"exaggerated "" ' ` + "`" + ` example"`,
},
}
for i, tc := range testCases {
gotOutput, gotErr := PostgreSQL.Ident(tc.input)
if !errors.Is(gotErr, tc.err) {
t.Fatalf("unexpected error %v in test case %d", gotErr, i)
}
if gotOutput != tc.output {
t.Fatalf("unexpected error %v in test case %d", gotErr, i)
}
}
}

View File

@ -0,0 +1,16 @@
package sqltemplate
// SQLite is an implementation of Dialect for the SQLite DMBS.
var SQLite sqlite
var _ Dialect = SQLite
type sqlite struct {
// See:
// https://www.sqlite.org/lang_keywords.html
standardIdent
}
func (sqlite) SelectFor(s ...string) (string, error) {
return rowLockingClauseAll(false).SelectFor(s...)
}

View File

@ -0,0 +1,7 @@
package sqltemplate
import "testing"
func TestSQLite_SelectFor(t *testing.T) {
SQLite.SelectFor() //nolint: errcheck,gosec
}

View File

@ -0,0 +1,143 @@
package sqltemplate
import (
"errors"
"strings"
"testing"
)
func TestSelectForOption_Valid(t *testing.T) {
t.Parallel()
testCases := []struct {
input RowLockingClause
expected bool
}{
{input: "", expected: false},
{input: "share", expected: false},
{input: SelectForShare, expected: true},
{input: SelectForShareNoWait, expected: true},
{input: SelectForShareSkipLocked, expected: true},
{input: SelectForUpdate, expected: true},
{input: SelectForUpdateNoWait, expected: true},
{input: SelectForUpdateSkipLocked, expected: true},
}
for i, tc := range testCases {
got := tc.input.Valid()
if got != tc.expected {
t.Fatalf("unexpected %v in test case %d", got, i)
}
}
}
func TestParseRowLockingClause(t *testing.T) {
t.Parallel()
splitSpace := func(s string) []string {
return strings.Split(s, " ")
}
testCases := []struct {
input []string
output RowLockingClause
err error
}{
{err: ErrInvalidRowLockingClause},
{
input: []string{" " + string(SelectForShare)},
err: ErrInvalidRowLockingClause,
},
{
input: splitSpace(string(SelectForShareNoWait)),
output: SelectForShareNoWait,
},
{
input: splitSpace(strings.ToLower(string(SelectForShareNoWait))),
output: SelectForShareNoWait,
},
{
input: splitSpace(strings.ToTitle(string(SelectForShareNoWait))),
output: SelectForShareNoWait,
},
}
for i, tc := range testCases {
gotOutput, gotErr := ParseRowLockingClause(tc.input...)
if !errors.Is(gotErr, tc.err) {
t.Fatalf("unexpected error %v in test case %d", gotErr, i)
}
if gotOutput != (tc.output) {
t.Fatalf("unexpected output %q in test case %d", gotOutput, i)
}
}
}
func TestRowLockingClauseAll_SelectFor(t *testing.T) {
t.Parallel()
splitSpace := func(s string) []string {
return strings.Split(s, " ")
}
testCases := []struct {
input []string
output RowLockingClause
err error
}{
{err: ErrInvalidRowLockingClause},
{input: []string{"invalid"}, err: ErrInvalidRowLockingClause},
{input: []string{" share"}, err: ErrInvalidRowLockingClause},
{
input: splitSpace(string(SelectForShare)),
output: SelectForShare,
},
}
for i, tc := range testCases {
gotOutput, gotErr := rowLockingClauseAll(true).SelectFor(tc.input...)
if !errors.Is(gotErr, tc.err) {
t.Fatalf("[true] unexpected error %v in test case %d", gotErr, i)
}
if gotOutput != string(tc.output) {
t.Fatalf("[true] unexpected error %v in test case %d", gotErr, i)
}
gotOutput, gotErr = rowLockingClauseAll(false).SelectFor(tc.input...)
if !errors.Is(gotErr, tc.err) {
t.Fatalf("[false] unexpected error %v in test case %d", gotErr, i)
}
if gotOutput != "" {
t.Fatalf("[false] unexpected error %v in test case %d", gotErr, i)
}
}
}
func TestStandardIdent_Ident(t *testing.T) {
t.Parallel()
testCases := []struct {
input string
output string
err error
}{
{input: ``, err: ErrEmptyIdent},
{input: `polite_example`, output: `"polite_example"`},
{input: `Juan Carlos`, output: `"Juan Carlos"`},
{
input: `exaggerated " ' ` + "`" + ` example`,
output: `"exaggerated "" ' ` + "`" + ` example"`,
},
}
for i, tc := range testCases {
gotOutput, gotErr := standardIdent{}.Ident(tc.input)
if !errors.Is(gotErr, tc.err) {
t.Fatalf("unexpected error %v in test case %d", gotErr, i)
}
if gotOutput != tc.output {
t.Fatalf("unexpected error %v in test case %d", gotErr, i)
}
}
}

View File

@ -0,0 +1,121 @@
package sqltemplate
import (
"fmt"
"strings"
"text/template"
)
// This file contains runnable examples. They serve the purpose of providing
// idiomatic usage of the package as well as showing how it actually works,
// since the examples are actually run together with regular Go tests. Note that
// the "Output" comment section at the end of each function starting with
// "Example" is used by the standard Go test tool to check that the standard
// output of the function matches the commented text until the end of the
// function. If you change the function, you may need to adapt that comment
// section as it's possible that the output changes, causing it to fail tests.
// To learn more about Go's runnable tests, which are a core builtin feature of
// Go's standard testing library, see:
// https://pkg.go.dev/testing#hdr-Examples
// In this example we will use both Args and Dialect to dynamically and securely
// build SQL queries, while also keeping track of the arguments that need to be
// passed to the database methods to replace the placeholder "?" with the
// correct values. If you're not familiar with Go text templating language,
// please, consider reading that library's documentation first.
// We will start with creating a simple text template to insert a new row into a
// users table:
var createUserTmpl = template.Must(template.New("query").Parse(`
INSERT INTO users (id, {{ .Ident "type" }}, name)
VALUES ({{ .Arg .ID }}, {{ .Arg .Type }}, {{ .Arg .Name}});
`))
// The two interesting methods here are Arg and Ident. Note that now we have a
// reusable text template, that will dynamically create the SQL code when
// executed, which is interesting because we have a SQL-implementation dependant
// code handled for us within the template (escaping the reserved word "type"),
// but also because the arguments to the database Exec method will be handled
// for us. The struct with the data needed to create a new user could be
// something like the following:
type CreateUserRequest struct {
ID int
Name string
Type string
}
// Note that this struct could actually come from a different definition, for
// example, from a DTO. We can reuse this DTO and create a smaller struct for
// the purpose of writing to the database without the need of mapping:
type DBCreateUserRequest struct {
Dialect // provides access to all Dialect methods, like Ident
*Args // provides access to Arg method, to keep track of db arguments
*CreateUserRequest
}
func Example() {
// Finally, we can take a request received from a user like the following:
dto := &CreateUserRequest{
ID: 1,
Name: "root",
Type: "admin",
}
// Put it into a database request:
req := DBCreateUserRequest{
Dialect: SQLite, // set at runtime, the template is agnostic
Args: new(Args),
CreateUserRequest: dto,
}
// Then we finally execute the template to both generate the SQL code and to
// populate req.Args with the arguments:
var b strings.Builder
err := createUserTmpl.Execute(&b, req)
if err != nil {
panic(err) // terminate the runnable example on error
}
// And we should finally be able to see the SQL generated, as well as
// getting the arguments populated for execution in a database. To execute
// it in the databse, we could run:
// db.ExecContext(ctx, b.String(), req.Args...)
// To provide the runnable example with some code to test, we will now print
// the values to standard output:
fmt.Println(b.String())
fmt.Printf("%#v", req.Args)
// Output:
// INSERT INTO users (id, "type", name)
// VALUES (?, ?, ?);
//
// &sqltemplate.Args{1, "admin", "root"}
}
// A more complex template example follows, which should be self-explanatory
// given the previous example. It is left as an exercise to the reader how the
// code should be implemented, based on the ExampleCreateUser function.
// List users example.
var _ = template.Must(template.New("query").Parse(`
SELECT id, {{ .Ident "type" }}, name
FROM users
WHERE
{{ if eq .By "type" }}
{{ .Ident "type" }} = {{ .Arg .Value }}
{{ else if eq .By "name" }}
name LIKE {{ .Arg .Value }}
{{ end }};
`))
type ListUsersRequest struct {
By string
Value string
}
type DBListUsersRequest struct {
Dialect
*Args
ListUsersRequest
}