database: update xorm to v0.6.4 and xorm core to v0.5.7

This commit is contained in:
Daniel Lee 2018-03-15 09:51:29 +01:00
parent 0185ad5b04
commit 1c20126f87
49 changed files with 2995 additions and 1894 deletions

8
Gopkg.lock generated
View File

@ -171,12 +171,14 @@
[[projects]]
name = "github.com/go-xorm/core"
packages = ["."]
revision = "e8409d73255791843585964791443dbad877058c"
revision = "da1adaf7a28ca792961721a34e6e04945200c890"
version = "v0.5.7"
[[projects]]
name = "github.com/go-xorm/xorm"
packages = ["."]
revision = "6687a2b4e824f4d87f2d65060ec5cb0d896dff1e"
revision = "1933dd69e294c0a26c0266637067f24dbb25770c"
version = "v0.6.4"
[[projects]]
branch = "master"
@ -631,6 +633,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "4de68f1342ba98a637ec8ca7496aeeae2021bf9e4c7c80db7924e14709151a62"
inputs-digest = "112ccff73f668c8c4dbe3d41c37ebee65fd7d839f5a4fa0665c593cae0095dad"
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -85,13 +85,11 @@ ignored = [
[[constraint]]
name = "github.com/go-xorm/core"
revision = "e8409d73255791843585964791443dbad877058c"
#version = "0.5.7" //keeping this since we would rather depend on version then commit
version = "0.5.7"
[[constraint]]
name = "github.com/go-xorm/xorm"
revision = "6687a2b4e824f4d87f2d65060ec5cb0d896dff1e"
#version = "0.6.4" //keeping this since we would rather depend on version then commit
version = "0.6.4"
[[constraint]]
name = "github.com/gorilla/websocket"

View File

@ -125,7 +125,7 @@ func (mg *Migrator) exec(m Migration, sess *xorm.Session) error {
condition := m.GetCondition()
if condition != nil {
sql, args := condition.Sql(mg.dialect)
results, err := sess.Query(sql, args...)
results, err := sess.SQL(sql).Query(args...)
if err != nil || len(results) == 0 {
mg.Logger.Info("Skipping migration condition not fulfilled", "id", m.Id())
return sess.Rollback()

View File

@ -13,12 +13,13 @@ const (
ONLYFROMDB
)
// database column
// Column defines database column
type Column struct {
Name string
TableName string
FieldName string
SQLType SQLType
IsJSON bool
Length int
Length2 int
Nullable bool
@ -37,6 +38,7 @@ type Column struct {
SetOptions map[string]int
DisableTimeZone bool
TimeZone *time.Location // column specified time zone
Comment string
}
func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column {
@ -60,6 +62,7 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable
IsVersion: false,
DefaultIsEmpty: false,
EnumOptions: make(map[string]int),
Comment: "",
}
}

View File

@ -244,6 +244,9 @@ func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset stri
sql += col.StringNoPk(b.dialect)
}
sql = strings.TrimSpace(sql)
if b.DriverName() == MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "
}

View File

@ -196,7 +196,7 @@ func (rs *Rows) ScanMap(dest interface{}) error {
newDest := make([]interface{}, len(cols))
vvv := vv.Elem()
for i, _ := range cols {
for i := range cols {
newDest[i] = ReflectNew(vvv.Type().Elem()).Interface()
//v := reflect.New(vvv.Type().Elem())
//newDest[i] = v.Interface()
@ -247,6 +247,18 @@ type Row struct {
err error // deferred error for easy chaining
}
// ErrorRow return an error row
func ErrorRow(err error) *Row {
return &Row{
err: err,
}
}
// NewRow from rows
func NewRow(rows *Rows, err error) *Row {
return &Row{rows, err}
}
func (row *Row) Columns() ([]string, error) {
if row.err != nil {
return nil, row.err

View File

@ -22,6 +22,7 @@ type Table struct {
Cacher Cacher
StoreEngine string
Charset string
Comment string
}
func (table *Table) Columns() []*Column {

View File

@ -101,6 +101,7 @@ var (
Bytea = "BYTEA"
Bool = "BOOL"
Boolean = "BOOLEAN"
Serial = "SERIAL"
BigSerial = "BIGSERIAL"
@ -163,7 +164,7 @@ var (
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
)
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparision
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison
var (
c_EMPTY_STRING string
c_BOOL_DEFAULT bool

View File

@ -21,7 +21,6 @@ type LRUCacher struct {
sqlIndex map[string]map[string]*list.Element
store core.CacheStore
mutex sync.Mutex
// maxSize int
MaxElementSize int
Expired time.Duration
GcInterval time.Duration
@ -54,8 +53,6 @@ func (m *LRUCacher) RunGC() {
// GC check ids lit and sql list to remove all element expired
func (m *LRUCacher) GC() {
//fmt.Println("begin gc ...")
//defer fmt.Println("end gc ...")
m.mutex.Lock()
defer m.mutex.Unlock()
var removedNum int
@ -64,12 +61,10 @@ func (m *LRUCacher) GC() {
time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired {
removedNum++
next := e.Next()
//fmt.Println("removing ...", e.Value)
node := e.Value.(*idNode)
m.delBean(node.tbName, node.id)
e = next
} else {
//fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len())
break
}
}
@ -80,12 +75,10 @@ func (m *LRUCacher) GC() {
time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired {
removedNum++
next := e.Next()
//fmt.Println("removing ...", e.Value)
node := e.Value.(*sqlNode)
m.delIds(node.tbName, node.sql)
e = next
} else {
//fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len())
break
}
}
@ -116,7 +109,6 @@ func (m *LRUCacher) GetIds(tableName, sql string) interface{} {
}
m.delIds(tableName, sql)
return nil
}
@ -134,7 +126,6 @@ func (m *LRUCacher) GetBean(tableName string, id string) interface{} {
// if expired, remove the node and return nil
if time.Now().Sub(lastTime) > m.Expired {
m.delBean(tableName, id)
//m.clearIds(tableName)
return nil
}
m.idList.MoveToBack(el)
@ -148,7 +139,6 @@ func (m *LRUCacher) GetBean(tableName string, id string) interface{} {
// store bean is not exist, then remove memory's index
m.delBean(tableName, id)
//m.clearIds(tableName)
return nil
}
@ -166,8 +156,8 @@ func (m *LRUCacher) clearIds(tableName string) {
// ClearIds clears all sql-ids mapping on table tableName from cache
func (m *LRUCacher) ClearIds(tableName string) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.clearIds(tableName)
m.mutex.Unlock()
}
func (m *LRUCacher) clearBeans(tableName string) {
@ -184,14 +174,13 @@ func (m *LRUCacher) clearBeans(tableName string) {
// ClearBeans clears all beans in some table
func (m *LRUCacher) ClearBeans(tableName string) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.clearBeans(tableName)
m.mutex.Unlock()
}
// PutIds pus ids into table
func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, ok := m.sqlIndex[tableName]; !ok {
m.sqlIndex[tableName] = make(map[string]*list.Element)
}
@ -207,12 +196,12 @@ func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) {
node := e.Value.(*sqlNode)
m.delIds(node.tbName, node.sql)
}
m.mutex.Unlock()
}
// PutBean puts beans into table
func (m *LRUCacher) PutBean(tableName string, id string, obj interface{}) {
m.mutex.Lock()
defer m.mutex.Unlock()
var el *list.Element
var ok bool
@ -229,6 +218,7 @@ func (m *LRUCacher) PutBean(tableName string, id string, obj interface{}) {
node := e.Value.(*idNode)
m.delBean(node.tbName, node.id)
}
m.mutex.Unlock()
}
func (m *LRUCacher) delIds(tableName, sql string) {
@ -244,8 +234,8 @@ func (m *LRUCacher) delIds(tableName, sql string) {
// DelIds deletes ids
func (m *LRUCacher) DelIds(tableName, sql string) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.delIds(tableName, sql)
m.mutex.Unlock()
}
func (m *LRUCacher) delBean(tableName string, id string) {
@ -261,8 +251,8 @@ func (m *LRUCacher) delBean(tableName string, id string) {
// DelBean deletes beans in some table
func (m *LRUCacher) DelBean(tableName string, id string) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.delBean(tableName, id)
m.mutex.Unlock()
}
type idNode struct {

26
vendor/github.com/go-xorm/xorm/context.go generated vendored Normal file
View File

@ -0,0 +1,26 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package xorm
import "context"
// PingContext tests if database is alive
func (engine *Engine) PingContext(ctx context.Context) error {
session := engine.NewSession()
defer session.Close()
return session.PingContext(ctx)
}
// PingContext test if database is ok
func (session *Session) PingContext(ctx context.Context) error {
if session.isAutoClose {
defer session.Close()
}
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().PingContext(ctx)
}

View File

@ -209,10 +209,10 @@ func convertAssign(dest, src interface{}) error {
if src == nil {
dv.Set(reflect.Zero(dv.Type()))
return nil
} else {
}
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssign(dv.Interface(), src)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
@ -247,3 +247,102 @@ func convertAssign(dest, src interface{}) error {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
}
func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
switch tp.Kind() {
case reflect.Int64:
return vv.Int(), nil
case reflect.Int:
return int(vv.Int()), nil
case reflect.Int32:
return int32(vv.Int()), nil
case reflect.Int16:
return int16(vv.Int()), nil
case reflect.Int8:
return int8(vv.Int()), nil
case reflect.Uint64:
return vv.Uint(), nil
case reflect.Uint:
return uint(vv.Uint()), nil
case reflect.Uint32:
return uint32(vv.Uint()), nil
case reflect.Uint16:
return uint16(vv.Uint()), nil
case reflect.Uint8:
return uint8(vv.Uint()), nil
case reflect.String:
return vv.String(), nil
case reflect.Slice:
if tp.Elem().Kind() == reflect.Uint8 {
v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64)
if err != nil {
return nil, err
}
return v, nil
}
}
return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv)
}
func convertFloat(v interface{}) (float64, error) {
switch v.(type) {
case float32:
return float64(v.(float32)), nil
case float64:
return v.(float64), nil
case string:
i, err := strconv.ParseFloat(v.(string), 64)
if err != nil {
return 0, err
}
return i, nil
case []byte:
i, err := strconv.ParseFloat(string(v.([]byte)), 64)
if err != nil {
return 0, err
}
return i, nil
}
return 0, fmt.Errorf("unsupported type: %v", v)
}
func convertInt(v interface{}) (int64, error) {
switch v.(type) {
case int:
return int64(v.(int)), nil
case int8:
return int64(v.(int8)), nil
case int16:
return int64(v.(int16)), nil
case int32:
return int64(v.(int32)), nil
case int64:
return v.(int64), nil
case []byte:
i, err := strconv.ParseInt(string(v.([]byte)), 10, 64)
if err != nil {
return 0, err
}
return i, nil
case string:
i, err := strconv.ParseInt(v.(string), 10, 64)
if err != nil {
return 0, err
}
return i, nil
}
return 0, fmt.Errorf("unsupported type: %v", v)
}
func asBool(bs []byte) (bool, error) {
if len(bs) == 0 {
return false, nil
}
if bs[0] == 0x00 {
return false, nil
} else if bs[0] == 0x01 {
return true, nil
}
return strconv.ParseBool(string(bs))
}

View File

@ -215,10 +215,10 @@ func (db *mssql) SqlType(c *core.Column) string {
var res string
switch t := c.SQLType.Name; t {
case core.Bool:
res = core.TinyInt
if c.Default == "true" {
res = core.Bit
if strings.EqualFold(c.Default, "true") {
c.Default = "1"
} else if c.Default == "false" {
} else {
c.Default = "0"
}
case core.Serial:
@ -250,6 +250,9 @@ func (db *mssql) SqlType(c *core.Column) string {
case core.Uuid:
res = core.Varchar
c.Length = 40
case core.TinyInt:
res = core.TinyInt
c.Length = 0
default:
res = t
}
@ -335,9 +338,15 @@ func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) {
func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{}
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
replace(replace(isnull(c.text,''),'(',''),')','') as vdefault
from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id
replace(replace(isnull(c.text,''),'(',''),')','') as vdefault,
ISNULL(i.is_primary_key, 0)
from sys.columns a
left join sys.types b on a.user_type_id=b.user_type_id
left join sys.syscomments c on a.default_object_id=c.id
LEFT OUTER JOIN
sys.index_columns ic ON ic.object_id = a.object_id AND ic.column_id = a.column_id
LEFT OUTER JOIN
sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id
where a.object_id=object_id('` + tableName + `')`
db.LogSQL(s, args)
@ -352,8 +361,8 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
for rows.Next() {
var name, ctype, vdefault string
var maxLen, precision, scale int
var nullable bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &vdefault)
var nullable, isPK bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &vdefault, &isPK)
if err != nil {
return nil, nil, err
}
@ -363,6 +372,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
col.Name = strings.Trim(name, "` ")
col.Nullable = nullable
col.Default = vdefault
col.IsPrimaryKey = isPK
ct := strings.ToUpper(ctype)
if ct == "DECIMAL" {
col.Length = precision
@ -468,9 +478,10 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
}
colName = strings.Trim(colName, "` ")
var isRegular bool
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName):]
isRegular = true
}
var index *core.Index
@ -479,6 +490,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
index = new(core.Index)
index.Type = indexType
index.Name = indexName
index.IsRegular = isRegular
indexes[indexName] = index
}
index.AddColumn(colName)
@ -534,7 +546,6 @@ type odbcDriver struct {
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
kv := strings.Split(dataSourceName, ";")
var dbName string
for _, c := range kv {
vv := strings.Split(strings.TrimSpace(c), "=")
if len(vv) == 2 {

View File

@ -299,7 +299,7 @@ func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{db.DbName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
" `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
@ -314,13 +314,14 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
col := new(core.Column)
col.Indexes = make(map[string]int)
var columnName, isNullable, colType, colKey, extra string
var columnName, isNullable, colType, colKey, extra, comment string
var colDefault *string
err = rows.Scan(&columnName, &isNullable, &colDefault, &colType, &colKey, &extra)
err = rows.Scan(&columnName, &isNullable, &colDefault, &colType, &colKey, &extra, &comment)
if err != nil {
return nil, nil, err
}
col.Name = strings.Trim(columnName, "` ")
col.Comment = comment
if "YES" == isNullable {
col.Nullable = true
}
@ -407,7 +408,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
func (db *mysql) GetTables() ([]*core.Table, error) {
args := []interface{}{db.DbName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from " +
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " +
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
db.LogSQL(s, args)
@ -420,14 +421,15 @@ func (db *mysql) GetTables() ([]*core.Table, error) {
tables := make([]*core.Table, 0)
for rows.Next() {
table := core.NewEmptyTable()
var name, engine, tableRows string
var name, engine, tableRows, comment string
var autoIncr *string
err = rows.Scan(&name, &engine, &tableRows, &autoIncr)
err = rows.Scan(&name, &engine, &tableRows, &autoIncr, &comment)
if err != nil {
return nil, err
}
table.Name = name
table.Comment = comment
table.StoreEngine = engine
tables = append(tables, table)
}

View File

@ -824,6 +824,12 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
indexName = strings.Trim(indexName, `" `)
var isRegular bool
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName):]
isRegular = true
}
if uniqueness == "UNIQUE" {
indexType = core.UniqueType
} else {
@ -836,6 +842,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
index = new(core.Index)
index.Type = indexType
index.Name = indexName
index.IsRegular = isRegular
indexes[indexName] = index
}
index.AddColumn(colName)

View File

@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"net/url"
"sort"
"strconv"
"strings"
@ -781,6 +780,9 @@ func (db *postgres) SqlType(c *core.Column) string {
case core.TinyInt:
res = core.SmallInt
return res
case core.Bit:
res = core.Boolean
return res
case core.MediumInt, core.Int, core.Integer:
if c.IsAutoIncrement {
return core.Serial
@ -1078,9 +1080,10 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
}
cs := strings.Split(indexdef, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
var isRegular bool
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
newIdxName := indexName[5+len(tableName):]
isRegular = true
if newIdxName != "" {
indexName = newIdxName
}
@ -1090,6 +1093,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `))
}
index.IsRegular = isRegular
indexes[index.Name] = index
}
return indexes, nil
@ -1112,10 +1116,6 @@ func (vs values) Get(k string) (v string) {
return vs[k]
}
func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}
func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
@ -1126,46 +1126,18 @@ func parseURL(connstr string) (string, error) {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+escaper.Replace(v))
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
i := strings.Index(u.Host, ":")
if i < 0 {
accrue("host", u.Host)
} else {
accrue("host", u.Host[:i])
accrue("port", u.Host[i+1:])
}
if u.Path != "" {
accrue("dbname", u.Path[1:])
return escaper.Replace(u.Path[1:]), nil
}
q := u.Query()
for k := range q {
accrue(k, q.Get(k))
return "", nil
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
}
func parseOpts(name string, o values) {
func parseOpts(name string, o values) error {
if len(name) == 0 {
return
return fmt.Errorf("invalid options: %s", name)
}
name = strings.TrimSpace(name)
@ -1174,31 +1146,36 @@ func parseOpts(name string, o values) {
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
errorf("invalid option: %q", p)
return fmt.Errorf("invalid option: %q", p)
}
o.Set(kv[0], kv[1])
}
return nil
}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES}
o := make(values)
var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
dataSourceName, err = parseURL(dataSourceName)
db.DbName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
}
} else {
o := make(values)
err = parseOpts(dataSourceName, o)
if err != nil {
return nil, err
}
parseOpts(dataSourceName, o)
db.DbName = o.Get("dbname")
}
if db.DbName == "" {
return nil, errors.New("dbname is empty")
}
/*db.Schema = o.Get("schema")
if len(db.Schema) == 0 {
db.Schema = "public"
}*/
return db, nil
}

View File

@ -14,10 +14,6 @@ import (
"github.com/go-xorm/core"
)
// func init() {
// RegisterDialect("sqlite3", &sqlite3{})
// }
var (
sqlite3ReservedWords = map[string]bool{
"ABORT": true,
@ -310,11 +306,25 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
for _, colStr := range colCreates {
reg = regexp.MustCompile(`,\s`)
colStr = reg.ReplaceAllString(colStr, ",")
if strings.HasPrefix(strings.TrimSpace(colStr), "PRIMARY KEY") {
parts := strings.Split(strings.TrimSpace(colStr), "(")
if len(parts) == 2 {
pkCols := strings.Split(strings.TrimRight(strings.TrimSpace(parts[1]), ")"), ",")
for _, pk := range pkCols {
if col, ok := cols[strings.Trim(strings.TrimSpace(pk), "`")]; ok {
col.IsPrimaryKey = true
}
}
}
continue
}
fields := strings.Fields(strings.TrimSpace(colStr))
col := new(core.Column)
col.Indexes = make(map[string]int)
col.Nullable = true
col.DefaultIsEmpty = true
for idx, field := range fields {
if idx == 0 {
col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`)
@ -405,8 +415,10 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
}
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
var isRegular bool
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
index.Name = indexName[5+len(tableName):]
isRegular = true
} else {
index.Name = indexName
}
@ -425,6 +437,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
for _, col := range colIndexes {
index.Cols = append(index.Cols, strings.Trim(col, "` []"))
}
index.IsRegular = isRegular
indexes[index.Name] = index
}

View File

@ -8,7 +8,7 @@ Package xorm is a simple and powerful ORM for Go.
Installation
Make sure you have installed Go 1.1+ and then:
Make sure you have installed Go 1.6+ and then:
go get github.com/go-xorm/xorm
@ -51,11 +51,15 @@ There are 8 major ORM methods and many helpful methods to use to operate databas
// INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values (),(),()
2. Query one record from database
2. Query one record or one variable from database
has, err := engine.Get(&user)
// SELECT * FROM user LIMIT 1
var id int64
has, err := engine.Table("user").Where("name = ?", name).Get(&id)
// SELECT id FROM user WHERE name = ? LIMIT 1
3. Query multiple records from database
var sliceOfStructs []Struct
@ -86,7 +90,7 @@ another is Rows
5. Update one or more records
affected, err := engine.Id(...).Update(&user)
affected, err := engine.ID(...).Update(&user)
// UPDATE user SET ...
6. Delete one or more records, Delete MUST has condition
@ -99,6 +103,9 @@ another is Rows
counts, err := engine.Count(&user)
// SELECT count(*) AS total FROM user
counts, err := engine.SQL("select count(*) FROM user").Count()
// select count(*) FROM user
8. Sum records
sumFloat64, err := engine.Sum(&user, "id")

View File

@ -19,6 +19,7 @@ import (
"sync"
"time"
"github.com/go-xorm/builder"
"github.com/go-xorm/core"
)
@ -40,12 +41,29 @@ type Engine struct {
showExecTime bool
logger core.ILogger
TZLocation *time.Location
TZLocation *time.Location // The timezone of the application
DatabaseTZ *time.Location // The timezone of the database
disableGlobalCache bool
tagHandlers map[string]tagHandler
engineGroup *EngineGroup
}
// BufferSize sets buffer size for iterate
func (engine *Engine) BufferSize(size int) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.BufferSize(size)
}
// CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(colName string) builder.Cond {
if engine.dialect.DBType() == core.MSSQL {
return builder.IsNull{colName}
}
return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1})
}
// ShowSQL show SQL statement or not on logger if log level is great than INFO
@ -78,6 +96,11 @@ func (engine *Engine) SetLogger(logger core.ILogger) {
engine.dialect.SetLogger(logger)
}
// SetLogLevel sets the logger level
func (engine *Engine) SetLogLevel(level core.LogLevel) {
engine.logger.SetLevel(level)
}
// SetDisableGlobalCache disable global cache or not
func (engine *Engine) SetDisableGlobalCache(disable bool) {
if engine.disableGlobalCache != disable {
@ -143,7 +166,6 @@ func (engine *Engine) Quote(value string) string {
// QuoteTo quotes string and writes into the buffer
func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) {
if buf == nil {
return
}
@ -169,7 +191,7 @@ func (engine *Engine) quote(sql string) string {
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
}
// SqlType will be depracated, please use SQLType instead
// SqlType will be deprecated, please use SQLType instead
//
// Deprecated: use SQLType instead
func (engine *Engine) SqlType(c *core.Column) string {
@ -201,26 +223,36 @@ func (engine *Engine) SetDefaultCacher(cacher core.Cacher) {
engine.Cacher = cacher
}
// GetDefaultCacher returns the default cacher
func (engine *Engine) GetDefaultCacher() core.Cacher {
return engine.Cacher
}
// NoCache If you has set default cacher, and you want temporilly stop use cache,
// you can use NoCache()
func (engine *Engine) NoCache() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.NoCache()
}
// NoCascade If you do not want to auto cascade load object
func (engine *Engine) NoCascade() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.NoCascade()
}
// MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) {
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error {
v := rValue(bean)
tb := engine.autoMapType(v)
tb, err := engine.autoMapType(v)
if err != nil {
return err
}
tb.Cacher = cacher
return nil
}
// NewDB provides an interface to operate database directly
@ -240,7 +272,7 @@ func (engine *Engine) Dialect() core.Dialect {
// NewSession New a session
func (engine *Engine) NewSession() *Session {
session := &Session{Engine: engine}
session := &Session{engine: engine}
session.Init()
return session
}
@ -254,7 +286,6 @@ func (engine *Engine) Close() error {
func (engine *Engine) Ping() error {
session := engine.NewSession()
defer session.Close()
engine.logger.Infof("PING DATABASE %v", engine.DriverName())
return session.Ping()
}
@ -262,43 +293,13 @@ func (engine *Engine) Ping() error {
func (engine *Engine) logSQL(sqlStr string, sqlArgs ...interface{}) {
if engine.showSQL && !engine.showExecTime {
if len(sqlArgs) > 0 {
engine.logger.Infof("[SQL] %v %v", sqlStr, sqlArgs)
engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs)
} else {
engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
func (engine *Engine) logSQLQueryTime(sqlStr string, args []interface{}, executionBlock func() (*core.Stmt, *core.Rows, error)) (*core.Stmt, *core.Rows, error) {
if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now()
stmt, res, err := executionBlock()
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
engine.logger.Infof("[SQL] %s %v - took: %v", sqlStr, args, execDuration)
} else {
engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
return stmt, res, err
}
return executionBlock()
}
func (engine *Engine) logSQLExecutionTime(sqlStr string, args []interface{}, executionBlock func() (sql.Result, error)) (sql.Result, error) {
if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now()
res, err := executionBlock()
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
engine.logger.Infof("[sql] %s [args] %v - took: %v", sqlStr, args, execDuration)
} else {
engine.logger.Infof("[sql] %s - took: %v", sqlStr, execDuration)
}
return res, err
}
return executionBlock()
}
// Sql provides raw sql input parameter. When you have a complex SQL statement
// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL.
//
@ -315,7 +316,7 @@ func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
// This code will execute "select * from user" and set the records to users
func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.SQL(query, args...)
}
@ -324,14 +325,14 @@ func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session {
// invoked. Call NoAutoTime if you dont' want to fill automatically.
func (engine *Engine) NoAutoTime() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.NoAutoTime()
}
// NoAutoCondition disable auto generate Where condition from bean or not
func (engine *Engine) NoAutoCondition(no ...bool) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.NoAutoCondition(no...)
}
@ -565,56 +566,56 @@ func (engine *Engine) tbName(v reflect.Value) string {
// Cascade use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Cascade(trueOrFalse...)
}
// Where method provide a condition query
func (engine *Engine) Where(query interface{}, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Where(query, args...)
}
// Id will be depracated, please use ID instead
// Id will be deprecated, please use ID instead
func (engine *Engine) Id(id interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Id(id)
}
// ID method provoide a condition as (id) = ?
func (engine *Engine) ID(id interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.ID(id)
}
// Before apply before Processor, affected bean is passed to closure arg
func (engine *Engine) Before(closures func(interface{})) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Before(closures)
}
// After apply after insert Processor, affected bean is passed to closure arg
func (engine *Engine) After(closures func(interface{})) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.After(closures)
}
// Charset set charset when create table, only support mysql now
func (engine *Engine) Charset(charset string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Charset(charset)
}
// StoreEngine set store engine when create table, only support mysql now
func (engine *Engine) StoreEngine(storeEngine string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.StoreEngine(storeEngine)
}
@ -623,35 +624,35 @@ func (engine *Engine) StoreEngine(storeEngine string) *Session {
// but distinct will not provide id
func (engine *Engine) Distinct(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Distinct(columns...)
}
// Select customerize your select columns or contents
func (engine *Engine) Select(str string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Select(str)
}
// Cols only use the parameters as select or update columns
func (engine *Engine) Cols(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Cols(columns...)
}
// AllCols indicates that all columns should be use
func (engine *Engine) AllCols() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.AllCols()
}
// MustCols specify some columns must use even if they are empty
func (engine *Engine) MustCols(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.MustCols(columns...)
}
@ -662,77 +663,84 @@ func (engine *Engine) MustCols(columns ...string) *Session {
// it will use parameters's columns
func (engine *Engine) UseBool(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.UseBool(columns...)
}
// Omit only not use the parameters as select or update columns
func (engine *Engine) Omit(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Omit(columns...)
}
// Nullable set null when column is zero-value and nullable for update
func (engine *Engine) Nullable(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Nullable(columns...)
}
// In will generate "column IN (?, ?)"
func (engine *Engine) In(column string, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.In(column, args...)
}
// NotIn will generate "column NOT IN (?, ?)"
func (engine *Engine) NotIn(column string, args ...interface{}) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.NotIn(column, args...)
}
// Incr provides a update string like "column = column + ?"
func (engine *Engine) Incr(column string, arg ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Incr(column, arg...)
}
// Decr provides a update string like "column = column - ?"
func (engine *Engine) Decr(column string, arg ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Decr(column, arg...)
}
// SetExpr provides a update string like "column = {expression}"
func (engine *Engine) SetExpr(column string, expression string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.SetExpr(column, expression)
}
// Table temporarily change the Get, Find, Update's table
func (engine *Engine) Table(tableNameOrBean interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Table(tableNameOrBean)
}
// Alias set the table alias
func (engine *Engine) Alias(alias string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Alias(alias)
}
// Limit will generate "LIMIT start, limit"
func (engine *Engine) Limit(limit int, start ...int) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Limit(limit, start...)
}
// Desc will generate "ORDER BY column1 DESC, column2 DESC"
func (engine *Engine) Desc(colNames ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Desc(colNames...)
}
@ -744,39 +752,53 @@ func (engine *Engine) Desc(colNames ...string) *Session {
//
func (engine *Engine) Asc(colNames ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Asc(colNames...)
}
// OrderBy will generate "ORDER BY order"
func (engine *Engine) OrderBy(order string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.OrderBy(order)
}
// Prepare enables prepare statement
func (engine *Engine) Prepare() *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.Prepare()
}
// Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Join(joinOperator, tablename, condition, args...)
}
// GroupBy generate group by statement
func (engine *Engine) GroupBy(keys string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.GroupBy(keys)
}
// Having generate having statement
func (engine *Engine) Having(conditions string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Having(conditions)
}
func (engine *Engine) autoMapType(v reflect.Value) *core.Table {
// UnMapType removes the datbase mapper of a type
func (engine *Engine) UnMapType(t reflect.Type) {
engine.mutex.Lock()
defer engine.mutex.Unlock()
delete(engine.Tables, t)
}
func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) {
t := v.Type()
engine.mutex.Lock()
defer engine.mutex.Unlock()
@ -785,8 +807,9 @@ func (engine *Engine) autoMapType(v reflect.Value) *core.Table {
var err error
table, err = engine.mapType(v)
if err != nil {
engine.logger.Error(err)
} else {
return nil, err
}
engine.Tables[t] = table
if engine.Cacher != nil {
if v.CanAddr() {
@ -796,13 +819,11 @@ func (engine *Engine) autoMapType(v reflect.Value) *core.Table {
}
}
}
}
return table
return table, nil
}
// GobRegister register one struct to gob for cache use
func (engine *Engine) GobRegister(v interface{}) *Engine {
//fmt.Printf("Type: %[1]T => Data: %[1]#v\n", v)
gob.Register(v)
return engine
}
@ -813,10 +834,19 @@ type Table struct {
Name string
}
// IsValid if table is valid
func (t *Table) IsValid() bool {
return t.Table != nil && len(t.Name) > 0
}
// TableInfo get table info according to bean's content
func (engine *Engine) TableInfo(bean interface{}) *Table {
v := rValue(bean)
return &Table{engine.autoMapType(v), engine.tbName(v)}
tb, err := engine.autoMapType(v)
if err != nil {
engine.logger.Error(err)
}
return &Table{tb, engine.tbName(v)}
}
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@ -911,6 +941,7 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
k := strings.ToUpper(key)
ctx.tagName = k
ctx.params = []string{}
pStart := strings.Index(k, "(")
if pStart == 0 {
@ -918,18 +949,18 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
}
if pStart > -1 {
if !strings.HasSuffix(k, ")") {
return nil, errors.New("cannot match ) charactor")
return nil, fmt.Errorf("field %s tag %s cannot match ) charactor", col.FieldName, key)
}
ctx.tagName = k[:pStart]
ctx.params = strings.Split(k[pStart+1:len(k)-1], ",")
ctx.params = strings.Split(key[pStart+1:len(k)-1], ",")
}
if j > 0 {
ctx.preTag = strings.ToUpper(tags[j-1])
}
if j < len(tags)-1 {
ctx.nextTag = strings.ToUpper(tags[j+1])
ctx.nextTag = tags[j+1]
} else {
ctx.nextTag = ""
}
@ -993,6 +1024,10 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name),
t.Field(i).Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true)
if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
idFieldColName = col.Name
}
}
if col.IsAutoIncrement {
col.Nullable = false
@ -1000,9 +1035,6 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
table.AddColumn(col)
if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
idFieldColName = col.Name
}
} // end for
if idFieldColName != "" && len(table.PrimaryKeys) == 0 {
@ -1066,21 +1098,54 @@ func (engine *Engine) IdOfV(rv reflect.Value) core.PK {
// IDOfV get id from one value of struct
func (engine *Engine) IDOfV(rv reflect.Value) core.PK {
pk, err := engine.idOfV(rv)
if err != nil {
engine.logger.Error(err)
return nil
}
return pk
}
func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) {
v := reflect.Indirect(rv)
table := engine.autoMapType(v)
table, err := engine.autoMapType(v)
if err != nil {
return nil, err
}
pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
var err error
pkField := v.FieldByName(col.FieldName)
switch pkField.Kind() {
case reflect.String:
pk[i] = pkField.String()
pk[i], err = engine.idTypeAssertion(col, pkField.String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
pk[i] = pkField.Int()
pk[i], err = engine.idTypeAssertion(col, strconv.FormatInt(pkField.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
pk[i] = pkField.Uint()
// id of uint will be converted to int64
pk[i], err = engine.idTypeAssertion(col, strconv.FormatUint(pkField.Uint(), 10))
}
if err != nil {
return nil, err
}
}
return core.PK(pk)
return core.PK(pk), nil
}
func (engine *Engine) idTypeAssertion(col *core.Column, sid string) (interface{}, error) {
if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(sid, 10, 64)
if err != nil {
return nil, err
}
return n, nil
} else if col.SQLType.IsText() {
return sid, nil
} else {
return nil, errors.New("not supported")
}
}
// CreateIndexes create indexes
@ -1101,13 +1166,6 @@ func (engine *Engine) getCacher2(table *core.Table) core.Cacher {
return table.Cacher
}
func (engine *Engine) getCacher(v reflect.Value) core.Cacher {
if table := engine.autoMapType(v); table != nil {
return table.Cacher
}
return engine.Cacher
}
// ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
v := rValue(bean)
@ -1116,7 +1174,10 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
return errors.New("error params")
}
tableName := engine.tbName(v)
table := engine.autoMapType(v)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
@ -1137,7 +1198,11 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
return errors.New("error params")
}
tableName := engine.tbName(v)
table := engine.autoMapType(v)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
@ -1154,19 +1219,23 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
// table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually.
func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
defer session.Close()
for _, bean := range beans {
v := rValue(bean)
tableName := engine.tbName(v)
table := engine.autoMapType(v)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
s := engine.NewSession()
defer s.Close()
isExist, err := s.Table(bean).isTableExist(tableName)
isExist, err := session.Table(bean).isTableExist(tableName)
if err != nil {
return err
}
if !isExist {
err = engine.CreateTables(bean)
err = session.createTable(bean)
if err != nil {
return err
}
@ -1177,11 +1246,11 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}*/
var isEmpty bool
if isEmpty {
err = engine.DropTables(bean)
err = session.dropTable(bean)
if err != nil {
return err
}
err = engine.CreateTables(bean)
err = session.createTable(bean)
if err != nil {
return err
}
@ -1192,9 +1261,9 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
session := engine.NewSession()
session.Statement.setRefValue(v)
defer session.Close()
if err := session.statement.setRefValue(v); err != nil {
return err
}
err = session.addColumn(col.Name)
if err != nil {
return err
@ -1203,19 +1272,19 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}
for name, index := range table.Indexes {
session := engine.NewSession()
session.Statement.setRefValue(v)
defer session.Close()
if err := session.statement.setRefValue(v); err != nil {
return err
}
if index.Type == core.UniqueType {
//isExist, err := session.isIndexExist(table.Name, name, true)
isExist, err := session.isIndexExist2(tableName, index.Cols, true)
if err != nil {
return err
}
if !isExist {
session := engine.NewSession()
session.Statement.setRefValue(v)
defer session.Close()
if err := session.statement.setRefValue(v); err != nil {
return err
}
err = session.addUnique(tableName, name)
if err != nil {
return err
@ -1227,9 +1296,10 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
session := engine.NewSession()
session.Statement.setRefValue(v)
defer session.Close()
if err := session.statement.setRefValue(v); err != nil {
return err
}
err = session.addIndex(tableName, name)
if err != nil {
return err
@ -1251,35 +1321,6 @@ func (engine *Engine) Sync2(beans ...interface{}) error {
return s.Sync2(beans...)
}
func (engine *Engine) unMap(beans ...interface{}) (e error) {
engine.mutex.Lock()
defer engine.mutex.Unlock()
for _, bean := range beans {
t := rType(bean)
if _, ok := engine.Tables[t]; ok {
delete(engine.Tables, t)
}
}
return
}
// Drop all mapped table
func (engine *Engine) dropAll() error {
session := engine.NewSession()
defer session.Close()
err := session.Begin()
if err != nil {
return err
}
err = session.dropAll()
if err != nil {
session.Rollback()
return err
}
return session.Commit()
}
// CreateTables create tabls according bean
func (engine *Engine) CreateTables(beans ...interface{}) error {
session := engine.NewSession()
@ -1291,7 +1332,7 @@ func (engine *Engine) CreateTables(beans ...interface{}) error {
}
for _, bean := range beans {
err = session.CreateTable(bean)
err = session.createTable(bean)
if err != nil {
session.Rollback()
return err
@ -1311,7 +1352,7 @@ func (engine *Engine) DropTables(beans ...interface{}) error {
}
for _, bean := range beans {
err = session.DropTable(bean)
err = session.dropTable(bean)
if err != nil {
session.Rollback()
return err
@ -1320,10 +1361,11 @@ func (engine *Engine) DropTables(beans ...interface{}) error {
return session.Commit()
}
func (engine *Engine) createAll() error {
// DropIndexes drop indexes of a table
func (engine *Engine) DropIndexes(bean interface{}) error {
session := engine.NewSession()
defer session.Close()
return session.createAll()
return session.DropIndexes(bean)
}
// Exec raw sql
@ -1334,10 +1376,24 @@ func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error)
}
// Query a raw sql and return records as []map[string][]byte
func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
func (engine *Engine) Query(sqlorArgs ...interface{}) (resultsSlice []map[string][]byte, err error) {
session := engine.NewSession()
defer session.Close()
return session.Query(sql, paramStr...)
return session.Query(sqlorArgs...)
}
// QueryString runs a raw sql and return records as []map[string]string
func (engine *Engine) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) {
session := engine.NewSession()
defer session.Close()
return session.QueryString(sqlorArgs...)
}
// QueryInterface runs a raw sql and return records as []map[string]interface{}
func (engine *Engine) QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) {
session := engine.NewSession()
defer session.Close()
return session.QueryInterface(sqlorArgs...)
}
// Insert one or more records
@ -1381,6 +1437,13 @@ func (engine *Engine) Get(bean interface{}) (bool, error) {
return session.Get(bean)
}
// Exist returns true if the record exist otherwise return false
func (engine *Engine) Exist(bean ...interface{}) (bool, error) {
session := engine.NewSession()
defer session.Close()
return session.Exist(bean...)
}
// Find retrieve records from table, condiBeans's non-empty fields
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct
@ -1406,10 +1469,10 @@ func (engine *Engine) Rows(bean interface{}) (*Rows, error) {
}
// Count counts the records. bean's non-empty fields are conditions.
func (engine *Engine) Count(bean interface{}) (int64, error) {
func (engine *Engine) Count(bean ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.Count(bean)
return session.Count(bean...)
}
// Sum sum the records by some column. bean's non-empty fields are conditions.
@ -1419,6 +1482,13 @@ func (engine *Engine) Sum(bean interface{}, colName string) (float64, error) {
return session.Sum(bean, colName)
}
// SumInt sum the records by some column. bean's non-empty fields are conditions.
func (engine *Engine) SumInt(bean interface{}, colName string) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.SumInt(bean, colName)
}
// Sums sum the records by some columns. bean's non-empty fields are conditions.
func (engine *Engine) Sums(bean interface{}, colNames ...string) ([]float64, error) {
session := engine.NewSession()
@ -1474,7 +1544,6 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) {
results = append(results, result)
if err != nil {
return nil, err
//lastError = err
}
}
}
@ -1482,49 +1551,32 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) {
return results, lastError
}
// TZTime change one time to xorm time location
func (engine *Engine) TZTime(t time.Time) time.Time {
if !t.IsZero() { // if time is not initialized it's not suitable for Time.In()
return t.In(engine.TZLocation)
}
return t
}
// NowTime return current time
func (engine *Engine) NowTime(sqlTypeName string) interface{} {
// nowTime return current time
func (engine *Engine) nowTime(col *core.Column) (interface{}, time.Time) {
t := time.Now()
return engine.FormatTime(sqlTypeName, t)
var tz = engine.DatabaseTZ
if !col.DisableTimeZone && col.TimeZone != nil {
tz = col.TimeZone
}
// NowTime2 return current time
func (engine *Engine) NowTime2(sqlTypeName string) (interface{}, time.Time) {
t := time.Now()
return engine.FormatTime(sqlTypeName, t), t
}
// FormatTime format time
func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}) {
return engine.formatTime(engine.TZLocation, sqlTypeName, t)
return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
}
func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{}) {
if col.DisableTimeZone {
return engine.formatTime(nil, col.SQLType.Name, t)
} else if col.TimeZone != nil {
return engine.formatTime(col.TimeZone, col.SQLType.Name, t)
if t.IsZero() {
if col.Nullable {
return nil
}
return engine.formatTime(engine.TZLocation, col.SQLType.Name, t)
return ""
}
func (engine *Engine) formatTime(tz *time.Location, sqlTypeName string, t time.Time) (v interface{}) {
if engine.dialect.DBType() == core.ORACLE {
return t
if col.TimeZone != nil {
return engine.formatTime(col.SQLType.Name, t.In(col.TimeZone))
}
if tz != nil {
t = t.In(tz)
} else {
t = engine.TZTime(t)
return engine.formatTime(col.SQLType.Name, t.In(engine.DatabaseTZ))
}
// formatTime format time as column type
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case core.Time:
s := t.Format("2006-01-02 15:04:05") //time.RFC3339
@ -1532,18 +1584,10 @@ func (engine *Engine) formatTime(tz *time.Location, sqlTypeName string, t time.T
case core.Date:
v = t.Format("2006-01-02")
case core.DateTime, core.TimeStamp:
if engine.dialect.DBType() == "ql" {
v = t
} else if engine.dialect.DBType() == "sqlite3" {
v = t.UTC().Format("2006-01-02 15:04:05")
} else {
v = t.Format("2006-01-02 15:04:05")
}
case core.TimeStampz:
if engine.dialect.DBType() == core.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else if engine.DriverName() == "mssql" {
v = t
} else {
v = t.Format(time.RFC3339Nano)
}
@ -1555,9 +1599,39 @@ func (engine *Engine) formatTime(tz *time.Location, sqlTypeName string, t time.T
return
}
// GetColumnMapper returns the column name mapper
func (engine *Engine) GetColumnMapper() core.IMapper {
return engine.ColumnMapper
}
// GetTableMapper returns the table name mapper
func (engine *Engine) GetTableMapper() core.IMapper {
return engine.TableMapper
}
// GetTZLocation returns time zone of the application
func (engine *Engine) GetTZLocation() *time.Location {
return engine.TZLocation
}
// SetTZLocation sets time zone of the application
func (engine *Engine) SetTZLocation(tz *time.Location) {
engine.TZLocation = tz
}
// GetTZDatabase returns time zone of the database
func (engine *Engine) GetTZDatabase() *time.Location {
return engine.DatabaseTZ
}
// SetTZDatabase sets time zone of the database
func (engine *Engine) SetTZDatabase(tz *time.Location) {
engine.DatabaseTZ = tz
}
// Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Unscoped()
}

230
vendor/github.com/go-xorm/xorm/engine_cond.go generated vendored Normal file
View File

@ -0,0 +1,230 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"time"
"github.com/go-xorm/builder"
"github.com/go-xorm/core"
)
func (engine *Engine) buildConds(table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
var conds []builder.Cond
for _, col := range table.Columns() {
if !includeVersion && col.IsVersion {
continue
}
if !includeUpdated && col.IsUpdated {
continue
}
if !includeAutoIncr && col.IsAutoIncrement {
continue
}
if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) {
continue
}
if col.SQLType.IsJson() {
continue
}
var colName string
if addedTableName {
var nm = tableName
if len(aliasName) > 0 {
nm = aliasName
}
colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
} else {
colName = engine.Quote(col.Name)
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
engine.logger.Error(err)
continue
}
if col.IsDeleted && !unscoped { // tag "deleted" is enabled
conds = append(conds, engine.CondDeleted(colName))
}
fieldValue := *fieldValuePtr
if fieldValue.Interface() == nil {
continue
}
fieldType := reflect.TypeOf(fieldValue.Interface())
requiredField := useAllCols
if b, ok := getFlagForColumn(mustColumnMap, col); ok {
if b {
requiredField = true
} else {
continue
}
}
if fieldType.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
if includeNil {
conds = append(conds, builder.Eq{colName: nil})
}
continue
} else if !fieldValue.IsValid() {
continue
} else {
// dereference ptr type to instance type
fieldValue = fieldValue.Elem()
fieldType = reflect.TypeOf(fieldValue.Interface())
requiredField = true
}
}
var val interface{}
switch fieldType.Kind() {
case reflect.Bool:
if allUseBool || requiredField {
val = fieldValue.Interface()
} else {
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
continue
}
case reflect.String:
if !requiredField && fieldValue.String() == "" {
continue
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
} else {
val = fieldValue.Interface()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if !requiredField && fieldValue.Int() == 0 {
continue
}
val = fieldValue.Interface()
case reflect.Float32, reflect.Float64:
if !requiredField && fieldValue.Float() == 0.0 {
continue
}
val = fieldValue.Interface()
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if !requiredField && fieldValue.Uint() == 0 {
continue
}
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = engine.formatColTime(col, t)
} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value()
if val == nil {
continue
}
} else {
if col.SQLType.IsJson() {
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !isZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else {
//TODO: how to handler?
return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys)
}
} else {
val = fieldValue.Interface()
}
}
}
case reflect.Array:
continue
case reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
continue
}
if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
continue
}
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
val = fieldValue.Bytes()
} else {
continue
}
} else {
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
continue
}
default:
val = fieldValue.Interface()
}
conds = append(conds, builder.Eq{colName: val})
}
return builder.And(conds...), nil
}

194
vendor/github.com/go-xorm/xorm/engine_group.go generated vendored Normal file
View File

@ -0,0 +1,194 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"github.com/go-xorm/core"
)
// EngineGroup defines an engine group
type EngineGroup struct {
*Engine
slaves []*Engine
policy GroupPolicy
}
// NewEngineGroup creates a new engine group
func NewEngineGroup(args1 interface{}, args2 interface{}, policies ...GroupPolicy) (*EngineGroup, error) {
var eg EngineGroup
if len(policies) > 0 {
eg.policy = policies[0]
} else {
eg.policy = RoundRobinPolicy()
}
driverName, ok1 := args1.(string)
conns, ok2 := args2.([]string)
if ok1 && ok2 {
engines := make([]*Engine, len(conns))
for i, conn := range conns {
engine, err := NewEngine(driverName, conn)
if err != nil {
return nil, err
}
engine.engineGroup = &eg
engines[i] = engine
}
eg.Engine = engines[0]
eg.slaves = engines[1:]
return &eg, nil
}
master, ok3 := args1.(*Engine)
slaves, ok4 := args2.([]*Engine)
if ok3 && ok4 {
master.engineGroup = &eg
for i := 0; i < len(slaves); i++ {
slaves[i].engineGroup = &eg
}
eg.Engine = master
eg.slaves = slaves
return &eg, nil
}
return nil, ErrParamsType
}
// Close the engine
func (eg *EngineGroup) Close() error {
err := eg.Engine.Close()
if err != nil {
return err
}
for i := 0; i < len(eg.slaves); i++ {
err := eg.slaves[i].Close()
if err != nil {
return err
}
}
return nil
}
// Master returns the master engine
func (eg *EngineGroup) Master() *Engine {
return eg.Engine
}
// Ping tests if database is alive
func (eg *EngineGroup) Ping() error {
if err := eg.Engine.Ping(); err != nil {
return err
}
for _, slave := range eg.slaves {
if err := slave.Ping(); err != nil {
return err
}
}
return nil
}
// SetColumnMapper set the column name mapping rule
func (eg *EngineGroup) SetColumnMapper(mapper core.IMapper) {
eg.Engine.ColumnMapper = mapper
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].ColumnMapper = mapper
}
}
// SetDefaultCacher set the default cacher
func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) {
eg.Engine.SetDefaultCacher(cacher)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetDefaultCacher(cacher)
}
}
// SetLogger set the new logger
func (eg *EngineGroup) SetLogger(logger core.ILogger) {
eg.Engine.SetLogger(logger)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetLogger(logger)
}
}
// SetLogLevel sets the logger level
func (eg *EngineGroup) SetLogLevel(level core.LogLevel) {
eg.Engine.SetLogLevel(level)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetLogLevel(level)
}
}
// SetMapper set the name mapping rules
func (eg *EngineGroup) SetMapper(mapper core.IMapper) {
eg.Engine.SetMapper(mapper)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetMapper(mapper)
}
}
// SetMaxIdleConns set the max idle connections on pool, default is 2
func (eg *EngineGroup) SetMaxIdleConns(conns int) {
eg.Engine.db.SetMaxIdleConns(conns)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].db.SetMaxIdleConns(conns)
}
}
// SetMaxOpenConns is only available for go 1.2+
func (eg *EngineGroup) SetMaxOpenConns(conns int) {
eg.Engine.db.SetMaxOpenConns(conns)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].db.SetMaxOpenConns(conns)
}
}
// SetPolicy set the group policy
func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup {
eg.policy = policy
return eg
}
// SetTableMapper set the table name mapping rule
func (eg *EngineGroup) SetTableMapper(mapper core.IMapper) {
eg.Engine.TableMapper = mapper
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].TableMapper = mapper
}
}
// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO
func (eg *EngineGroup) ShowExecTime(show ...bool) {
eg.Engine.ShowExecTime(show...)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].ShowExecTime(show...)
}
}
// ShowSQL show SQL statement or not on logger if log level is great than INFO
func (eg *EngineGroup) ShowSQL(show ...bool) {
eg.Engine.ShowSQL(show...)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].ShowSQL(show...)
}
}
// Slave returns one of the physical databases which is a slave according the policy
func (eg *EngineGroup) Slave() *Engine {
switch len(eg.slaves) {
case 0:
return eg.Engine
case 1:
return eg.slaves[0]
}
return eg.policy.Slave(eg)
}
// Slaves returns all the slaves
func (eg *EngineGroup) Slaves() []*Engine {
return eg.slaves
}

116
vendor/github.com/go-xorm/xorm/engine_group_policy.go generated vendored Normal file
View File

@ -0,0 +1,116 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"math/rand"
"sync"
"time"
)
// GroupPolicy is be used by chosing the current slave from slaves
type GroupPolicy interface {
Slave(*EngineGroup) *Engine
}
// GroupPolicyHandler should be used when a function is a GroupPolicy
type GroupPolicyHandler func(*EngineGroup) *Engine
// Slave implements the chosen of slaves
func (h GroupPolicyHandler) Slave(eg *EngineGroup) *Engine {
return h(eg)
}
// RandomPolicy implmentes randomly chose the slave of slaves
func RandomPolicy() GroupPolicyHandler {
var r = rand.New(rand.NewSource(time.Now().UnixNano()))
return func(g *EngineGroup) *Engine {
return g.Slaves()[r.Intn(len(g.Slaves()))]
}
}
// WeightRandomPolicy implmentes randomly chose the slave of slaves
func WeightRandomPolicy(weights []int) GroupPolicyHandler {
var rands = make([]int, 0, len(weights))
for i := 0; i < len(weights); i++ {
for n := 0; n < weights[i]; n++ {
rands = append(rands, i)
}
}
var r = rand.New(rand.NewSource(time.Now().UnixNano()))
return func(g *EngineGroup) *Engine {
var slaves = g.Slaves()
idx := rands[r.Intn(len(rands))]
if idx >= len(slaves) {
idx = len(slaves) - 1
}
return slaves[idx]
}
}
func RoundRobinPolicy() GroupPolicyHandler {
var pos = -1
var lock sync.Mutex
return func(g *EngineGroup) *Engine {
var slaves = g.Slaves()
lock.Lock()
defer lock.Unlock()
pos++
if pos >= len(slaves) {
pos = 0
}
return slaves[pos]
}
}
func WeightRoundRobinPolicy(weights []int) GroupPolicyHandler {
var rands = make([]int, 0, len(weights))
for i := 0; i < len(weights); i++ {
for n := 0; n < weights[i]; n++ {
rands = append(rands, i)
}
}
var pos = -1
var lock sync.Mutex
return func(g *EngineGroup) *Engine {
var slaves = g.Slaves()
lock.Lock()
defer lock.Unlock()
pos++
if pos >= len(rands) {
pos = 0
}
idx := rands[pos]
if idx >= len(slaves) {
idx = len(slaves) - 1
}
return slaves[idx]
}
}
// LeastConnPolicy implements GroupPolicy, every time will get the least connections slave
func LeastConnPolicy() GroupPolicyHandler {
return func(g *EngineGroup) *Engine {
var slaves = g.Slaves()
connections := 0
idx := 0
for i := 0; i < len(slaves); i++ {
openConnections := slaves[i].DB().Stats().OpenConnections
if i == 0 {
connections = openConnections
idx = i
} else if openConnections <= connections {
connections = openConnections
idx = i
}
}
return slaves[idx]
}
}

22
vendor/github.com/go-xorm/xorm/engine_maxlife.go generated vendored Normal file
View File

@ -0,0 +1,22 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.6
package xorm
import "time"
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
func (engine *Engine) SetConnMaxLifetime(d time.Duration) {
engine.db.SetConnMaxLifetime(d)
}
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
func (eg *EngineGroup) SetConnMaxLifetime(d time.Duration) {
eg.Engine.SetConnMaxLifetime(d)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetConnMaxLifetime(d)
}
}

View File

@ -23,4 +23,6 @@ var (
ErrNeedDeletedCond = errors.New("Delete need at least one condition")
// ErrNotImplemented not implemented
ErrNotImplemented = errors.New("Not implemented")
// ErrConditionType condition type unsupported
ErrConditionType = errors.New("Unsupported conditon type")
)

View File

@ -196,26 +196,44 @@ func isArrayValueZero(v reflect.Value) bool {
func int64ToIntValue(id int64, tp reflect.Type) reflect.Value {
var v interface{}
switch tp.Kind() {
case reflect.Int16:
v = int16(id)
case reflect.Int32:
v = int32(id)
case reflect.Int:
v = int(id)
case reflect.Int64:
v = id
case reflect.Uint16:
v = uint16(id)
case reflect.Uint32:
v = uint32(id)
case reflect.Uint64:
v = uint64(id)
case reflect.Uint:
v = uint(id)
kind := tp.Kind()
if kind == reflect.Ptr {
kind = tp.Elem().Kind()
}
switch kind {
case reflect.Int16:
temp := int16(id)
v = &temp
case reflect.Int32:
temp := int32(id)
v = &temp
case reflect.Int:
temp := int(id)
v = &temp
case reflect.Int64:
temp := id
v = &temp
case reflect.Uint16:
temp := uint16(id)
v = &temp
case reflect.Uint32:
temp := uint32(id)
v = &temp
case reflect.Uint64:
temp := uint64(id)
v = &temp
case reflect.Uint:
temp := uint(id)
v = &temp
}
if tp.Kind() == reflect.Ptr {
return reflect.ValueOf(v).Convert(tp)
}
return reflect.ValueOf(v).Elem().Convert(tp)
}
func int64ToInt(id int64, tp reflect.Type) interface{} {
return int64ToIntValue(id, tp).Interface()
@ -302,180 +320,6 @@ func sliceEq(left, right []string) bool {
return true
}
func reflect2value(rawValue *reflect.Value) (str string, err error) {
aa := reflect.TypeOf((*rawValue).Interface())
vv := reflect.ValueOf((*rawValue).Interface())
switch aa.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
case reflect.String:
str = vv.String()
case reflect.Array, reflect.Slice:
switch aa.Elem().Kind() {
case reflect.Uint8:
data := rawValue.Interface().([]byte)
str = string(data)
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
// time type
case reflect.Struct:
if aa.ConvertibleTo(core.TimeType) {
str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
} else {
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
case reflect.Bool:
str = strconv.FormatBool(vv.Bool())
case reflect.Complex128, reflect.Complex64:
str = fmt.Sprintf("%v", vv.Complex())
/* TODO: unsupported types below
case reflect.Map:
case reflect.Ptr:
case reflect.Uintptr:
case reflect.UnsafePointer:
case reflect.Chan, reflect.Func, reflect.Interface:
*/
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
return
}
func value2Bytes(rawValue *reflect.Value) (data []byte, err error) {
var str string
str, err = reflect2value(rawValue)
if err != nil {
return
}
data = []byte(str)
return
}
func value2String(rawValue *reflect.Value) (data string, err error) {
data, err = reflect2value(rawValue)
if err != nil {
return
}
return
}
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2mapStr(rows, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2map(rows, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) {
result := make(map[string][]byte)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
if data, err := value2Bytes(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) {
result := make(map[string]string)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
if data, err := value2String(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func txQuery2(tx *core.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string]string, err error) {
rows, err := tx.Query(sqlStr, params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
func query2(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string]string, err error) {
s, err := db.Prepare(sqlStr)
if err != nil {
return nil, err
}
defer s.Close()
rows, err := s.Query(params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
func setColumnInt(bean interface{}, col *core.Column, t int64) {
v, err := col.ValueOf(bean)
if err != nil {
@ -514,7 +358,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
for _, col := range table.Columns() {
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
}
@ -542,6 +386,10 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
@ -549,28 +397,32 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
} else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
}
if session.Statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.Statement.nullableMap, col); ok {
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.Engine.NowTime2(col.SQLType.Name)
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
@ -578,7 +430,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.Statement.checkVersion {
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
@ -589,7 +441,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
}
if includeQuote {
colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?")
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
} else {
colNames = append(colNames, col.Name)
}
@ -602,7 +454,6 @@ func indexName(tableName, idxName string) string {
}
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}

21
vendor/github.com/go-xorm/xorm/helpler_time.go generated vendored Normal file
View File

@ -0,0 +1,21 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import "time"
const (
zeroTime0 = "0000-00-00 00:00:00"
zeroTime1 = "0001-01-01 00:00:00"
)
func formatTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}
func isTimeZero(t time.Time) bool {
return t.IsZero() || formatTime(t) == zeroTime0 ||
formatTime(t) == zeroTime1
}

103
vendor/github.com/go-xorm/xorm/interface.go generated vendored Normal file
View File

@ -0,0 +1,103 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql"
"reflect"
"time"
"github.com/go-xorm/core"
)
// Interface defines the interface which Engine, EngineGroup and Session will implementate.
type Interface interface {
AllCols() *Session
Alias(alias string) *Session
Asc(colNames ...string) *Session
BufferSize(size int) *Session
Cols(columns ...string) *Session
Count(...interface{}) (int64, error)
CreateIndexes(bean interface{}) error
CreateUniques(bean interface{}) error
Decr(column string, arg ...interface{}) *Session
Desc(...string) *Session
Delete(interface{}) (int64, error)
Distinct(columns ...string) *Session
DropIndexes(bean interface{}) error
Exec(string, ...interface{}) (sql.Result, error)
Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error
Get(interface{}) (bool, error)
GroupBy(keys string) *Session
ID(interface{}) *Session
In(string, ...interface{}) *Session
Incr(column string, arg ...interface{}) *Session
Insert(...interface{}) (int64, error)
InsertOne(interface{}) (int64, error)
IsTableEmpty(bean interface{}) (bool, error)
IsTableExist(beanOrTableName interface{}) (bool, error)
Iterate(interface{}, IterFunc) error
Limit(int, ...int) *Session
NoAutoCondition(...bool) *Session
NotIn(string, ...interface{}) *Session
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session
Omit(columns ...string) *Session
OrderBy(order string) *Session
Ping() error
Query(sqlOrAgrs ...interface{}) (resultsSlice []map[string][]byte, err error)
QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error)
QueryString(sqlorArgs ...interface{}) ([]map[string]string, error)
Rows(bean interface{}) (*Rows, error)
SetExpr(string, string) *Session
SQL(interface{}, ...interface{}) *Session
Sum(bean interface{}, colName string) (float64, error)
SumInt(bean interface{}, colName string) (int64, error)
Sums(bean interface{}, colNames ...string) ([]float64, error)
SumsInt(bean interface{}, colNames ...string) ([]int64, error)
Table(tableNameOrBean interface{}) *Session
Unscoped() *Session
Update(bean interface{}, condiBeans ...interface{}) (int64, error)
UseBool(...string) *Session
Where(interface{}, ...interface{}) *Session
}
// EngineInterface defines the interface which Engine, EngineGroup will implementate.
type EngineInterface interface {
Interface
Before(func(interface{})) *Session
Charset(charset string) *Session
CreateTables(...interface{}) error
DBMetas() ([]*core.Table, error)
Dialect() core.Dialect
DropTables(...interface{}) error
DumpAllToFile(fp string, tp ...core.DbType) error
GetColumnMapper() core.IMapper
GetDefaultCacher() core.Cacher
GetTableMapper() core.IMapper
GetTZDatabase() *time.Location
GetTZLocation() *time.Location
NewSession() *Session
NoAutoTime() *Session
Quote(string) string
SetDefaultCacher(core.Cacher)
SetLogLevel(core.LogLevel)
SetMapper(core.IMapper)
SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location)
ShowSQL(show ...bool)
Sync(...interface{}) error
Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table
UnMapType(reflect.Type)
}
var (
_ Interface = &Session{}
_ EngineInterface = &Engine{}
_ EngineInterface = &EngineGroup{}
)

View File

@ -29,13 +29,6 @@ type AfterSetProcessor interface {
AfterSet(string, Cell)
}
// !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations
//// Executed before an object is validated
//type BeforeValidateProcessor interface {
// BeforeValidate()
//}
// --
// AfterInsertProcessor executed after an object is persisted to the database
type AfterInsertProcessor interface {
AfterInsert()
@ -50,3 +43,36 @@ type AfterUpdateProcessor interface {
type AfterDeleteProcessor interface {
AfterDelete()
}
// AfterLoadProcessor executed after an ojbect has been loaded from database
type AfterLoadProcessor interface {
AfterLoad()
}
// AfterLoadSessionProcessor executed after an ojbect has been loaded from database with session parameter
type AfterLoadSessionProcessor interface {
AfterLoad(*Session)
}
type executedProcessorFunc func(*Session, interface{}) error
type executedProcessor struct {
fun executedProcessorFunc
session *Session
bean interface{}
}
func (executor *executedProcessor) execute() error {
return executor.fun(executor.session, executor.bean)
}
func (session *Session) executeProcessors() error {
processors := session.afterProcessors
session.afterProcessors = make([]executedProcessor, 0)
for _, processor := range processors {
if err := processor.execute(); err != nil {
return err
}
}
return nil
}

View File

@ -17,7 +17,6 @@ type Rows struct {
NoTypeCheck bool
session *Session
stmt *core.Stmt
rows *core.Rows
fields []string
beanType reflect.Type
@ -29,51 +28,34 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
rows.session = session
rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type()
defer rows.session.resetStatement()
var sqlStr string
var args []interface{}
var err error
rows.session.Statement.setRefValue(rValue(bean))
if len(session.Statement.TableName()) <= 0 {
if err = rows.session.statement.setRefValue(rValue(bean)); err != nil {
return nil, err
}
if len(session.statement.TableName()) <= 0 {
return nil, ErrTableNotFound
}
if rows.session.Statement.RawSQL == "" {
sqlStr, args = rows.session.Statement.genGetSQL(bean)
} else {
sqlStr = rows.session.Statement.RawSQL
args = rows.session.Statement.RawParams
}
for _, filter := range rows.session.Engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, rows.session.Statement.RefTable)
}
rows.session.saveLastSQL(sqlStr, args...)
var err error
if rows.session.prepareStmt {
rows.stmt, err = rows.session.DB().Prepare(sqlStr)
if rows.session.statement.RawSQL == "" {
sqlStr, args, err = rows.session.statement.genGetSQL(bean)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
rows.rows, err = rows.stmt.Query(args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
} else {
rows.rows, err = rows.session.DB().Query(sqlStr, args...)
sqlStr = rows.session.statement.RawSQL
args = rows.session.statement.RawParams
}
rows.rows, err = rows.session.queryRows(sqlStr, args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
}
rows.fields, err = rows.rows.Columns()
if err != nil {
@ -113,15 +95,26 @@ func (rows *Rows) Scan(bean interface{}) error {
}
dataStruct := rValue(bean)
rows.session.Statement.setRefValue(dataStruct)
_, err := rows.session.row2Bean(rows.rows, rows.fields, len(rows.fields), bean, &dataStruct, rows.session.Statement.RefTable)
if err := rows.session.statement.setRefValue(dataStruct); err != nil {
return err
}
scanResults, err := rows.session.row2Slice(rows.rows, rows.fields, bean)
if err != nil {
return err
}
_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable)
if err != nil {
return err
}
return rows.session.executeProcessors()
}
// Close session if session.IsAutoClose is true, and claimed any opened resources
func (rows *Rows) Close() error {
if rows.session.IsAutoClose {
if rows.session.isAutoClose {
defer rows.session.Close()
}
@ -129,17 +122,10 @@ func (rows *Rows) Close() error {
if rows.rows != nil {
rows.lastError = rows.rows.Close()
if rows.lastError != nil {
defer rows.stmt.Close()
return rows.lastError
}
}
if rows.stmt != nil {
rows.lastError = rows.stmt.Close()
}
} else {
if rows.stmt != nil {
defer rows.stmt.Close()
}
if rows.rows != nil {
defer rows.rows.Close()
}

View File

@ -11,7 +11,6 @@ import (
"fmt"
"hash/crc32"
"reflect"
"strconv"
"strings"
"time"
@ -22,17 +21,16 @@ import (
// kind of database operations.
type Session struct {
db *core.DB
Engine *Engine
Tx *core.Tx
Statement Statement
IsAutoCommit bool
IsCommitedOrRollbacked bool
TransType string
IsAutoClose bool
engine *Engine
tx *core.Tx
statement Statement
isAutoCommit bool
isCommitedOrRollbacked bool
isAutoClose bool
// Automatically reset the statement after operations that execute a SQL
// query such as Count(), Find(), Get(), ...
AutoResetStatement bool
autoResetStatement bool
// !nashtsai! storing these beans due to yet committed tx
afterInsertBeans map[interface{}]*[]func(interface{})
@ -43,14 +41,17 @@ type Session struct {
beforeClosures []func(interface{})
afterClosures []func(interface{})
afterProcessors []executedProcessor
prepareStmt bool
stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr))
cascadeDeep int
// !evalphobia! stored the last executed query on this session
//beforeSQLExec func(string, ...interface{})
lastSQL string
lastSQLArgs []interface{}
err error
}
// Clone copy all the session's content and return a new session
@ -61,12 +62,12 @@ func (session *Session) Clone() *Session {
// Init reset the session as the init status.
func (session *Session) Init() {
session.Statement.Init()
session.Statement.Engine = session.Engine
session.IsAutoCommit = true
session.IsCommitedOrRollbacked = false
session.IsAutoClose = false
session.AutoResetStatement = true
session.statement.Init()
session.statement.Engine = session.engine
session.isAutoCommit = true
session.isCommitedOrRollbacked = false
session.isAutoClose = false
session.autoResetStatement = true
session.prepareStmt = false
// !nashtsai! is lazy init better?
@ -75,6 +76,9 @@ func (session *Session) Init() {
session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0)
session.beforeClosures = make([]func(interface{}), 0)
session.afterClosures = make([]func(interface{}), 0)
session.stmtCache = make(map[uint32]*core.Stmt)
session.afterProcessors = make([]executedProcessor, 0)
session.lastSQL = ""
session.lastSQLArgs = []interface{}{}
@ -89,19 +93,23 @@ func (session *Session) Close() {
if session.db != nil {
// When Close be called, if session is a transaction and do not call
// Commit or Rollback, then call Rollback.
if session.Tx != nil && !session.IsCommitedOrRollbacked {
if session.tx != nil && !session.isCommitedOrRollbacked {
session.Rollback()
}
session.Tx = nil
session.tx = nil
session.stmtCache = nil
session.Init()
session.db = nil
}
}
// IsClosed returns if session is closed
func (session *Session) IsClosed() bool {
return session.db == nil
}
func (session *Session) resetStatement() {
if session.AutoResetStatement {
session.Statement.Init()
if session.autoResetStatement {
session.statement.Init()
}
}
@ -129,75 +137,75 @@ func (session *Session) After(closures func(interface{})) *Session {
// Table can input a string or pointer to struct for special a table to operate.
func (session *Session) Table(tableNameOrBean interface{}) *Session {
session.Statement.Table(tableNameOrBean)
session.statement.Table(tableNameOrBean)
return session
}
// Alias set the table alias
func (session *Session) Alias(alias string) *Session {
session.Statement.Alias(alias)
session.statement.Alias(alias)
return session
}
// NoCascade indicate that no cascade load child object
func (session *Session) NoCascade() *Session {
session.Statement.UseCascade = false
session.statement.UseCascade = false
return session
}
// ForUpdate Set Read/Write locking for UPDATE
func (session *Session) ForUpdate() *Session {
session.Statement.IsForUpdate = true
session.statement.IsForUpdate = true
return session
}
// NoAutoCondition disable generate SQL condition from beans
func (session *Session) NoAutoCondition(no ...bool) *Session {
session.Statement.NoAutoCondition(no...)
session.statement.NoAutoCondition(no...)
return session
}
// Limit provide limit and offset query condition
func (session *Session) Limit(limit int, start ...int) *Session {
session.Statement.Limit(limit, start...)
session.statement.Limit(limit, start...)
return session
}
// OrderBy provide order by query condition, the input parameter is the content
// after order by on a sql statement.
func (session *Session) OrderBy(order string) *Session {
session.Statement.OrderBy(order)
session.statement.OrderBy(order)
return session
}
// Desc provide desc order by query condition, the input parameters are columns.
func (session *Session) Desc(colNames ...string) *Session {
session.Statement.Desc(colNames...)
session.statement.Desc(colNames...)
return session
}
// Asc provide asc order by query condition, the input parameters are columns.
func (session *Session) Asc(colNames ...string) *Session {
session.Statement.Asc(colNames...)
session.statement.Asc(colNames...)
return session
}
// StoreEngine is only avialble mysql dialect currently
func (session *Session) StoreEngine(storeEngine string) *Session {
session.Statement.StoreEngine = storeEngine
session.statement.StoreEngine = storeEngine
return session
}
// Charset is only avialble mysql dialect currently
func (session *Session) Charset(charset string) *Session {
session.Statement.Charset = charset
session.statement.Charset = charset
return session
}
// Cascade indicates if loading sub Struct
func (session *Session) Cascade(trueOrFalse ...bool) *Session {
if len(trueOrFalse) >= 1 {
session.Statement.UseCascade = trueOrFalse[0]
session.statement.UseCascade = trueOrFalse[0]
}
return session
}
@ -205,32 +213,32 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session {
// NoCache ask this session do not retrieve data from cache system and
// get data from database directly.
func (session *Session) NoCache() *Session {
session.Statement.UseCache = false
session.statement.UseCache = false
return session
}
// Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session {
session.Statement.Join(joinOperator, tablename, condition, args...)
session.statement.Join(joinOperator, tablename, condition, args...)
return session
}
// GroupBy Generate Group By statement
func (session *Session) GroupBy(keys string) *Session {
session.Statement.GroupBy(keys)
session.statement.GroupBy(keys)
return session
}
// Having Generate Having statement
func (session *Session) Having(conditions string) *Session {
session.Statement.Having(conditions)
session.statement.Having(conditions)
return session
}
// DB db return the wrapper of sql.DB
func (session *Session) DB() *core.DB {
if session.db == nil {
session.db = session.Engine.db
session.db = session.engine.db
session.stmtCache = make(map[uint32]*core.Stmt, 0)
}
return session.db
@ -243,25 +251,25 @@ func cleanupProcessorsClosures(slices *[]func(interface{})) {
}
func (session *Session) canCache() bool {
if session.Statement.RefTable == nil ||
session.Statement.JoinStr != "" ||
session.Statement.RawSQL != "" ||
!session.Statement.UseCache ||
session.Statement.IsForUpdate ||
session.Tx != nil ||
len(session.Statement.selectStr) > 0 {
if session.statement.RefTable == nil ||
session.statement.JoinStr != "" ||
session.statement.RawSQL != "" ||
!session.statement.UseCache ||
session.statement.IsForUpdate ||
session.tx != nil ||
len(session.statement.selectStr) > 0 {
return false
}
return true
}
func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) {
func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, err error) {
crc := crc32.ChecksumIEEE([]byte(sqlStr))
// TODO try hash(sqlStr+len(sqlStr))
var has bool
stmt, has = session.stmtCache[crc]
if !has {
stmt, err = session.DB().Prepare(sqlStr)
stmt, err = db.Prepare(sqlStr)
if err != nil {
return nil, err
}
@ -273,18 +281,18 @@ func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) {
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value {
var col *core.Column
if col = table.GetColumnIdx(key, idx); col == nil {
//session.Engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq())
//session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq())
return nil
}
fieldValue, err := col.ValueOfV(dataStruct)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return nil
}
if !fieldValue.IsValid() || !fieldValue.CanSet() {
session.Engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key)
session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key)
return nil
}
return fieldValue
@ -293,28 +301,40 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c
// Cell cell is a result of one column field
type Cell *interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount int,
func (session *Session) rows2Beans(rows *core.Rows, fields []string,
table *core.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, core.PK) error) error {
for rows.Next() {
var newValue = newElemFunc(fields)
bean := newValue.Interface()
dataStruct := rValue(bean)
pk, err := session.row2Bean(rows, fields, fieldsCount, bean, &dataStruct, table)
if err != nil {
return err
}
dataStruct := newValue.Elem()
err = sliceValueSetFunc(&newValue, pk)
// handle beforeClosures
scanResults, err := session.row2Slice(rows, fields, bean)
if err != nil {
return err
}
pk, err := session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
return err
}
session.afterProcessors = append(session.afterProcessors, executedProcessor{
fun: func(*Session, interface{}) error {
return sliceValueSetFunc(&newValue, pk)
},
session: session,
bean: bean,
})
}
return nil
}
func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) {
scanResults := make([]interface{}, fieldsCount)
func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) {
for _, closure := range session.beforeClosures {
closure(bean)
}
scanResults := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var cell interface{}
scanResults[i] = &cell
@ -328,7 +348,10 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
b.BeforeSet(key, Cell(scanResults[ii].(*interface{})))
}
}
return scanResults, nil
}
func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) {
defer func() {
if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
for ii, key := range fields {
@ -337,6 +360,40 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
}
}()
// handle afterClosures
for _, closure := range session.afterClosures {
session.afterProcessors = append(session.afterProcessors, executedProcessor{
fun: func(sess *Session, bean interface{}) error {
closure(bean)
return nil
},
session: session,
bean: bean,
})
}
if a, has := bean.(AfterLoadProcessor); has {
session.afterProcessors = append(session.afterProcessors, executedProcessor{
fun: func(sess *Session, bean interface{}) error {
a.AfterLoad()
return nil
},
session: session,
bean: bean,
})
}
if a, has := bean.(AfterLoadSessionProcessor); has {
session.afterProcessors = append(session.afterProcessors, executedProcessor{
fun: func(sess *Session, bean interface{}) error {
a.AfterLoad(sess)
return nil
},
session: session,
bean: bean,
})
}
var tempMap = make(map[string]int)
var pk core.PK
for ii, key := range fields {
@ -361,9 +418,11 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
structConvert.FromDB(data)
if err := structConvert.FromDB(data); err != nil {
return nil, err
}
} else {
session.Engine.logger.Error(err)
return nil, err
}
continue
}
@ -376,7 +435,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
}
fieldValue.Interface().(core.Conversion).FromDB(data)
} else {
session.Engine.logger.Error(err)
return nil, err
}
continue
}
@ -403,17 +462,19 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
hasAssigned = true
if len(bs) > 0 {
if fieldType.Kind() == reflect.String {
fieldValue.SetString(string(bs))
continue
}
if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
session.Engine.logger.Error(key, err)
return nil, err
}
} else {
x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface())
if err != nil {
session.Engine.logger.Error(key, err)
return nil, err
}
fieldValue.Set(x.Elem())
@ -438,14 +499,12 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
session.Engine.logger.Error(err)
return nil, err
}
} else {
x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface())
if err != nil {
session.Engine.logger.Error(err)
return nil, err
}
fieldValue.Set(x.Elem())
@ -462,16 +521,21 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
x := reflect.New(fieldType)
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
session.Engine.logger.Error(err)
return nil, err
}
fieldValue.Set(x.Elem())
} else {
if fieldValue.Len() > 0 {
for i := 0; i < fieldValue.Len(); i++ {
if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i))
}
}
} else {
for i := 0; i < vv.Len(); i++ {
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
}
}
}
}
}
@ -509,57 +573,38 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
}
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil {
dbTZ = col.TimeZone
}
if rawValueType == core.TimeType {
hasAssigned = true
t := vv.Convert(core.TimeType).Interface().(time.Time)
z, _ := t.Zone()
dbTZ := session.Engine.DatabaseTZ
if dbTZ == nil {
if session.Engine.dialect.DBType() == core.SQLITE {
dbTZ = time.UTC
} else {
dbTZ = time.Local
}
}
// set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.Engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
}
// !nashtsai! convert to engine location
if col.TimeZone == nil {
t = t.In(session.Engine.TZLocation)
} else {
t = t.In(col.TimeZone)
}
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
// t = fieldValue.Interface().(time.Time)
// z, _ = t.Zone()
// session.Engine.LogDebug("fieldValue key[%v]: %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
} else if rawValueType == core.IntType || rawValueType == core.Int64Type ||
rawValueType == core.Int32Type {
hasAssigned = true
var tz *time.Location
if col.TimeZone == nil {
tz = session.Engine.TZLocation
} else {
tz = col.TimeZone
}
t := time.Unix(vv.Int(), 0).In(tz)
//vv = reflect.ValueOf(t)
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else {
if d, ok := vv.Interface().([]uint8); ok {
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.Engine.logger.Error("byte2Time error:", err.Error())
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
@ -568,20 +613,20 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.Engine.logger.Error("byte2Time error:", err.Error())
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else {
panic(fmt.Sprintf("rawValueType is %v, value is %v", rawValueType, vv.Interface()))
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
}
}
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
session.Engine.logger.Error("sql.Sanner error:", err.Error())
session.engine.logger.Error("sql.Sanner error:", err.Error())
hasAssigned = false
}
} else if col.SQLType.IsJson() {
@ -591,7 +636,6 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
session.Engine.logger.Error(err)
return nil, err
}
fieldValue.Set(x.Elem())
@ -602,48 +646,25 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
if len(vv.Bytes()) > 0 {
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
session.Engine.logger.Error(err)
return nil, err
}
fieldValue.Set(x.Elem())
}
}
} else if session.Statement.UseCascade {
table := session.Engine.autoMapType(*fieldValue)
if table != nil {
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
if err != nil {
return nil, err
}
hasAssigned = true
if len(table.PrimaryKeys) != 1 {
panic("unsupported non or composited primary key cascade")
return nil, errors.New("unsupported non or composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
switch rawValueType.Kind() {
case reflect.Int64:
pk[0] = vv.Int()
case reflect.Int:
pk[0] = int(vv.Int())
case reflect.Int32:
pk[0] = int32(vv.Int())
case reflect.Int16:
pk[0] = int16(vv.Int())
case reflect.Int8:
pk[0] = int8(vv.Int())
case reflect.Uint64:
pk[0] = vv.Uint()
case reflect.Uint:
pk[0] = uint(vv.Uint())
case reflect.Uint32:
pk[0] = uint32(vv.Uint())
case reflect.Uint16:
pk[0] = uint16(vv.Uint())
case reflect.Uint8:
pk[0] = uint8(vv.Uint())
case reflect.String:
pk[0] = vv.String()
case reflect.Slice:
pk[0], _ = strconv.ParseInt(string(rawValue.Interface().([]byte)), 10, 64)
default:
panic(fmt.Sprintf("unsupported primary key type: %v, %v", rawValueType, fieldValue))
pk[0], err = asKind(vv, rawValueType)
if err != nil {
return nil, err
}
if !isPKZero(pk) {
@ -651,27 +672,19 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return nil, err
}
if has {
//v := structInter.Elem().Interface()
//fieldValue.Set(reflect.ValueOf(v))
fieldValue.Set(structInter.Elem())
} else {
return nil, errors.New("cascade obj is not exist")
}
}
} else {
session.Engine.logger.Error("unsupported struct type in Scan: ", fieldValue.Type().String())
}
}
case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above
//typeStr := fieldType.String()
switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly
case core.PtrStringType:
@ -769,10 +782,9 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
session.Engine.logger.Error(err)
} else {
fieldValue.Set(reflect.ValueOf(&x))
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
case core.Complex128Type:
@ -780,24 +792,23 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
session.Engine.logger.Error(err)
} else {
fieldValue.Set(reflect.ValueOf(&x))
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
} // switch fieldType
// default:
// session.Engine.LogError("unsupported type in Scan: ", reflect.TypeOf(v).String())
} // switch fieldType.Kind()
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value
if !hasAssigned {
data, err := value2Bytes(&rawValue)
if err == nil {
session.bytes2Value(col, fieldValue, data)
} else {
session.Engine.logger.Error(err.Error())
if err != nil {
return nil, err
}
if err = session.bytes2Value(col, fieldValue, data); err != nil {
return nil, err
}
}
}
@ -805,19 +816,11 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
return pk, nil
}
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.Engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable)
}
session.saveLastSQL(*sqlStr, paramStr...)
}
// saveLastSQL stores executed query information
func (session *Session) saveLastSQL(sql string, args ...interface{}) {
session.lastSQL = sql
session.lastSQLArgs = args
session.Engine.logSQL(sql, args...)
session.engine.logSQL(sql, args...)
}
// LastSQL returns last query information
@ -827,8 +830,8 @@ func (session *Session) LastSQL() (string, []interface{}) {
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.Statement.AltTableName) > 0 {
return session.Statement.AltTableName
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
@ -836,6 +839,6 @@ func (session *Session) tbNameNoSchema(table *core.Table) string {
// Unscoped always disable struct tag "deleted"
func (session *Session) Unscoped() *Session {
session.Statement.Unscoped()
session.statement.Unscoped()
return session
}

View File

@ -6,43 +6,43 @@ package xorm
// Incr provides a query string like "count = count + 1"
func (session *Session) Incr(column string, arg ...interface{}) *Session {
session.Statement.Incr(column, arg...)
session.statement.Incr(column, arg...)
return session
}
// Decr provides a query string like "count = count - 1"
func (session *Session) Decr(column string, arg ...interface{}) *Session {
session.Statement.Decr(column, arg...)
session.statement.Decr(column, arg...)
return session
}
// SetExpr provides a query string like "column = {expression}"
func (session *Session) SetExpr(column string, expression string) *Session {
session.Statement.SetExpr(column, expression)
session.statement.SetExpr(column, expression)
return session
}
// Select provides some columns to special
func (session *Session) Select(str string) *Session {
session.Statement.Select(str)
session.statement.Select(str)
return session
}
// Cols provides some columns to special
func (session *Session) Cols(columns ...string) *Session {
session.Statement.Cols(columns...)
session.statement.Cols(columns...)
return session
}
// AllCols ask all columns
func (session *Session) AllCols() *Session {
session.Statement.AllCols()
session.statement.AllCols()
return session
}
// MustCols specify some columns must use even if they are empty
func (session *Session) MustCols(columns ...string) *Session {
session.Statement.MustCols(columns...)
session.statement.MustCols(columns...)
return session
}
@ -52,7 +52,7 @@ func (session *Session) MustCols(columns ...string) *Session {
// If no parameters, it will use all the bool field of struct, or
// it will use parameters's columns
func (session *Session) UseBool(columns ...string) *Session {
session.Statement.UseBool(columns...)
session.statement.UseBool(columns...)
return session
}
@ -60,25 +60,25 @@ func (session *Session) UseBool(columns ...string) *Session {
// distinct will not be cached because cache system need id,
// but distinct will not provide id
func (session *Session) Distinct(columns ...string) *Session {
session.Statement.Distinct(columns...)
session.statement.Distinct(columns...)
return session
}
// Omit Only not use the parameters as select or update columns
func (session *Session) Omit(columns ...string) *Session {
session.Statement.Omit(columns...)
session.statement.Omit(columns...)
return session
}
// Nullable Set null when column is zero-value and nullable for update
func (session *Session) Nullable(columns ...string) *Session {
session.Statement.Nullable(columns...)
session.statement.Nullable(columns...)
return session
}
// NoAutoTime means do not automatically give created field and updated field
// the current time on the current session temporarily
func (session *Session) NoAutoTime() *Session {
session.Statement.UseAutoTime = false
session.statement.UseAutoTime = false
return session
}

View File

@ -17,25 +17,25 @@ func (session *Session) Sql(query string, args ...interface{}) *Session {
// SQL provides raw sql input parameter. When you have a complex SQL statement
// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL.
func (session *Session) SQL(query interface{}, args ...interface{}) *Session {
session.Statement.SQL(query, args...)
session.statement.SQL(query, args...)
return session
}
// Where provides custom query condition.
func (session *Session) Where(query interface{}, args ...interface{}) *Session {
session.Statement.Where(query, args...)
session.statement.Where(query, args...)
return session
}
// And provides custom query condition.
func (session *Session) And(query interface{}, args ...interface{}) *Session {
session.Statement.And(query, args...)
session.statement.And(query, args...)
return session
}
// Or provides custom query condition.
func (session *Session) Or(query interface{}, args ...interface{}) *Session {
session.Statement.Or(query, args...)
session.statement.Or(query, args...)
return session
}
@ -48,23 +48,23 @@ func (session *Session) Id(id interface{}) *Session {
// ID provides converting id as a query condition
func (session *Session) ID(id interface{}) *Session {
session.Statement.ID(id)
session.statement.ID(id)
return session
}
// In provides a query string like "id in (1, 2, 3)"
func (session *Session) In(column string, args ...interface{}) *Session {
session.Statement.In(column, args...)
session.statement.In(column, args...)
return session
}
// NotIn provides a query string like "id in (1, 2, 3)"
func (session *Session) NotIn(column string, args ...interface{}) *Session {
session.Statement.NotIn(column, args...)
session.statement.NotIn(column, args...)
return session
}
// Conds returns session query conditions
// Conds returns session query conditions except auto bean conditions
func (session *Session) Conds() builder.Cond {
return session.Statement.cond
return session.statement.cond
}

View File

@ -23,41 +23,38 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
var x time.Time
var err error
if sdata == "0000-00-00 00:00:00" ||
sdata == "0001-01-01 00:00:00" {
var parseLoc = session.engine.DatabaseTZ
if col.TimeZone != nil {
parseLoc = col.TimeZone
}
if sdata == zeroTime0 || sdata == zeroTime1 {
} else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column
// time stamp
sd, err := strconv.ParseInt(sdata, 10, 64)
if err == nil {
x = time.Unix(sd, 0)
// !nashtsai! HACK mymysql driver is causing Local location being change to CHAT and cause wrong time conversion
if col.TimeZone == nil {
x = x.In(session.Engine.TZLocation)
//session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
x = x.In(col.TimeZone)
}
session.Engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
session.Engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
//session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
} else if len(sdata) > 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, session.Engine.TZLocation)
session.Engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, session.Engine.TZLocation)
session.Engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
//session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, session.Engine.TZLocation)
session.Engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
//session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
} else if len(sdata) == 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, session.Engine.TZLocation)
session.Engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
//session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, session.Engine.TZLocation)
session.Engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
//session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if col.SQLType.Name == core.Time {
if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ")
@ -65,13 +62,13 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
}
sdata = strings.TrimSpace(sdata)
if session.Engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 {
if session.engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:]
}
st := fmt.Sprintf("2006-01-02 %v", sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, session.Engine.TZLocation)
session.Engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
//session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
outErr = fmt.Errorf("unsupported time format %v", sdata)
return
@ -80,7 +77,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
outErr = fmt.Errorf("unsupported time format %v: %v", sdata, err)
return
}
outTime = x
outTime = x.In(session.engine.TZLocation)
return
}
@ -108,7 +105,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(x.Elem())
@ -122,7 +119,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(x.Elem())
@ -135,7 +132,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(x.Elem())
@ -147,8 +144,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
case reflect.String:
fieldValue.SetString(string(data))
case reflect.Bool:
d := string(data)
v, err := strconv.ParseBool(d)
v, err := asBool(data)
if err != nil {
return fmt.Errorf("arg %v as bool: %s", key, err.Error())
}
@ -159,7 +155,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
session.Engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
session.engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
if len(data) == 1 {
x = int64(data[0])
} else {
@ -207,16 +203,19 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
v = x
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
} else if session.Statement.UseCascade {
table := session.Engine.autoMapType(*fieldValue)
if table != nil {
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
if err != nil {
return err
}
// TODO: current only support 1 primary key
if len(table.PrimaryKeys) > 1 {
panic("unsupported composited primary key cascade")
return errors.New("unsupported composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
var err error
pk[0], err = str2PK(string(data), rawValueType)
if err != nil {
return err
@ -227,9 +226,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return err
}
@ -240,9 +237,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return errors.New("cascade obj is not exist")
}
}
} else {
return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
}
}
}
case reflect.Ptr:
@ -267,7 +261,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, &x)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
@ -278,7 +272,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, &x)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
@ -350,7 +344,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") {
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int64(data[0])
} else {
@ -375,7 +369,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") {
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int(data[0])
} else {
@ -403,7 +397,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
session.Engine.dialect.DBType() == core.MYSQL {
session.engine.dialect.DBType() == core.MYSQL {
if len(data) == 1 {
x = int32(data[0])
} else {
@ -431,7 +425,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") {
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int8(data[0])
} else {
@ -459,7 +453,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") {
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int16(data[0])
} else {
@ -491,15 +485,18 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
v = x
fieldValue.Set(reflect.ValueOf(&x))
default:
if session.Statement.UseCascade {
if session.statement.UseCascade {
structInter := reflect.New(fieldType.Elem())
table := session.Engine.autoMapType(structInter.Elem())
if table != nil {
if len(table.PrimaryKeys) > 1 {
panic("unsupported composited primary key cascade")
table, err := session.engine.autoMapType(structInter.Elem())
if err != nil {
return err
}
if len(table.PrimaryKeys) > 1 {
return errors.New("unsupported composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
var err error
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
pk[0], err = str2PK(string(data), rawValueType)
if err != nil {
@ -510,9 +507,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return err
}
@ -523,7 +518,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return errors.New("cascade obj is not exist")
}
}
}
} else {
return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
}
@ -570,7 +564,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
if fieldValue.IsNil() {
return nil, nil
} else if !fieldValue.IsValid() {
session.Engine.logger.Warn("the field[", col.FieldName, "] is invalid")
session.engine.logger.Warn("the field[", col.FieldName, "] is invalid")
return nil, nil
} else {
// !nashtsai! deference pointer type to instance type
@ -588,12 +582,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
if session.Engine.dialect.DBType() == core.MSSQL {
if t.IsZero() {
return nil, nil
}
}
tf := session.Engine.FormatTime(col.SQLType.Name, t)
tf := session.engine.formatColTime(col, t)
return tf, nil
}
@ -603,7 +592,10 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
return v.Value()
}
fieldTable := session.Engine.autoMapType(fieldValue)
fieldTable, err := session.engine.autoMapType(fieldValue)
if err != nil {
return nil, err
}
if len(fieldTable.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName)
return pkField.Interface(), nil
@ -614,14 +606,14 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
return bytes, nil
@ -630,7 +622,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
case reflect.Complex64, reflect.Complex128:
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
return string(bytes), nil
@ -642,7 +634,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
return string(bytes), nil
@ -655,7 +647,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
} else {
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
}

View File

@ -12,26 +12,26 @@ import (
"github.com/go-xorm/core"
)
func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
if session.Statement.RefTable == nil ||
session.Tx != nil {
func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, args ...interface{}) error {
if table == nil ||
session.tx != nil {
return ErrCacheFailed
}
for _, filter := range session.Engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, table)
}
newsql := session.Statement.convertIDSQL(sqlStr)
newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" {
return ErrCacheFailed
}
cacher := session.Engine.getCacher2(session.Statement.RefTable)
tableName := session.Statement.TableName()
cacher := session.engine.getCacher2(table)
pkColumns := table.PKColumns()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
resultsSlice, err := session.query(newsql, args...)
resultsSlice, err := session.queryBytes(newsql, args...)
if err != nil {
return err
}
@ -40,7 +40,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
for _, data := range resultsSlice {
var id int64
var pk core.PK = make([]interface{}, 0)
for _, col := range session.Statement.RefTable.PKColumns() {
for _, col := range pkColumns {
if v, ok := data[col.Name]; !ok {
return errors.New("no id")
} else if col.SQLType.IsText() {
@ -58,33 +58,30 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
ids = append(ids, pk)
}
}
} /*else {
session.Engine.LogDebug("delete cache sql %v", newsql)
cacher.DelIds(tableName, genSqlKey(newsql, args))
}*/
}
for _, id := range ids {
session.Engine.logger.Debug("[cacheDelete] delete cache obj", tableName, id)
session.engine.logger.Debug("[cacheDelete] delete cache obj:", tableName, id)
sid, err := id.ToString()
if err != nil {
return err
}
cacher.DelBean(tableName, sid)
}
session.Engine.logger.Debug("[cacheDelete] clear cache sql", tableName)
session.engine.logger.Debug("[cacheDelete] clear cache table:", tableName)
cacher.ClearIds(tableName)
return nil
}
// Delete records, bean's non-empty fields are conditions
func (session *Session) Delete(bean interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
session.Statement.setRefValue(rValue(bean))
var table = session.Statement.RefTable
if err := session.statement.setRefValue(rValue(bean)); err != nil {
return 0, err
}
// handle before delete processors
for _, closure := range session.beforeClosures {
@ -96,13 +93,17 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
processor.BeforeDelete()
}
// --
condSQL, condArgs, _ := session.Statement.genConds(bean)
if len(condSQL) == 0 && session.Statement.LimitN == 0 {
condSQL, condArgs, err := session.statement.genConds(bean)
if err != nil {
return 0, err
}
if len(condSQL) == 0 && session.statement.LimitN == 0 {
return 0, ErrNeedDeletedCond
}
var tableName = session.Engine.Quote(session.Statement.TableName())
var tableNameNoQuote = session.statement.TableName()
var tableName = session.engine.Quote(tableNameNoQuote)
var table = session.statement.RefTable
var deleteSQL string
if len(condSQL) > 0 {
deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL)
@ -111,15 +112,15 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}
var orderSQL string
if len(session.Statement.OrderStr) > 0 {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.Statement.OrderStr)
if len(session.statement.OrderStr) > 0 {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr)
}
if session.Statement.LimitN > 0 {
orderSQL += fmt.Sprintf(" LIMIT %d", session.Statement.LimitN)
if session.statement.LimitN > 0 {
orderSQL += fmt.Sprintf(" LIMIT %d", session.statement.LimitN)
}
if len(orderSQL) > 0 {
switch session.Engine.dialect.DBType() {
switch session.engine.dialect.DBType() {
case core.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
@ -144,7 +145,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
var realSQL string
argsForCache := make([]interface{}, 0, len(condArgs)*2)
if session.Statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled
if session.statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled
realSQL = deleteSQL
copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...)
@ -155,12 +156,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
deletedColumn := table.DeletedColumn()
realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v",
session.Engine.Quote(session.Statement.TableName()),
session.Engine.Quote(deletedColumn.Name),
session.engine.Quote(session.statement.TableName()),
session.engine.Quote(deletedColumn.Name),
condSQL)
if len(orderSQL) > 0 {
switch session.Engine.dialect.DBType() {
switch session.engine.dialect.DBType() {
case core.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
@ -183,12 +184,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}
}
// !oinume! Insert NowTime to the head of session.Statement.Params
// !oinume! Insert nowTime to the head of session.statement.Params
condArgs = append(condArgs, "")
paramsLen := len(condArgs)
copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1])
val, t := session.Engine.NowTime2(deletedColumn.SQLType.Name)
val, t := session.engine.nowTime(deletedColumn)
condArgs[0] = val
var colName = deletedColumn.Name
@ -198,17 +199,18 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
})
}
if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && session.Statement.UseCache {
session.cacheDelete(deleteSQL, argsForCache...)
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...)
}
session.statement.RefTable = table
res, err := session.exec(realSQL, condArgs...)
if err != nil {
return 0, err
}
// handle after delete processors
if session.IsAutoCommit {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}

77
vendor/github.com/go-xorm/xorm/session_exist.go generated vendored Normal file
View File

@ -0,0 +1,77 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"errors"
"fmt"
"reflect"
"github.com/go-xorm/builder"
)
// Exist returns true if the record exist otherwise return false
func (session *Session) Exist(bean ...interface{}) (bool, error) {
if session.isAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if session.statement.RawSQL == "" {
if len(bean) == 0 {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return false, ErrTableNotFound
}
if session.statement.cond.IsValid() {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return false, err
}
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL)
args = condArgs
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName)
args = []interface{}{}
}
} else {
beanValue := reflect.ValueOf(bean[0])
if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer")
}
if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
return false, err
}
}
if len(session.statement.TableName()) <= 0 {
return false, ErrTableNotFound
}
session.statement.Limit(1)
sqlStr, args, err = session.statement.genGetSQL(bean[0])
if err != nil {
return false, err
}
}
} else {
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return false, err
}
defer rows.Close()
return rows.Next(), nil
}

View File

@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/go-xorm/builder"
@ -24,11 +23,13 @@ const (
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct
func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
return session.find(rowsSlicePtr, condiBean...)
}
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
return errors.New("needs a pointer to a slice or a map")
@ -37,77 +38,79 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
sliceElementType := sliceValue.Type().Elem()
var tp = tpStruct
if session.Statement.RefTable == nil {
if session.statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem())
session.Statement.setRefValue(pv.Elem())
if err := session.statement.setRefValue(pv.Elem()); err != nil {
return err
}
} else {
tp = tpNonStruct
}
} else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType)
session.Statement.setRefValue(pv.Elem())
if err := session.statement.setRefValue(pv.Elem()); err != nil {
return err
}
} else {
tp = tpNonStruct
}
}
var table = session.Statement.RefTable
var table = session.statement.RefTable
var addedTableName = (len(session.Statement.JoinStr) > 0)
var addedTableName = (len(session.statement.JoinStr) > 0)
var autoCond builder.Cond
if tp == tpStruct {
if !session.Statement.noAutoCondition && len(condiBean) > 0 {
if !session.statement.noAutoCondition && len(condiBean) > 0 {
var err error
autoCond, err = session.Statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName)
autoCond, err = session.statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName)
if err != nil {
panic(err)
return err
}
} else {
// !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given.
// See https://github.com/go-xorm/xorm/issues/179
if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled
var colName = session.Engine.Quote(col.Name)
if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled
var colName = session.engine.Quote(col.Name)
if addedTableName {
var nm = session.Statement.TableName()
if len(session.Statement.TableAlias) > 0 {
nm = session.Statement.TableAlias
var nm = session.statement.TableName()
if len(session.statement.TableAlias) > 0 {
nm = session.statement.TableAlias
}
colName = session.Engine.Quote(nm) + "." + colName
}
if session.Engine.dialect.DBType() == core.MSSQL {
autoCond = builder.IsNull{colName}
} else {
autoCond = builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"})
colName = session.engine.Quote(nm) + "." + colName
}
autoCond = session.engine.CondDeleted(colName)
}
}
}
var sqlStr string
var args []interface{}
if session.Statement.RawSQL == "" {
if len(session.Statement.TableName()) <= 0 {
var err error
if session.statement.RawSQL == "" {
if len(session.statement.TableName()) <= 0 {
return ErrTableNotFound
}
var columnStr = session.Statement.ColumnStr
if len(session.Statement.selectStr) > 0 {
columnStr = session.Statement.selectStr
var columnStr = session.statement.ColumnStr
if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr
} else {
if session.Statement.JoinStr == "" {
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.Statement.GroupByStr != "" {
columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
} else {
columnStr = session.Statement.genColumnStr()
columnStr = session.statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.Statement.GroupByStr != "" {
columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
} else {
columnStr = "*"
}
@ -118,31 +121,37 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
}
}
condSQL, condArgs, _ := builder.ToSQL(session.Statement.cond.And(autoCond))
session.statement.cond = session.statement.cond.And(autoCond)
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return err
}
args = append(session.Statement.joinArgs, condArgs...)
sqlStr = session.Statement.genSelectSQL(columnStr, condSQL)
args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL)
if err != nil {
return err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
}
var err error
if session.canCache() {
if cacher := session.Engine.getCacher2(table); cacher != nil &&
!session.Statement.IsDistinct &&
!session.Statement.unscoped {
if cacher := session.engine.getCacher2(table); cacher != nil &&
!session.statement.IsDistinct &&
!session.statement.unscoped {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
if err != ErrCacheFailed {
return err
}
err = nil // !nashtsai! reset err to nil for ErrCacheFailed
session.Engine.logger.Warn("Cache Find Failed")
session.engine.logger.Warn("Cache Find Failed")
}
}
@ -150,21 +159,13 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
}
func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {
var rawRows *core.Rows
var err error
session.queryPreprocess(&sqlStr, args...)
if session.IsAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.Tx.Query(sqlStr, args...)
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return err
}
defer rawRows.Close()
defer rows.Close()
fields, err := rawRows.Columns()
fields, err := rows.Columns()
if err != nil {
return err
}
@ -234,20 +235,29 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields)
dataStruct := rValue(newValue.Interface())
return session.rows2Beans(rawRows, fields, len(fields), session.Engine.autoMapType(dataStruct), newElemFunc, containerValueSetFunc)
tb, err := session.engine.autoMapType(dataStruct)
if err != nil {
return err
}
err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc)
rows.Close()
if err != nil {
return err
}
return session.executeProcessors()
}
for rawRows.Next() {
for rows.Next() {
var newValue = newElemFunc(fields)
bean := newValue.Interface()
switch elemType.Kind() {
case reflect.Slice:
err = rawRows.ScanSlice(bean)
err = rows.ScanSlice(bean)
case reflect.Map:
err = rawRows.ScanMap(bean)
err = rows.ScanMap(bean)
default:
err = rawRows.Scan(bean)
err = rows.Scan(bean)
}
if err != nil {
@ -278,22 +288,21 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed
}
for _, filter := range session.Engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
}
newsql := session.Statement.convertIDSQL(sqlStr)
newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" {
return ErrCacheFailed
}
tableName := session.Statement.TableName()
table := session.Statement.RefTable
cacher := session.Engine.getCacher2(table)
tableName := session.statement.TableName()
table := session.statement.RefTable
cacher := session.engine.getCacher2(table)
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
rows, err := session.DB().Query(newsql, args...)
rows, err := session.queryRows(newsql, args...)
if err != nil {
return err
}
@ -304,7 +313,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
for rows.Next() {
i++
if i > 500 {
session.Engine.logger.Debug("[cacheFind] ids length > 500, no cache")
session.engine.logger.Debug("[cacheFind] ids length > 500, no cache")
return ErrCacheFailed
}
var res = make([]string, len(table.PrimaryKeys))
@ -312,32 +321,24 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if err != nil {
return err
}
var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(res[i], 10, 64)
pk[i], err = session.engine.idTypeAssertion(col, res[i])
if err != nil {
return err
}
pk[i] = n
} else if col.SQLType.IsText() {
pk[i] = res[i]
} else {
return errors.New("not supported")
}
}
ids = append(ids, pk)
}
session.Engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, newsql, args)
session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, sqlStr, newsql, args)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil {
return err
}
} else {
session.Engine.logger.Debug("[cacheFind] cache hit sql:", newsql, args)
session.engine.logger.Debug("[cacheFind] cache hit sql:", tableName, sqlStr, newsql, args)
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
@ -352,20 +353,20 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return err
}
bean := cacher.GetBean(tableName, sid)
if bean == nil {
if bean == nil || reflect.ValueOf(bean).Elem().Type() != t {
ides = append(ides, id)
ididxes[sid] = idx
} else {
session.Engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean)
session.engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean)
pk := session.Engine.IdOf(bean)
pk := session.engine.IdOf(bean)
xid, err := pk.ToString()
if err != nil {
return err
}
if sid != xid {
session.Engine.logger.Error("[cacheFind] error cache", xid, sid, bean)
session.engine.logger.Error("[cacheFind] error cache", xid, sid, bean)
return ErrCacheFailed
}
temps[idx] = bean
@ -373,9 +374,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
}
if len(ides) > 0 {
newSession := session.Engine.NewSession()
defer newSession.Close()
slices := reflect.New(reflect.SliceOf(t))
beans := slices.Interface()
@ -385,18 +383,18 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ff = append(ff, ie[0])
}
newSession.In("`"+table.PrimaryKeys[0]+"`", ff...)
session.In("`"+table.PrimaryKeys[0]+"`", ff...)
} else {
for _, ie := range ides {
cond := builder.NewCond()
for i, name := range table.PrimaryKeys {
cond = cond.And(builder.Eq{"`" + name + "`": ie[i]})
}
newSession.Or(cond)
session.Or(cond)
}
}
err = newSession.NoCache().Find(beans)
err = session.NoCache().Table(tableName).find(beans)
if err != nil {
return err
}
@ -407,7 +405,10 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
id := session.Engine.IdOfV(rv)
id, err := session.engine.idOfV(rv)
if err != nil {
return err
}
sid, err := id.ToString()
if err != nil {
return err
@ -415,7 +416,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
bean := rv.Interface()
temps[ididxes[sid]] = bean
session.Engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps)
session.engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps)
cacher.PutBean(tableName, sid, bean)
}
}
@ -423,7 +424,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
for j := 0; j < len(temps); j++ {
bean := temps[j]
if bean == nil {
session.Engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps)
session.engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps)
// return errors.New("cache error") // !nashtsai! no need to return error, but continue instead
continue
}

View File

@ -15,42 +15,49 @@ import (
// Get retrieve one record from database, bean's non-empty fields
// will be as conditions
func (session *Session) Get(bean interface{}) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
beanValue := reflect.ValueOf(bean)
if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer to a struct")
return session.get(bean)
}
// FIXME: remove this after support non-struct Get
if beanValue.Elem().Kind() != reflect.Struct {
return false, errors.New("needs a pointer to a struct")
func (session *Session) get(bean interface{}) (bool, error) {
beanValue := reflect.ValueOf(bean)
if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer to a value")
} else if beanValue.Elem().Kind() == reflect.Ptr {
return false, errors.New("a pointer to a pointer is not allowed")
}
if beanValue.Elem().Kind() == reflect.Struct {
session.Statement.setRefValue(beanValue.Elem())
if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
return false, err
}
}
var sqlStr string
var args []interface{}
var err error
if session.Statement.RawSQL == "" {
if len(session.Statement.TableName()) <= 0 {
if session.statement.RawSQL == "" {
if len(session.statement.TableName()) <= 0 {
return false, ErrTableNotFound
}
session.Statement.Limit(1)
sqlStr, args = session.Statement.genGetSQL(bean)
session.statement.Limit(1)
sqlStr, args, err = session.statement.genGetSQL(bean)
if err != nil {
return false, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
}
if session.canCache() {
if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil &&
!session.Statement.unscoped {
table := session.statement.RefTable
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.engine.getCacher2(table); cacher != nil &&
!session.statement.unscoped {
has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed {
return has, err
@ -58,48 +65,52 @@ func (session *Session) Get(bean interface{}) (bool, error) {
}
}
return session.nocacheGet(beanValue.Elem().Kind(), bean, sqlStr, args...)
return session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...)
}
func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
var rawRows *core.Rows
var err error
session.queryPreprocess(&sqlStr, args...)
if session.IsAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.Tx.Query(sqlStr, args...)
}
func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return false, err
}
defer rows.Close()
defer rawRows.Close()
if rawRows.Next() {
fields, err := rawRows.Columns()
if err != nil {
// WARN: Alougth rawRows return true, but get fields failed
return true, err
if !rows.Next() {
return false, nil
}
switch beanKind {
case reflect.Struct:
fields, err := rows.Columns()
if err != nil {
// WARN: Alougth rows return true, but get fields failed
return true, err
}
scanResults, err := session.row2Slice(rows, fields, bean)
if err != nil {
return false, err
}
// close it before covert data
rows.Close()
dataStruct := rValue(bean)
session.Statement.setRefValue(dataStruct)
_, err = session.row2Bean(rawRows, fields, len(fields), bean, &dataStruct, session.Statement.RefTable)
_, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
return true, err
}
return true, session.executeProcessors()
case reflect.Slice:
err = rawRows.ScanSlice(bean)
err = rows.ScanSlice(bean)
case reflect.Map:
err = rawRows.ScanMap(bean)
err = rows.ScanMap(bean)
default:
err = rawRows.Scan(bean)
err = rows.Scan(bean)
}
return true, err
}
return false, nil
}
func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
// if has no reftable, then don't use cache currently
@ -107,22 +118,22 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed
}
for _, filter := range session.Engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
}
newsql := session.Statement.convertIDSQL(sqlStr)
newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" {
return false, ErrCacheFailed
}
cacher := session.Engine.getCacher2(session.Statement.RefTable)
tableName := session.Statement.TableName()
session.Engine.logger.Debug("[cacheGet] find sql:", newsql, args)
cacher := session.engine.getCacher2(session.statement.RefTable)
tableName := session.statement.TableName()
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
table := session.statement.RefTable
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
table := session.Statement.RefTable
if err != nil {
var res = make([]string, len(table.PrimaryKeys))
rows, err := session.DB().Query(newsql, args...)
rows, err := session.NoCache().queryRows(newsql, args...)
if err != nil {
return false, err
}
@ -153,19 +164,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
}
ids = []core.PK{pk}
session.Engine.logger.Debug("[cacheGet] cache ids:", newsql, ids)
session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil {
return false, err
}
} else {
session.Engine.logger.Debug("[cacheGet] cache hit sql:", newsql)
session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql, ids)
}
if len(ids) > 0 {
structValue := reflect.Indirect(reflect.ValueOf(bean))
id := ids[0]
session.Engine.logger.Debug("[cacheGet] get bean:", tableName, id)
session.engine.logger.Debug("[cacheGet] get bean:", tableName, id)
sid, err := id.ToString()
if err != nil {
return false, err
@ -173,15 +184,15 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
cacheBean := cacher.GetBean(tableName, sid)
if cacheBean == nil {
cacheBean = bean
has, err = session.nocacheGet(reflect.Struct, cacheBean, sqlStr, args...)
has, err = session.nocacheGet(reflect.Struct, table, cacheBean, sqlStr, args...)
if err != nil || !has {
return has, err
}
session.Engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean)
session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean)
cacher.PutBean(tableName, sid, cacheBean)
} else {
session.Engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean)
session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean)
has = true
}
structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean)))

View File

@ -19,17 +19,16 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
var affected int64
var err error
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
defer session.resetStatement()
for _, bean := range beans {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice {
size := sliceValue.Len()
if size > 0 {
if session.Engine.SupportInsertMany() {
if session.engine.SupportInsertMany() {
cnt, err := session.innerInsertMulti(bean)
if err != nil {
return affected, err
@ -67,13 +66,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, errors.New("could not insert a empty slice")
}
session.Statement.setRefValue(sliceValue.Index(0))
if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
return 0, err
}
if len(session.Statement.TableName()) <= 0 {
if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound
}
table := session.Statement.RefTable
table := session.statement.RefTable
size := sliceValue.Len()
var colNames []string
@ -114,18 +115,18 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted {
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
}
if session.Statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
val, t := session.Engine.NowTime2(col.SQLType.Name)
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
@ -133,7 +134,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.Statement.checkVersion {
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
@ -169,18 +170,18 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted {
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
}
if session.Statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
val, t := session.Engine.NowTime2(col.SQLType.Name)
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
@ -188,7 +189,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.Statement.checkVersion {
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
@ -212,25 +213,26 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
var statement string
if session.Engine.dialect.DBType() == core.ORACLE {
var tableName = session.statement.TableName()
if session.engine.dialect.DBType() == core.ORACLE {
sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
session.Engine.Quote(session.Statement.TableName()),
session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr() + ", " + session.Engine.QuoteStr()),
session.Engine.QuoteStr())
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr())
statement = fmt.Sprintf(sql,
session.Engine.Quote(session.Statement.TableName()),
session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr() + ", " + session.Engine.QuoteStr()),
session.Engine.QuoteStr(),
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, temp))
} else {
statement = fmt.Sprintf(sql,
session.Engine.Quote(session.Statement.TableName()),
session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr() + ", " + session.Engine.QuoteStr()),
session.Engine.QuoteStr(),
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(statement, args...)
@ -238,8 +240,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, err
}
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
lenAfterClosures := len(session.afterClosures)
@ -247,7 +249,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
// handle AfterInsertProcessor
if session.IsAutoCommit {
if session.isAutoCommit {
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
for _, closure := range session.afterClosures {
closure(elemValue)
@ -278,8 +280,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
// InsertMulti insert multiple records
func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
@ -297,12 +298,14 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
}
func (session *Session) innerInsert(bean interface{}) (int64, error) {
session.Statement.setRefValue(rValue(bean))
if len(session.Statement.TableName()) <= 0 {
if err := session.statement.setRefValue(rValue(bean)); err != nil {
return 0, err
}
if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound
}
table := session.Statement.RefTable
table := session.statement.RefTable
// handle BeforeInsertProcessor
for _, closure := range session.beforeClosures {
@ -314,12 +317,12 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
processor.BeforeInsert()
}
// --
colNames, args, err := genCols(session.Statement.RefTable, session, bean, false, false)
colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false)
if err != nil {
return 0, err
}
// insert expr columns, override if exists
exprColumns := session.Statement.getExpr()
exprColumns := session.statement.getExpr()
exprColVals := make([]string, 0, len(exprColumns))
for _, v := range exprColumns {
// remove the expr columns
@ -339,18 +342,30 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if len(exprColVals) > 0 {
colPlaces = colPlaces + strings.Join(exprColVals, ", ")
} else {
if len(colPlaces) > 0 {
colPlaces = colPlaces[0 : len(colPlaces)-2]
}
}
sqlStr := fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
session.Engine.Quote(session.Statement.TableName()),
session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.Quote(", ")),
session.Engine.QuoteStr(),
var sqlStr string
var tableName = session.statement.TableName()
if len(colPlaces) > 0 {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.Quote(", ")),
session.engine.QuoteStr(),
colPlaces)
} else {
if session.engine.dialect.DBType() == core.MYSQL {
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
} else {
sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.engine.Quote(tableName))
}
}
handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.IsAutoCommit {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}
@ -379,23 +394,22 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself.
if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
//assert table.AutoIncrement != ""
res, err := session.query("select seq_atable.currval from dual", args...)
if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil {
return 0, err
}
handleAfterInsertProcessorFunc(bean)
defer handleAfterInsertProcessorFunc(bean)
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
if table.Version != "" && session.Statement.checkVersion {
if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1)
}
@ -413,7 +427,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@ -423,24 +437,24 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil
} else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
} else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
//assert table.AutoIncrement != ""
sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement)
res, err := session.query(sqlStr, args...)
sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
res, err := session.queryBytes(sqlStr, args...)
if err != nil {
return 0, err
}
handleAfterInsertProcessorFunc(bean)
defer handleAfterInsertProcessorFunc(bean)
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
if table.Version != "" && session.Statement.checkVersion {
if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1)
}
@ -458,7 +472,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@ -476,14 +490,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
defer handleAfterInsertProcessorFunc(bean)
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
if table.Version != "" && session.Statement.checkVersion {
if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1)
}
@ -501,7 +515,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@ -518,24 +532,21 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// The in parameter bean must a struct or a point to struct. The return
// parameter is inserted and error
func (session *Session) InsertOne(bean interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
return session.innerInsert(bean)
}
func (session *Session) cacheInsert(tables ...string) error {
if session.Statement.RefTable == nil {
func (session *Session) cacheInsert(table *core.Table, tables ...string) error {
if table == nil {
return ErrCacheFailed
}
table := session.Statement.RefTable
cacher := session.Engine.getCacher2(table)
cacher := session.engine.getCacher2(table)
for _, t := range tables {
session.Engine.logger.Debug("[cache] clear sql:", t)
session.engine.logger.Debug("[cache] clear sql:", t)
cacher.ClearIds(t)
}

View File

@ -19,6 +19,14 @@ func (session *Session) Rows(bean interface{}) (*Rows, error) {
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct
func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
if session.isAutoClose {
defer session.Close()
}
if session.statement.bufferSize > 0 {
return session.bufferIterate(bean, fun)
}
rows, err := session.Rows(bean)
if err != nil {
return err
@ -40,3 +48,49 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
}
return err
}
// BufferSize sets the buffersize for iterate
func (session *Session) BufferSize(size int) *Session {
session.statement.bufferSize = size
return session
}
func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error {
if session.isAutoClose {
defer session.Close()
}
var bufferSize = session.statement.bufferSize
var limit = session.statement.LimitN
if limit > 0 && bufferSize > limit {
bufferSize = limit
}
var start = session.statement.Start
v := rValue(bean)
sliceType := reflect.SliceOf(v.Type())
var idx = 0
for {
slice := reflect.New(sliceType)
if err := session.Limit(bufferSize, start).find(slice.Interface(), bean); err != nil {
return err
}
for i := 0; i < slice.Elem().Len(); i++ {
if err := fun(idx, slice.Elem().Index(i).Addr().Interface()); err != nil {
return err
}
idx++
}
start = start + slice.Elem().Len()
if limit > 0 && idx+bufferSize > limit {
bufferSize = limit - idx
}
if bufferSize <= 0 || slice.Elem().Len() < bufferSize || idx == limit {
break
}
}
return nil
}

252
vendor/github.com/go-xorm/xorm/session_query.go generated vendored Normal file
View File

@ -0,0 +1,252 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
"github.com/go-xorm/builder"
"github.com/go-xorm/core"
)
func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interface{}, error) {
if len(sqlorArgs) > 0 {
return sqlorArgs[0].(string), sqlorArgs[1:], nil
}
if session.statement.RawSQL != "" {
return session.statement.RawSQL, session.statement.RawParams, nil
}
if len(session.statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
var columnStr = session.statement.ColumnStr
if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr
} else {
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
} else {
columnStr = session.statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return "", nil, err
}
args := append(session.statement.joinArgs, condArgs...)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
}
// Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlorArgs ...interface{}) ([]map[string][]byte, error) {
if session.isAutoClose {
defer session.Close()
}
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
return session.queryBytes(sqlStr, args...)
}
func value2String(rawValue *reflect.Value) (str string, err error) {
aa := reflect.TypeOf((*rawValue).Interface())
vv := reflect.ValueOf((*rawValue).Interface())
switch aa.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
case reflect.String:
str = vv.String()
case reflect.Array, reflect.Slice:
switch aa.Elem().Kind() {
case reflect.Uint8:
data := rawValue.Interface().([]byte)
str = string(data)
if str == "\x00" {
str = "0"
}
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
// time type
case reflect.Struct:
if aa.ConvertibleTo(core.TimeType) {
str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
} else {
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
case reflect.Bool:
str = strconv.FormatBool(vv.Bool())
case reflect.Complex128, reflect.Complex64:
str = fmt.Sprintf("%v", vv.Complex())
/* TODO: unsupported types below
case reflect.Map:
case reflect.Ptr:
case reflect.Uintptr:
case reflect.UnsafePointer:
case reflect.Chan, reflect.Func, reflect.Interface:
*/
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
return
}
func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) {
result := make(map[string]string)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
// if row is null then as empty string
if rawValue.Interface() == nil {
result[key] = ""
continue
}
if data, err := value2String(&rawValue); err == nil {
result[key] = data
} else {
return nil, err
}
}
return result, nil
}
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2mapStr(rows, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
// QueryString runs a raw sql and return records as []map[string]string
func (session *Session) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) {
if session.isAutoClose {
defer session.Close()
}
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) {
resultsMap = make(map[string]interface{}, len(fields))
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
resultsMap[key] = reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])).Interface()
}
return
}
func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2mapInterface(rows, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
// QueryInterface runs a raw sql and return records as []map[string]interface{}
func (session *Session) QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) {
if session.isAutoClose {
defer session.Close()
}
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Interfaces(rows)
}

View File

@ -6,96 +6,179 @@ package xorm
import (
"database/sql"
"reflect"
"time"
"github.com/go-xorm/core"
)
func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
session.queryPreprocess(&sqlStr, paramStr...)
if session.IsAutoCommit {
return session.innerQuery2(sqlStr, paramStr...)
}
return session.txQuery(session.Tx, sqlStr, paramStr...)
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable)
}
func (session *Session) txQuery(tx *core.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) {
rows, err := tx.Query(sqlStr, params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2maps(rows)
session.lastSQL = *sqlStr
session.lastSQLArgs = paramStr
}
func (session *Session) innerQuery(sqlStr string, params ...interface{}) (*core.Stmt, *core.Rows, error) {
var callback func() (*core.Stmt, *core.Rows, error)
if session.prepareStmt {
callback = func() (*core.Stmt, *core.Rows, error) {
stmt, err := session.doPrepare(sqlStr)
if err != nil {
return nil, nil, err
}
rows, err := stmt.Query(params...)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
} else {
callback = func() (*core.Stmt, *core.Rows, error) {
rows, err := session.DB().Query(sqlStr, params...)
if err != nil {
return nil, nil, err
}
return nil, rows, err
}
}
stmt, rows, err := session.Engine.logSQLQueryTime(sqlStr, params, callback)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
func (session *Session) innerQuery2(sqlStr string, params ...interface{}) ([]map[string][]byte, error) {
_, rows, err := session.innerQuery(sqlStr, params...)
if rows != nil {
defer rows.Close()
}
if err != nil {
return nil, err
}
return rows2maps(rows)
}
// Query a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Rows, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
session.queryPreprocess(&sqlStr, args...)
if session.engine.showSQL {
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
return session.query(sqlStr, paramStr...)
if session.isAutoCommit {
var db *core.DB
if session.engine.engineGroup != nil {
db = session.engine.engineGroup.Slave().DB()
} else {
db = session.DB()
}
// =============================
// for string
// =============================
func (session *Session) query2(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string]string, err error) {
session.queryPreprocess(&sqlStr, paramStr...)
if session.IsAutoCommit {
return query2(session.DB(), sqlStr, paramStr...)
}
return txQuery2(session.Tx, sqlStr, paramStr...)
}
// Execute sql
func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) {
if session.prepareStmt {
stmt, err := session.doPrepare(sqlStr)
// don't clear stmt since session will cache them
stmt, err := session.doPrepare(db, sqlStr)
if err != nil {
return nil, err
}
rows, err := stmt.Query(args...)
if err != nil {
return nil, err
}
return rows, nil
}
rows, err := db.Query(sqlStr, args...)
if err != nil {
return nil, err
}
return rows, nil
}
rows, err := session.tx.Query(sqlStr, args...)
if err != nil {
return nil, err
}
return rows, nil
}
func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row {
return core.NewRow(session.queryRows(sqlStr, args...))
}
func value2Bytes(rawValue *reflect.Value) ([]byte, error) {
str, err := value2String(rawValue)
if err != nil {
return nil, err
}
return []byte(str), nil
}
func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) {
result := make(map[string][]byte)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
result[key] = []byte{}
continue
}
if data, err := value2Bytes(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2map(rows, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2maps(rows)
}
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
defer session.resetStatement()
session.queryPreprocess(&sqlStr, args...)
if session.engine.showSQL {
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
if !session.isAutoCommit {
return session.tx.Exec(sqlStr, args...)
}
if session.prepareStmt {
stmt, err := session.doPrepare(session.DB(), sqlStr)
if err != nil {
return nil, err
}
@ -110,33 +193,9 @@ func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Resul
return session.DB().Exec(sqlStr, args...)
}
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
for _, filter := range session.Engine.dialect.Filters() {
// TODO: for table name, it's no need to RefTable
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
}
session.saveLastSQL(sqlStr, args...)
return session.Engine.logSQLExecutionTime(sqlStr, args, func() (sql.Result, error) {
if session.IsAutoCommit {
// FIXME: oci8 can not auto commit (github.com/mattn/go-oci8)
if session.Engine.dialect.DBType() == core.ORACLE {
session.Begin()
r, err := session.Tx.Exec(sqlStr, args...)
session.Commit()
return r, err
}
return session.innerExec(sqlStr, args...)
}
return session.Tx.Exec(sqlStr, args...)
})
}
// Exec raw sql
func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}

View File

@ -16,38 +16,50 @@ import (
// Ping test if database is ok
func (session *Session) Ping() error {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().Ping()
}
// CreateTable create a table according a bean
func (session *Session) CreateTable(bean interface{}) error {
v := rValue(bean)
session.Statement.setRefValue(v)
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
return session.createOneTable()
return session.createTable(bean)
}
func (session *Session) createTable(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqlStr := session.statement.genCreateTableSQL()
_, err := session.exec(sqlStr)
return err
}
// CreateIndexes create indexes
func (session *Session) CreateIndexes(bean interface{}) error {
v := rValue(bean)
session.Statement.setRefValue(v)
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
sqls := session.Statement.genIndexSQL()
return session.createIndexes(bean)
}
func (session *Session) createIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqls := session.statement.genIndexSQL()
for _, sqlStr := range sqls {
_, err := session.exec(sqlStr)
if err != nil {
@ -59,15 +71,19 @@ func (session *Session) CreateIndexes(bean interface{}) error {
// CreateUniques create uniques
func (session *Session) CreateUniques(bean interface{}) error {
v := rValue(bean)
session.Statement.setRefValue(v)
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
return session.createUniques(bean)
}
sqls := session.Statement.genUniqueSQL()
func (session *Session) createUniques(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqls := session.statement.genUniqueSQL()
for _, sqlStr := range sqls {
_, err := session.exec(sqlStr)
if err != nil {
@ -77,41 +93,22 @@ func (session *Session) CreateUniques(bean interface{}) error {
return nil
}
func (session *Session) createOneTable() error {
sqlStr := session.Statement.genCreateTableSQL()
_, err := session.exec(sqlStr)
return err
}
// to be deleted
func (session *Session) createAll() error {
if session.IsAutoClose {
defer session.Close()
}
for _, table := range session.Engine.Tables {
session.Statement.RefTable = table
session.Statement.tableName = table.Name
err := session.createOneTable()
session.resetStatement()
if err != nil {
return err
}
}
return nil
}
// DropIndexes drop indexes
func (session *Session) DropIndexes(bean interface{}) error {
v := rValue(bean)
session.Statement.setRefValue(v)
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
sqls := session.Statement.genDelIndexSQL()
return session.dropIndexes(bean)
}
func (session *Session) dropIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqls := session.statement.genDelIndexSQL()
for _, sqlStr := range sqls {
_, err := session.exec(sqlStr)
if err != nil {
@ -123,15 +120,23 @@ func (session *Session) DropIndexes(bean interface{}) error {
// DropTable drop table will drop table if exist, if drop failed, it will return error
func (session *Session) DropTable(beanOrTableName interface{}) error {
tableName, err := session.Engine.tableName(beanOrTableName)
if session.isAutoClose {
defer session.Close()
}
return session.dropTable(beanOrTableName)
}
func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return err
}
var needDrop = true
if !session.Engine.dialect.SupportDropIfExists() {
sqlStr, args := session.Engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...)
if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.queryBytes(sqlStr, args...)
if err != nil {
return err
}
@ -139,7 +144,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
}
if needDrop {
sqlStr := session.Engine.Dialect().DropTableSql(tableName)
sqlStr := session.engine.Dialect().DropTableSql(tableName)
_, err = session.exec(sqlStr)
return err
}
@ -148,7 +153,11 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
// IsTableExist if a table is exist
func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
tableName, err := session.Engine.tableName(beanOrTableName)
if session.isAutoClose {
defer session.Close()
}
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return false, err
}
@ -157,12 +166,8 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
}
func (session *Session) isTableExist(tableName string) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
sqlStr, args := session.Engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...)
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.queryBytes(sqlStr, args...)
return len(results) > 0, err
}
@ -172,6 +177,9 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
t := v.Type()
if t.Kind() == reflect.String {
if session.isAutoClose {
defer session.Close()
}
return session.isTableEmpty(bean.(string))
} else if t.Kind() == reflect.Struct {
rows, err := session.Count(bean)
@ -181,15 +189,9 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
}
func (session *Session) isTableEmpty(tableName string) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.Engine.Quote(tableName))
err := session.DB().QueryRow(sqlStr).Scan(&total)
session.saveLastSQL(sqlStr)
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
err := session.queryRow(sqlStr).Scan(&total)
if err != nil {
if err == sql.ErrNoRows {
err = nil
@ -200,30 +202,9 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) {
return total == 0, nil
}
func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var idx string
if unique {
idx = uniqueName(tableName, idxName)
} else {
idx = indexName(tableName, idxName)
}
sqlStr, args := session.Engine.dialect.IndexCheckSql(tableName, idx)
results, err := session.query(sqlStr, args...)
return len(results) > 0, err
}
// find if index is exist according cols
func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
indexes, err := session.Engine.dialect.GetIndexes(tableName)
indexes, err := session.engine.dialect.GetIndexes(tableName)
if err != nil {
return false, err
}
@ -240,62 +221,34 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo
}
func (session *Session) addColumn(colName string) error {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
col := session.Statement.RefTable.GetColumn(colName)
sql, args := session.Statement.genAddColumnStr(col)
col := session.statement.RefTable.GetColumn(colName)
sql, args := session.statement.genAddColumnStr(col)
_, err := session.exec(sql, args...)
return err
}
func (session *Session) addIndex(tableName, idxName string) error {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
index := session.Statement.RefTable.Indexes[idxName]
sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index)
index := session.statement.RefTable.Indexes[idxName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
_, err := session.exec(sqlStr)
return err
}
func (session *Session) addUnique(tableName, uqeName string) error {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
index := session.Statement.RefTable.Indexes[uqeName]
sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index)
index := session.statement.RefTable.Indexes[uqeName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
_, err := session.exec(sqlStr)
return err
}
// To be deleted
func (session *Session) dropAll() error {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
for _, table := range session.Engine.Tables {
session.Statement.Init()
session.Statement.RefTable = table
sqlStr := session.Engine.Dialect().DropTableSql(session.Statement.TableName())
_, err := session.exec(sqlStr)
if err != nil {
return err
}
}
return nil
}
// Sync2 synchronize structs to database tables
func (session *Session) Sync2(beans ...interface{}) error {
engine := session.Engine
engine := session.engine
if session.isAutoClose {
session.isAutoClose = false
defer session.Close()
}
tables, err := engine.DBMetas()
if err != nil {
@ -322,17 +275,17 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
if oriTable == nil {
err = session.StoreEngine(session.Statement.StoreEngine).CreateTable(bean)
err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
if err != nil {
return err
}
err = session.CreateUniques(bean)
err = session.createUniques(bean)
if err != nil {
return err
}
err = session.CreateIndexes(bean)
err = session.createIndexes(bean)
if err != nil {
return err
}
@ -357,7 +310,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
engine.dialect.DBType() == core.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbName, col.Name, curType, expectedType)
_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col))
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbName, col.Name, curType, expectedType)
@ -367,7 +320,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col))
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
}
}
} else {
@ -381,7 +334,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col))
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
}
}
}
@ -394,10 +347,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
tbName, col.Name, oriCol.Nullable, col.Nullable)
}
} else {
session := engine.NewSession()
session.Statement.RefTable = table
session.Statement.tableName = tbName
defer session.Close()
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addColumn(col.Name)
}
if err != nil {
@ -421,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbName, oriIndex)
_, err = engine.Exec(sql)
_, err = session.exec(sql)
if err != nil {
return err
}
@ -437,7 +388,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbName, index2)
_, err = engine.Exec(sql)
_, err = session.exec(sql)
if err != nil {
return err
}
@ -446,16 +397,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames {
if index.Type == core.UniqueType {
session := engine.NewSession()
session.Statement.RefTable = table
session.Statement.tableName = tbName
defer session.Close()
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addUnique(tbName, name)
} else if index.Type == core.IndexType {
session := engine.NewSession()
session.Statement.RefTable = table
session.Statement.tableName = tbName
defer session.Close()
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addIndex(tbName, name)
}
if err != nil {

98
vendor/github.com/go-xorm/xorm/session_stats.go generated vendored Normal file
View File

@ -0,0 +1,98 @@
// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql"
"errors"
"reflect"
)
// Count counts the records. bean's non-empty fields
// are conditions.
func (session *Session) Count(bean ...interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if session.statement.RawSQL == "" {
sqlStr, args, err = session.statement.genCountSQL(bean...)
if err != nil {
return 0, err
}
} else {
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
}
var total int64
err = session.queryRow(sqlStr, args...).Scan(&total)
if err == sql.ErrNoRows || err == nil {
return total, nil
}
return 0, err
}
// sum call sum some column. bean's non-empty fields are conditions.
func (session *Session) sum(res interface{}, bean interface{}, columnNames ...string) error {
if session.isAutoClose {
defer session.Close()
}
v := reflect.ValueOf(res)
if v.Kind() != reflect.Ptr {
return errors.New("need a pointer to a variable")
}
var isSlice = v.Elem().Kind() == reflect.Slice
var sqlStr string
var args []interface{}
var err error
if len(session.statement.RawSQL) == 0 {
sqlStr, args, err = session.statement.genSumSQL(bean, columnNames...)
if err != nil {
return err
}
} else {
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
}
if isSlice {
err = session.queryRow(sqlStr, args...).ScanSlice(res)
} else {
err = session.queryRow(sqlStr, args...).Scan(res)
}
if err == sql.ErrNoRows || err == nil {
return nil
}
return err
}
// Sum call sum some column. bean's non-empty fields are conditions.
func (session *Session) Sum(bean interface{}, columnName string) (res float64, err error) {
return res, session.sum(&res, bean, columnName)
}
// SumInt call sum some column. bean's non-empty fields are conditions.
func (session *Session) SumInt(bean interface{}, columnName string) (res int64, err error) {
return res, session.sum(&res, bean, columnName)
}
// Sums call sum some columns. bean's non-empty fields are conditions.
func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) {
var res = make([]float64, len(columnNames), len(columnNames))
return res, session.sum(&res, bean, columnNames...)
}
// SumsInt sum specify columns and return as []int64 instead of []float64
func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) {
var res = make([]int64, len(columnNames), len(columnNames))
return res, session.sum(&res, bean, columnNames...)
}

View File

@ -1,137 +0,0 @@
// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import "database/sql"
// Count counts the records. bean's non-empty fields
// are conditions.
func (session *Session) Count(bean interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
if session.Statement.RawSQL == "" {
sqlStr, args = session.Statement.genCountSQL(bean)
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var err error
var total int64
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
} else {
err = session.Tx.QueryRow(sqlStr, args...).Scan(&total)
}
if err == sql.ErrNoRows || err == nil {
return total, nil
}
return 0, err
}
// Sum call sum some column. bean's non-empty fields are conditions.
func (session *Session) Sum(bean interface{}, columnName string) (float64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
if len(session.Statement.RawSQL) == 0 {
sqlStr, args = session.Statement.genSumSQL(bean, columnName)
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var err error
var res float64
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).Scan(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return 0, err
}
// Sums call sum some columns. bean's non-empty fields are conditions.
func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
if len(session.Statement.RawSQL) == 0 {
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var err error
var res = make([]float64, len(columnNames), len(columnNames))
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return nil, err
}
// SumsInt sum specify columns and return as []int64 instead of []float64
func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
if len(session.Statement.RawSQL) == 0 {
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var err error
var res = make([]int64, len(columnNames), len(columnNames))
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return nil, err
}

View File

@ -6,14 +6,14 @@ package xorm
// Begin a transaction
func (session *Session) Begin() error {
if session.IsAutoCommit {
if session.isAutoCommit {
tx, err := session.DB().Begin()
if err != nil {
return err
}
session.IsAutoCommit = false
session.IsCommitedOrRollbacked = false
session.Tx = tx
session.isAutoCommit = false
session.isCommitedOrRollbacked = false
session.tx = tx
session.saveLastSQL("BEGIN TRANSACTION")
}
return nil
@ -21,25 +21,23 @@ func (session *Session) Begin() error {
// Rollback When using transaction, you can rollback if any error
func (session *Session) Rollback() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
session.saveLastSQL(session.Engine.dialect.RollBackStr())
session.IsCommitedOrRollbacked = true
return session.Tx.Rollback()
if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL(session.engine.dialect.RollBackStr())
session.isCommitedOrRollbacked = true
return session.tx.Rollback()
}
return nil
}
// Commit When using transaction, Commit will commit all operations.
func (session *Session) Commit() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL("COMMIT")
session.IsCommitedOrRollbacked = true
session.isCommitedOrRollbacked = true
var err error
if err = session.Tx.Commit(); err == nil {
if err = session.tx.Commit(); err == nil {
// handle processors after tx committed
closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) {
if closuresPtr != nil {
for _, closure := range *closuresPtr {
closure(bean)

View File

@ -15,20 +15,20 @@ import (
"github.com/go-xorm/core"
)
func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
if session.Statement.RefTable == nil ||
session.Tx != nil {
func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error {
if table == nil ||
session.tx != nil {
return ErrCacheFailed
}
oldhead, newsql := session.Statement.convertUpdateSQL(sqlStr)
oldhead, newsql := session.statement.convertUpdateSQL(sqlStr)
if newsql == "" {
return ErrCacheFailed
}
for _, filter := range session.Engine.dialect.Filters() {
newsql = filter.Do(newsql, session.Engine.dialect, session.Statement.RefTable)
for _, filter := range session.engine.dialect.Filters() {
newsql = filter.Do(newsql, session.engine.dialect, table)
}
session.Engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
var nStart int
if len(args) > 0 {
@ -39,13 +39,12 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
nStart = strings.Count(oldhead, "$")
}
}
table := session.Statement.RefTable
cacher := session.Engine.getCacher2(table)
tableName := session.Statement.TableName()
session.Engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
cacher := session.engine.getCacher2(table)
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil {
rows, err := session.DB().Query(newsql, args[nStart:]...)
rows, err := session.NoCache().queryRows(newsql, args[nStart:]...)
if err != nil {
return err
}
@ -75,9 +74,9 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
ids = append(ids, pk)
}
session.Engine.logger.Debug("[cacheUpdate] find updated id", ids)
session.engine.logger.Debug("[cacheUpdate] find updated id", ids)
} /*else {
session.Engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
cacher.DelIds(tableName, genSqlKey(newsql, args))
}*/
@ -103,36 +102,36 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
colName := sps2[len(sps2)-1]
if strings.Contains(colName, "`") {
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
} else if strings.Contains(colName, session.Engine.QuoteStr()) {
colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1))
} else if strings.Contains(colName, session.engine.QuoteStr()) {
colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1))
} else {
session.Engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed
}
if col := table.GetColumn(colName); col != nil {
fieldValue, err := col.ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
} else {
session.Engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
if col.IsVersion && session.Statement.checkVersion {
session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
if col.IsVersion && session.statement.checkVersion {
fieldValue.SetInt(fieldValue.Int() + 1)
} else {
fieldValue.Set(reflect.ValueOf(args[idx]))
}
}
} else {
session.Engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's",
session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's",
colName, table.Name)
}
}
session.Engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean)
session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean)
cacher.PutBean(tableName, sid, bean)
}
}
session.Engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName)
session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName)
cacher.ClearIds(tableName)
return nil
}
@ -144,8 +143,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
// You should call UseBool if you have bool to use.
// 2.float32 & float64 may be not inexact as conditions
func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
@ -169,19 +167,21 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct
if isStruct {
session.Statement.setRefValue(v)
if err := session.statement.setRefValue(v); err != nil {
return 0, err
}
if len(session.Statement.TableName()) <= 0 {
if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound
}
if session.Statement.ColumnStr == "" {
colNames, args = buildUpdates(session.Engine, session.Statement.RefTable, bean, false, false,
false, false, session.Statement.allUseBool, session.Statement.useAllCols,
session.Statement.mustColumnMap, session.Statement.nullableMap,
session.Statement.columnMap, true, session.Statement.unscoped)
if session.statement.ColumnStr == "" {
colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false,
false, false, session.statement.allUseBool, session.statement.useAllCols,
session.statement.mustColumnMap, session.statement.nullableMap,
session.statement.columnMap, true, session.statement.unscoped)
} else {
colNames, args, err = genCols(session.Statement.RefTable, session, bean, true, true)
colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true)
if err != nil {
return 0, err
}
@ -192,19 +192,20 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
bValue := reflect.Indirect(reflect.ValueOf(bean))
for _, v := range bValue.MapKeys() {
colNames = append(colNames, session.Engine.Quote(v.String())+" = ?")
colNames = append(colNames, session.engine.Quote(v.String())+" = ?")
args = append(args, bValue.MapIndex(v).Interface())
}
} else {
return 0, ErrParamsType
}
table := session.Statement.RefTable
table := session.statement.RefTable
if session.Statement.UseAutoTime && table != nil && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
if session.statement.UseAutoTime && table != nil && table.Updated != "" {
if _, ok := session.statement.columnMap[strings.ToLower(table.Updated)]; !ok {
colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?")
col := table.UpdatedColumn()
val, t := session.Engine.NowTime2(col.SQLType.Name)
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
@ -215,45 +216,60 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
})
}
}
}
//for update action to like "column = column + ?"
incColumns := session.Statement.getInc()
incColumns := session.statement.getInc()
for _, v := range incColumns {
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" + ?")
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
args = append(args, v.arg)
}
//for update action to like "column = column - ?"
decColumns := session.Statement.getDec()
decColumns := session.statement.getDec()
for _, v := range decColumns {
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" - ?")
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
args = append(args, v.arg)
}
//for update action to like "column = expression"
exprColumns := session.Statement.getExpr()
exprColumns := session.statement.getExpr()
for _, v := range exprColumns {
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr)
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)
}
session.Statement.processIDParam()
if err = session.statement.processIDParam(); err != nil {
return 0, err
}
var autoCond builder.Cond
if !session.Statement.noAutoCondition && len(condiBean) > 0 {
if !session.statement.noAutoCondition && len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok {
autoCond = builder.Eq(c)
} else {
ct := reflect.TypeOf(condiBean[0])
k := ct.Kind()
if k == reflect.Ptr {
k = ct.Elem().Kind()
}
if k == reflect.Struct {
var err error
autoCond, err = session.Statement.buildConds(session.Statement.RefTable, condiBean[0], true, true, false, true, false)
autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
if err != nil {
return 0, err
}
} else {
return 0, ErrConditionType
}
}
}
st := session.Statement
defer session.resetStatement()
st := &session.statement
var sqlStr string
var condArgs []interface{}
var condSQL string
cond := session.Statement.cond.And(autoCond)
cond := session.statement.cond.And(autoCond)
var doIncVer = (table != nil && table.Version != "" && session.Statement.checkVersion)
var doIncVer = (table != nil && table.Version != "" && session.statement.checkVersion)
var verValue *reflect.Value
if doIncVer {
verValue, err = table.VersionColumn().ValueOf(bean)
@ -261,11 +277,15 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, err
}
cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()})
colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
}
condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
}
condSQL, condArgs, _ = builder.ToSQL(cond)
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
@ -274,6 +294,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr)
}
var tableName = session.statement.TableName()
// TODO: Oracle support needed
var top string
if st.LimitN > 0 {
@ -282,27 +303,53 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} else if st.Engine.dialect.DBType() == core.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
condSQL, condArgs, _ = builder.ToSQL(cond)
session.engine.Quote(tableName), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
}
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else if st.Engine.dialect.DBType() == core.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
condSQL, condArgs, _ = builder.ToSQL(cond)
session.engine.Quote(tableName), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
}
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else if st.Engine.dialect.DBType() == core.MSSQL {
top = fmt.Sprintf("top (%d) ", st.LimitN)
if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
session.engine.Quote(tableName), condSQL), condArgs...)
condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
}
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else {
top = fmt.Sprintf("TOP (%d) ", st.LimitN)
}
}
}
if len(colNames) <= 0 {
return 0, errors.New("No content found to be updated")
}
sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v",
top,
session.Engine.Quote(session.Statement.TableName()),
session.engine.Quote(tableName),
strings.Join(colNames, ", "),
condSQL)
@ -316,19 +363,20 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
if table != nil {
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
cacher.ClearIds(session.Statement.TableName())
cacher.ClearBeans(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
//session.cacheUpdate(table, tableName, sqlStr, args...)
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
}
}
// handle after update processors
if session.IsAutoCommit {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}
if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok {
session.Engine.logger.Debug("[event]", session.Statement.TableName(), " has after update processor")
session.engine.logger.Debug("[event]", tableName, " has after update processor")
processor.AfterUpdate()
}
} else {

View File

@ -73,6 +73,7 @@ type Statement struct {
decrColumns map[string]decrParam
exprColumns map[string]exprParam
cond builder.Cond
bufferSize int
}
// Init reset all the statement's fields
@ -111,6 +112,7 @@ func (statement *Statement) Init() {
statement.decrColumns = make(map[string]decrParam)
statement.exprColumns = make(map[string]exprParam)
statement.cond = builder.NewCond()
statement.bufferSize = 0
}
// NoAutoCondition if you do not want convert bean's field as query condition, then use this function
@ -158,6 +160,9 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
case string:
cond := builder.Expr(query.(string), args...)
statement.cond = statement.cond.And(cond)
case map[string]interface{}:
cond := builder.Eq(query.(map[string]interface{}))
statement.cond = statement.cond.And(cond)
case builder.Cond:
cond := query.(builder.Cond)
statement.cond = statement.cond.And(cond)
@ -179,6 +184,9 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen
case string:
cond := builder.Expr(query.(string), args...)
statement.cond = statement.cond.Or(cond)
case map[string]interface{}:
cond := builder.Eq(query.(map[string]interface{}))
statement.cond = statement.cond.Or(cond)
case builder.Cond:
cond := query.(builder.Cond)
statement.cond = statement.cond.Or(cond)
@ -207,9 +215,14 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement
return statement
}
func (statement *Statement) setRefValue(v reflect.Value) {
statement.RefTable = statement.Engine.autoMapType(reflect.Indirect(v))
func (statement *Statement) setRefValue(v reflect.Value) error {
var err error
statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
if err != nil {
return err
}
statement.tableName = statement.Engine.tbName(v)
return nil
}
// Table tempororily set table name, the parameter could be a string or a pointer of struct
@ -219,7 +232,12 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
statement.RefTable = statement.Engine.autoMapType(v)
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
statement.AltTableName = statement.Engine.tbName(v)
}
return statement
@ -262,6 +280,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
fieldValue := *fieldValuePtr
fieldType := reflect.TypeOf(fieldValue.Interface())
if fieldType == nil {
continue
}
requiredField := useAllCols
includeNil := useAllCols
@ -366,7 +387,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = engine.FormatTime(col.SQLType.Name, t)
val = engine.formatColTime(col, t)
} else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = nulType.Value()
} else {
@ -480,224 +501,6 @@ func (statement *Statement) colName(col *core.Column, tableName string) string {
return statement.Engine.Quote(col.Name)
}
func buildConds(engine *Engine, table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
var conds []builder.Cond
for _, col := range table.Columns() {
if !includeVersion && col.IsVersion {
continue
}
if !includeUpdated && col.IsUpdated {
continue
}
if !includeAutoIncr && col.IsAutoIncrement {
continue
}
if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) {
continue
}
if col.SQLType.IsJson() {
continue
}
var colName string
if addedTableName {
var nm = tableName
if len(aliasName) > 0 {
nm = aliasName
}
colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
} else {
colName = engine.Quote(col.Name)
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
engine.logger.Error(err)
continue
}
if col.IsDeleted && !unscoped { // tag "deleted" is enabled
if engine.dialect.DBType() == core.MSSQL {
conds = append(conds, builder.IsNull{colName})
} else {
conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"}))
}
}
fieldValue := *fieldValuePtr
if fieldValue.Interface() == nil {
continue
}
fieldType := reflect.TypeOf(fieldValue.Interface())
requiredField := useAllCols
if b, ok := getFlagForColumn(mustColumnMap, col); ok {
if b {
requiredField = true
} else {
continue
}
}
if fieldType.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
if includeNil {
conds = append(conds, builder.Eq{colName: nil})
}
continue
} else if !fieldValue.IsValid() {
continue
} else {
// dereference ptr type to instance type
fieldValue = fieldValue.Elem()
fieldType = reflect.TypeOf(fieldValue.Interface())
requiredField = true
}
}
var val interface{}
switch fieldType.Kind() {
case reflect.Bool:
if allUseBool || requiredField {
val = fieldValue.Interface()
} else {
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
continue
}
case reflect.String:
if !requiredField && fieldValue.String() == "" {
continue
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
} else {
val = fieldValue.Interface()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if !requiredField && fieldValue.Int() == 0 {
continue
}
val = fieldValue.Interface()
case reflect.Float32, reflect.Float64:
if !requiredField && fieldValue.Float() == 0.0 {
continue
}
val = fieldValue.Interface()
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if !requiredField && fieldValue.Uint() == 0 {
continue
}
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = engine.FormatTime(col.SQLType.Name, t)
} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value()
if val == nil {
continue
}
} else {
if col.SQLType.IsJson() {
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !isZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else {
//TODO: how to handler?
panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys))
}
} else {
val = fieldValue.Interface()
}
}
}
case reflect.Array:
continue
case reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
continue
}
if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
continue
}
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
val = fieldValue.Bytes()
} else {
continue
}
} else {
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
continue
}
default:
val = fieldValue.Interface()
}
conds = append(conds, builder.Eq{colName: val})
}
return builder.And(conds...), nil
}
// TableName return current tableName
func (statement *Statement) TableName() string {
if statement.AltTableName != "" {
@ -800,6 +603,22 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
return newColumns
}
func (statement *Statement) colmap2NewColsWithQuote() []string {
newColumns := make([]string, 0, len(statement.columnMap))
for col := range statement.columnMap {
fields := strings.Split(strings.TrimSpace(col), ".")
if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
}
return newColumns
}
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
@ -826,7 +645,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
statement.columnMap[strings.ToLower(nc)] = true
}
newColumns := statement.col2NewColsWithQuote(columns...)
newColumns := statement.colmap2NewColsWithQuote()
statement.ColumnStr = strings.Join(newColumns, ", ")
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
return statement
@ -1088,33 +907,50 @@ func (statement *Statement) genDelIndexSQL() []string {
func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
quote := statement.Engine.Quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(statement.TableName()),
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
col.String(statement.Engine.dialect))
if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ";"
return sql, []interface{}{}
}
func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
}
func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
func (statement *Statement) mergeConds(bean interface{}) error {
if !statement.noAutoCondition {
var addedTableName = (len(statement.JoinStr) > 0)
autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
if err != nil {
return "", nil, err
return err
}
statement.cond = statement.cond.And(autoCond)
}
statement.processIDParam()
if err := statement.processIDParam(); err != nil {
return err
}
return nil
}
func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
return builder.ToSQL(statement.cond)
}
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
statement.setRefValue(rValue(bean))
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
v := rValue(bean)
isStruct := v.Kind() == reflect.Struct
if isStruct {
statement.setRefValue(v)
}
var columnStr = statement.ColumnStr
if len(statement.selectStr) > 0 {
@ -1133,22 +969,46 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{})
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
} else {
}
}
}
}
if len(columnStr) == 0 {
columnStr = "*"
}
if isStruct {
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
}
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}
condSQL, condArgs, _ := statement.genConds(bean)
return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
if err != nil {
return "", nil, err
}
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) {
statement.setRefValue(rValue(bean))
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
condSQL, condArgs, _ := statement.genConds(bean)
func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
var condSQL string
var condArgs []interface{}
var err error
if len(beans) > 0 {
statement.setRefValue(rValue(beans[0]))
condSQL, condArgs, err = statement.genConds(beans[0])
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
}
if err != nil {
return "", nil, err
}
var selectSQL = statement.selectStr
if len(selectSQL) <= 0 {
@ -1158,23 +1018,40 @@ func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}
selectSQL = "count(*)"
}
}
return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...)
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
if err != nil {
return "", nil, err
}
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) {
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
statement.setRefValue(rValue(bean))
var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns {
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", statement.Engine.Quote(colName)))
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
colName = statement.Engine.Quote(colName)
}
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
}
sumSelect := strings.Join(sumStrs, ", ")
condSQL, condArgs, err := statement.genConds(bean)
if err != nil {
return "", nil, err
}
condSQL, condArgs, _ := statement.genConds(bean)
return statement.genSelectSQL(strings.Join(sumStrs, ", "), condSQL), append(statement.joinArgs, condArgs...)
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
if err != nil {
return "", nil, err
}
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
var distinct string
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT "
@ -1185,15 +1062,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
var top string
var mssqlCondi string
statement.processIDParam()
if err := statement.processIDParam(); err != nil {
return "", err
}
var buf bytes.Buffer
if len(condSQL) > 0 {
fmt.Fprintf(&buf, " WHERE %v", condSQL)
}
var whereStr = buf.String()
var fromStr = " FROM "
if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName()
} else {
fromStr += quote(statement.TableName())
}
var fromStr = " FROM " + quote(statement.TableName())
if statement.TableAlias != "" {
if dialect.DBType() == core.ORACLE {
fromStr += " " + quote(statement.TableAlias)
@ -1246,7 +1131,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
}
// !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, fromStr, whereStr)
a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
if len(mssqlCondi) > 0 {
if len(whereStr) > 0 {
a += " AND " + mssqlCondi
@ -1282,19 +1167,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
return
}
func (statement *Statement) processIDParam() {
func (statement *Statement) processIDParam() error {
if statement.idParam == nil {
return
return nil
}
if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
len(statement.RefTable.PrimaryKeys),
len(*statement.idParam),
)
}
for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName())
if i < len(*(statement.idParam)) {
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
} else {
statement.cond = statement.cond.And(builder.Eq{colName: ""})
}
}
return nil
}
func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
@ -1328,7 +1217,8 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
top = fmt.Sprintf("TOP %d ", statement.LimitN)
}
return fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
return newsql
}
return ""
}

View File

@ -54,6 +54,7 @@ var (
"UNIQUE": UniqueTagHandler,
"CACHE": CacheTagHandler,
"NOCACHE": NoCacheTagHandler,
"COMMENT": CommentTagHandler,
}
)
@ -192,6 +193,14 @@ func UniqueTagHandler(ctx *tagContext) error {
return nil
}
// CommentTagHandler add comment to column
func CommentTagHandler(ctx *tagContext) error {
if len(ctx.params) > 0 {
ctx.col.Comment = strings.Trim(ctx.params[0], "' ")
}
return nil
}
// SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *tagContext) error {
ctx.col.SQLType = core.SQLType{Name: ctx.tagName}

View File

@ -17,7 +17,7 @@ import (
const (
// Version show the xorm's version
Version string = "0.6.2.0326"
Version string = "0.6.4.0910"
)
func regDrvsNDialects() bool {
@ -50,10 +50,13 @@ func close(engine *Engine) {
engine.Close()
}
func init() {
regDrvsNDialects()
}
// NewEngine new a db manager according to the parameter. Currently support four
// drivers
func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
regDrvsNDialects()
driver := core.QueryDriver(driverName)
if driver == nil {
return nil, fmt.Errorf("Unsupported driver name: %v", driverName)
@ -89,6 +92,12 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
tagHandlers: defaultTagHandlers,
}
if uri.DbType == core.SQLITE {
engine.DatabaseTZ = time.UTC
} else {
engine.DatabaseTZ = time.Local
}
logger := NewSimpleLogger(os.Stdout)
logger.SetLevel(core.LOG_INFO)
engine.SetLogger(logger)