updated xorm, go-sqlite3 dependencies

This commit is contained in:
Torkel Ödegaard 2015-03-26 12:41:43 +01:00
parent 71aa2ef2c2
commit fcac4c057c
36 changed files with 1875 additions and 680 deletions

8
Godeps/Godeps.json generated
View File

@ -25,12 +25,12 @@
}, },
{ {
"ImportPath": "github.com/go-xorm/core", "ImportPath": "github.com/go-xorm/core",
"Rev": "a949e067ced1cb6e6ef5c38b6f28b074fa718f1e" "Rev": "be6e7ac47dc57bd0ada25322fa526944f66ccaa6"
}, },
{ {
"ImportPath": "github.com/go-xorm/xorm", "ImportPath": "github.com/go-xorm/xorm",
"Comment": "v0.4.1-19-g5c23849", "Comment": "v0.4.2-58-ge2889e5",
"Rev": "5c23849a66f4593e68909bb6c1fa30651b5b0541" "Rev": "e2889e5517600b82905f1d2ba8b70deb71823ffe"
}, },
{ {
"ImportPath": "github.com/jtolds/gls", "ImportPath": "github.com/jtolds/gls",
@ -51,7 +51,7 @@
}, },
{ {
"ImportPath": "github.com/mattn/go-sqlite3", "ImportPath": "github.com/mattn/go-sqlite3",
"Rev": "d10e2c8f62100097910367dee90a9bd89d426a44" "Rev": "e28cd440fabdd39b9520344bc26829f61db40ece"
}, },
{ {
"ImportPath": "github.com/smartystreets/goconvey/convey", "ImportPath": "github.com/smartystreets/goconvey/convey",

114
Godeps/_workspace/src/github.com/go-xorm/core/README.md generated vendored Normal file
View File

@ -0,0 +1,114 @@
Core is a lightweight wrapper of sql.DB.
# Open
```Go
db, _ := core.Open(db, connstr)
```
# SetMapper
```Go
db.SetMapper(SameMapper())
```
## Scan usage
### Scan
```Go
rows, _ := db.Query()
for rows.Next() {
rows.Scan()
}
```
### ScanMap
```Go
rows, _ := db.Query()
for rows.Next() {
rows.ScanMap()
```
### ScanSlice
You can use `[]string`, `[][]byte`, `[]interface{}`, `[]*string`, `[]sql.NullString` to ScanSclice. Notice, slice's length should be equal or less than select columns.
```Go
rows, _ := db.Query()
cols, _ := rows.Columns()
for rows.Next() {
var s = make([]string, len(cols))
rows.ScanSlice(&s)
}
```
```Go
rows, _ := db.Query()
cols, _ := rows.Columns()
for rows.Next() {
var s = make([]*string, len(cols))
rows.ScanSlice(&s)
}
```
### ScanStruct
```Go
rows, _ := db.Query()
for rows.Next() {
rows.ScanStructByName()
rows.ScanStructByIndex()
}
```
## Query usage
```Go
rows, err := db.Query("select * from table where name = ?", name)
user = User{
Name:"lunny",
}
rows, err := db.QueryStruct("select * from table where name = ?Name",
&user)
var user = map[string]interface{}{
"name": "lunny",
}
rows, err = db.QueryMap("select * from table where name = ?name",
&user)
```
## QueryRow usage
```Go
row := db.QueryRow("select * from table where name = ?", name)
user = User{
Name:"lunny",
}
row := db.QueryRowStruct("select * from table where name = ?Name",
&user)
var user = map[string]interface{}{
"name": "lunny",
}
row = db.QueryRowMap("select * from table where name = ?name",
&user)
```
## Exec usage
```Go
db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", name, title, age, alias...)
user = User{
Name:"lunny",
Title:"test",
Age: 18,
}
result, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
&user)
var user = map[string]interface{}{
"Name": "lunny",
"Title": "test",
"Age": 18,
}
result, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
&user)
```

View File

@ -121,6 +121,21 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
col.fieldPath = strings.Split(col.FieldName, ".") col.fieldPath = strings.Split(col.FieldName, ".")
} }
if dataStruct.Type().Kind() == reflect.Map {
var keyValue reflect.Value
if len(col.fieldPath) == 1 {
keyValue = reflect.ValueOf(col.FieldName)
} else if len(col.fieldPath) == 2 {
keyValue = reflect.ValueOf(col.fieldPath[1])
} else {
return nil, fmt.Errorf("Unsupported mutliderive %v", col.FieldName)
}
fieldValue = dataStruct.MapIndex(keyValue)
return &fieldValue, nil
}
if len(col.fieldPath) == 1 { if len(col.fieldPath) == 1 {
fieldValue = dataStruct.FieldByName(col.FieldName) fieldValue = dataStruct.FieldByName(col.FieldName)
} else if len(col.fieldPath) == 2 { } else if len(col.fieldPath) == 2 {

View File

@ -47,15 +47,13 @@ type Dialect interface {
SupportInsertMany() bool SupportInsertMany() bool
SupportEngine() bool SupportEngine() bool
SupportCharset() bool SupportCharset() bool
SupportDropIfExists() bool
IndexOnTable() bool IndexOnTable() bool
ShowCreateNull() bool ShowCreateNull() bool
IndexCheckSql(tableName, idxName string) (string, []interface{}) IndexCheckSql(tableName, idxName string) (string, []interface{})
TableCheckSql(tableName string) (string, []interface{}) TableCheckSql(tableName string) (string, []interface{})
//ColumnCheckSql(tableName, colName string) (string, []interface{})
//IsTableExist(tableName string) (bool, error)
//IsIndexExist(tableName string, idx *Index) (bool, error)
IsColumnExist(tableName string, col *Column) (bool, error) IsColumnExist(tableName string, col *Column) (bool, error)
CreateTableSql(table *Table, tableName, storeEngine, charset string) string CreateTableSql(table *Table, tableName, storeEngine, charset string) string
@ -65,15 +63,13 @@ type Dialect interface {
ModifyColumnSql(tableName string, col *Column) string ModifyColumnSql(tableName string, col *Column) string
//CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error
//MustDropTable(tableName string) error
GetColumns(tableName string) ([]string, map[string]*Column, error) GetColumns(tableName string) ([]string, map[string]*Column, error)
GetTables() ([]*Table, error) GetTables() ([]*Table, error)
GetIndexes(tableName string) (map[string]*Index, error) GetIndexes(tableName string) (map[string]*Index, error)
// Get data from db cell to a struct's field
//GetData(col *Column, fieldValue *reflect.Value, cellData interface{}) error
// Set field data to db
//SetData(col *Column, fieldValue *refelct.Value) (interface{}, error)
Filters() []Filter Filters() []Filter
} }
@ -144,6 +140,10 @@ func (db *Base) RollBackStr() string {
return "ROLL BACK" return "ROLL BACK"
} }
func (db *Base) SupportDropIfExists() bool {
return true
}
func (db *Base) DropTableSql(tableName string) string { func (db *Base) DropTableSql(tableName string) string {
return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName) return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName)
} }
@ -170,35 +170,52 @@ func (db *Base) IsColumnExist(tableName string, col *Column) (bool, error) {
return db.HasRecords(query, db.DbName, tableName, col.Name) return db.HasRecords(query, db.DbName, tableName, col.Name)
} }
/*
func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error {
sql, args := db.dialect.TableCheckSql(tableName)
rows, err := db.DB().Query(sql, args...)
if db.Logger != nil {
db.Logger.Info("[sql]", sql, args)
}
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
return nil
}
sql = db.dialect.CreateTableSql(table, tableName, storeEngine, charset)
_, err = db.DB().Exec(sql)
if db.Logger != nil {
db.Logger.Info("[sql]", sql)
}
return err
}*/
func (db *Base) CreateIndexSql(tableName string, index *Index) string { func (db *Base) CreateIndexSql(tableName string, index *Index) string {
quote := db.dialect.Quote quote := db.dialect.Quote
var unique string var unique string
var idxName string var idxName string
if index.Type == UniqueType { if index.Type == UniqueType {
unique = " UNIQUE" unique = " UNIQUE"
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
} else {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
} }
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v);", unique, idxName = index.XName(tableName)
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
quote(idxName), quote(tableName), quote(idxName), quote(tableName),
quote(strings.Join(index.Cols, quote(",")))) quote(strings.Join(index.Cols, quote(","))))
} }
func (db *Base) DropIndexSql(tableName string, index *Index) string { func (db *Base) DropIndexSql(tableName string, index *Index) string {
quote := db.dialect.Quote quote := db.dialect.Quote
//var unique string var name string
var idxName string = index.Name if index.IsRegular {
if !strings.HasPrefix(idxName, "UQE_") && name = index.XName(tableName)
!strings.HasPrefix(idxName, "IDX_") { } else {
if index.Type == UniqueType { name = index.Name
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
} else {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
} }
return fmt.Sprintf("DROP INDEX %v ON %s", return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
quote(idxName), quote(tableName))
} }
func (db *Base) ModifyColumnSql(tableName string, col *Column) string { func (db *Base) ModifyColumnSql(tableName string, col *Column) string {

View File

@ -1,7 +1,9 @@
package core package core
import ( import (
"fmt"
"sort" "sort"
"strings"
) )
const ( const (
@ -11,9 +13,21 @@ const (
// database index // database index
type Index struct { type Index struct {
Name string IsRegular bool
Type int Name string
Cols []string Type int
Cols []string
}
func (index *Index) XName(tableName string) string {
if !strings.HasPrefix(index.Name, "UQE_") &&
!strings.HasPrefix(index.Name, "IDX_") {
if index.Type == UniqueType {
return fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
}
return fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
return index.Name
} }
// add columns which will be composite index // add columns which will be composite index
@ -24,6 +38,9 @@ func (index *Index) AddColumn(cols ...string) {
} }
func (index *Index) Equal(dst *Index) bool { func (index *Index) Equal(dst *Index) bool {
if index.Type != dst.Type {
return false
}
if len(index.Cols) != len(dst.Cols) { if len(index.Cols) != len(dst.Cols) {
return false return false
} }
@ -40,5 +57,5 @@ func (index *Index) Equal(dst *Index) bool {
// new an index // new an index
func NewIndex(name string, indexType int) *Index { func NewIndex(name string, indexType int) *Index {
return &Index{name, indexType, make([]string, 0)} return &Index{true, name, indexType, make([]string, 0)}
} }

View File

@ -9,7 +9,6 @@ import (
type IMapper interface { type IMapper interface {
Obj2Table(string) string Obj2Table(string) string
Table2Obj(string) string Table2Obj(string) string
TableName(string) string
} }
type CacheMapper struct { type CacheMapper struct {
@ -56,10 +55,6 @@ func (m *CacheMapper) Table2Obj(t string) string {
return o return o
} }
func (m *CacheMapper) TableName(t string) string {
return t
}
// SameMapper implements IMapper and provides same name between struct and // SameMapper implements IMapper and provides same name between struct and
// database table // database table
type SameMapper struct { type SameMapper struct {
@ -73,10 +68,6 @@ func (m SameMapper) Table2Obj(t string) string {
return t return t
} }
func (m SameMapper) TableName(t string) string {
return t
}
// SnakeMapper implements IMapper and provides name transaltion between // SnakeMapper implements IMapper and provides name transaltion between
// struct and database table // struct and database table
type SnakeMapper struct { type SnakeMapper struct {
@ -97,25 +88,6 @@ func snakeCasedName(name string) string {
return string(newstr) return string(newstr)
} }
/*func pascal2Sql(s string) (d string) {
d = ""
lastIdx := 0
for i := 0; i < len(s); i++ {
if s[i] >= 'A' && s[i] <= 'Z' {
if lastIdx < i {
d += s[lastIdx+1 : i]
}
if i != 0 {
d += "_"
}
d += string(s[i] + 32)
lastIdx = i
}
}
d += s[lastIdx+1:]
return
}*/
func (mapper SnakeMapper) Obj2Table(name string) string { func (mapper SnakeMapper) Obj2Table(name string) string {
return snakeCasedName(name) return snakeCasedName(name)
} }
@ -148,9 +120,103 @@ func (mapper SnakeMapper) Table2Obj(name string) string {
return titleCasedName(name) return titleCasedName(name)
} }
func (mapper SnakeMapper) TableName(t string) string { // GonicMapper implements IMapper. It will consider initialisms when mapping names.
return t // E.g. id -> ID, user -> User and to table names: UserID -> user_id, MyUID -> my_uid
type GonicMapper map[string]bool
func isASCIIUpper(r rune) bool {
return 'A' <= r && r <= 'Z'
} }
func toASCIIUpper(r rune) rune {
if 'a' <= r && r <= 'z' {
r -= ('a' - 'A')
}
return r
}
func gonicCasedName(name string) string {
newstr := make([]rune, 0, len(name)+3)
for idx, chr := range name {
if isASCIIUpper(chr) && idx > 0 {
if !isASCIIUpper(newstr[len(newstr)-1]) {
newstr = append(newstr, '_')
}
}
if !isASCIIUpper(chr) && idx > 1 {
l := len(newstr)
if isASCIIUpper(newstr[l-1]) && isASCIIUpper(newstr[l-2]) {
newstr = append(newstr, newstr[l-1])
newstr[l-1] = '_'
}
}
newstr = append(newstr, chr)
}
return strings.ToLower(string(newstr))
}
func (mapper GonicMapper) Obj2Table(name string) string {
return gonicCasedName(name)
}
func (mapper GonicMapper) Table2Obj(name string) string {
newstr := make([]rune, 0)
name = strings.ToLower(name)
parts := strings.Split(name, "_")
for _, p := range parts {
_, isInitialism := mapper[strings.ToUpper(p)]
for i, r := range p {
if i == 0 || isInitialism {
r = toASCIIUpper(r)
}
newstr = append(newstr, r)
}
}
return string(newstr)
}
// A GonicMapper that contains a list of common initialisms taken from golang/lint
var LintGonicMapper = GonicMapper{
"API": true,
"ASCII": true,
"CPU": true,
"CSS": true,
"DNS": true,
"EOF": true,
"GUID": true,
"HTML": true,
"HTTP": true,
"HTTPS": true,
"ID": true,
"IP": true,
"JSON": true,
"LHS": true,
"QPS": true,
"RAM": true,
"RHS": true,
"RPC": true,
"SLA": true,
"SMTP": true,
"SSH": true,
"TLS": true,
"TTL": true,
"UI": true,
"UID": true,
"UUID": true,
"URI": true,
"URL": true,
"UTF8": true,
"VM": true,
"XML": true,
"XSRF": true,
"XSS": true,
}
// provide prefix table name support // provide prefix table name support
type PrefixMapper struct { type PrefixMapper struct {
Mapper IMapper Mapper IMapper
@ -165,10 +231,6 @@ func (mapper PrefixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):])
} }
func (mapper PrefixMapper) TableName(name string) string {
return mapper.Prefix + name
}
func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper { func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper {
return PrefixMapper{mapper, prefix} return PrefixMapper{mapper, prefix}
} }
@ -187,10 +249,6 @@ func (mapper SuffixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)]) return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)])
} }
func (mapper SuffixMapper) TableName(name string) string {
return name + mapper.Suffix
}
func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper { func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper {
return SuffixMapper{mapper, suffix} return SuffixMapper{mapper, suffix}
} }

View File

@ -0,0 +1,45 @@
package core
import (
"testing"
)
func TestGonicMapperFromObj(t *testing.T) {
testCases := map[string]string{
"HTTPLib": "http_lib",
"id": "id",
"ID": "id",
"IDa": "i_da",
"iDa": "i_da",
"IDAa": "id_aa",
"aID": "a_id",
"aaID": "aa_id",
"aaaID": "aaa_id",
"MyREalFunkYLONgNAME": "my_r_eal_funk_ylo_ng_name",
}
for in, expected := range testCases {
out := gonicCasedName(in)
if out != expected {
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
}
}
}
func TestGonicMapperToObj(t *testing.T) {
testCases := map[string]string{
"http_lib": "HTTPLib",
"id": "ID",
"ida": "Ida",
"id_aa": "IDAa",
"aa_id": "AaID",
"my_r_eal_funk_ylo_ng_name": "MyREalFunkYloNgName",
}
for in, expected := range testCases {
out := LintGonicMapper.Table2Obj(in)
if out != expected {
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
}
}
}

View File

@ -1,7 +1,8 @@
package core package core
import ( import (
"encoding/json" "bytes"
"encoding/gob"
) )
type PK []interface{} type PK []interface{}
@ -12,14 +13,14 @@ func NewPK(pks ...interface{}) *PK {
} }
func (p *PK) ToString() (string, error) { func (p *PK) ToString() (string, error) {
bs, err := json.Marshal(*p) buf := new(bytes.Buffer)
if err != nil { enc := gob.NewEncoder(buf)
return "", nil err := enc.Encode(*p)
} return buf.String(), err
return string(bs), nil
} }
func (p *PK) FromString(content string) error { func (p *PK) FromString(content string) error {
return json.Unmarshal([]byte(content), p) dec := gob.NewDecoder(bytes.NewBufferString(content))
err := dec.Decode(p)
return err
} }

View File

@ -2,6 +2,7 @@ package core
import ( import (
"fmt" "fmt"
"reflect"
"testing" "testing"
) )
@ -19,4 +20,14 @@ func TestPK(t *testing.T) {
t.Error(err) t.Error(err)
} }
fmt.Println(s) fmt.Println(s)
if len(*p) != len(*s) {
t.Fatal("p", *p, "should be equal", *s)
}
for i, ori := range *p {
if ori != (*s)[i] {
t.Fatal("ori", ori, reflect.ValueOf(ori), "should be equal", (*s)[i], reflect.ValueOf((*s)[i]))
}
}
} }

View File

@ -65,13 +65,18 @@ func (table *Table) GetColumnIdx(name string, idx int) *Column {
// if has primary key, return column // if has primary key, return column
func (table *Table) PKColumns() []*Column { func (table *Table) PKColumns() []*Column {
columns := make([]*Column, 0) columns := make([]*Column, len(table.PrimaryKeys))
for _, name := range table.PrimaryKeys { for i, name := range table.PrimaryKeys {
columns = append(columns, table.GetColumn(name)) columns[i] = table.GetColumn(name)
} }
return columns return columns
} }
func (table *Table) ColumnType(name string) reflect.Type {
t, _ := table.Type.FieldByName(name)
return t.Type
}
func (table *Table) AutoIncrColumn() *Column { func (table *Table) AutoIncrColumn() *Column {
return table.GetColumn(table.AutoIncrement) return table.GetColumn(table.AutoIncrement)
} }

View File

@ -70,6 +70,7 @@ var (
NVarchar = "NVARCHAR" NVarchar = "NVARCHAR"
TinyText = "TINYTEXT" TinyText = "TINYTEXT"
Text = "TEXT" Text = "TEXT"
Clob = "CLOB"
MediumText = "MEDIUMTEXT" MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT" LongText = "LONGTEXT"
Uuid = "UUID" Uuid = "UUID"
@ -120,6 +121,7 @@ var (
MediumText: TEXT_TYPE, MediumText: TEXT_TYPE,
LongText: TEXT_TYPE, LongText: TEXT_TYPE,
Uuid: TEXT_TYPE, Uuid: TEXT_TYPE,
Clob: TEXT_TYPE,
Date: TIME_TYPE, Date: TIME_TYPE,
DateTime: TIME_TYPE, DateTime: TIME_TYPE,
@ -250,7 +252,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
case reflect.String: case reflect.String:
st = SQLType{Varchar, 255, 0} st = SQLType{Varchar, 255, 0}
case reflect.Struct: case reflect.Struct:
if t == reflect.TypeOf(c_TIME_DEFAULT) { if t.ConvertibleTo(reflect.TypeOf(c_TIME_DEFAULT)) {
st = SQLType{DateTime, 0, 0} st = SQLType{DateTime, 0, 0}
} else { } else {
// TODO need to handle association struct // TODO need to handle association struct
@ -303,7 +305,7 @@ func SQLType2Type(st SQLType) reflect.Type {
return reflect.TypeOf(float32(1)) return reflect.TypeOf(float32(1))
case Double: case Double:
return reflect.TypeOf(float64(1)) return reflect.TypeOf(float64(1))
case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid: case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid, Clob:
return reflect.TypeOf("") return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
return reflect.TypeOf([]byte{}) return reflect.TypeOf([]byte{})

View File

@ -82,11 +82,13 @@ Or
# Cases # Cases
* [Wego](http://github.com/go-tango/wego)
* [Docker.cn](https://docker.cn/) * [Docker.cn](https://docker.cn/)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs) * [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [Gorevel](http://http://gorevel.cn/) - [github.com/goofcc/gorevel](http://github.com/goofcc/gorevel) * [Gorevel](http://gorevel.cn/) - [github.com/goofcc/gorevel](http://github.com/goofcc/gorevel)
* [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker) * [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker)

View File

@ -44,16 +44,10 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
## 更新日志 ## 更新日志
* **v0.4.0 RC1** * **v0.4.2**
新特性: 新特性:
* 移动xorm cmd [github.com/go-xorm/cmd](github.com/go-xorm/cmd) * deleted标记
* 在重构一般DB操作核心库 [github.com/go-xorm/core](https://github.com/go-xorm/core) * bug fixed
* 移动测试github.com/复XORM/测试 [github.com/go-xorm/tests](github.com/go-xorm/tests)
改进:
* Prepared statement 缓存
* 添加 Incr API
* 指定时区位置
[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16) [更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16)
@ -78,6 +72,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
## 案例 ## 案例
* [Wego](http://github.com/go-tango/wego)
* [Docker.cn](https://docker.cn/) * [Docker.cn](https://docker.cn/)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs) * [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)

View File

@ -1 +1 @@
xorm v0.4.1 xorm v0.4.2.0225

View File

@ -63,21 +63,22 @@ There are 7 major ORM methods and many helpful methods to use to operate databas
// SELECT * FROM user // SELECT * FROM user
4. Query multiple records and record by record handle, there two methods, one is Iterate, 4. Query multiple records and record by record handle, there two methods, one is Iterate,
another is Raws another is Rows
err := engine.Iterate(...) err := engine.Iterate(...)
// SELECT * FROM user // SELECT * FROM user
raws, err := engine.Raws(...) rows, err := engine.Rows(...)
// SELECT * FROM user // SELECT * FROM user
defer rows.Close()
bean := new(Struct) bean := new(Struct)
for raws.Next() { for rows.Next() {
err = raws.Scan(bean) err = rows.Scan(bean)
} }
5. Update one or more records 5. Update one or more records
affected, err := engine.Update(&user) affected, err := engine.Id(...).Update(&user)
// UPDATE user SET ... // UPDATE user SET ...
6. Delete one or more records, Delete MUST has conditon 6. Delete one or more records, Delete MUST has conditon
@ -150,6 +151,6 @@ Attention: the above 7 methods should be the last chainable method.
engine.Join("LEFT", "userdetail", "user.id=userdetail.id").Find() engine.Join("LEFT", "userdetail", "user.id=userdetail.id").Find()
//SELECT * FROM user LEFT JOIN userdetail ON user.id=userdetail.id //SELECT * FROM user LEFT JOIN userdetail ON user.id=userdetail.id
More usage, please visit https://github.com/go-xorm/xorm/blob/master/docs/QuickStartEn.md More usage, please visit http://xorm.io/docs
*/ */
package xorm package xorm

View File

@ -344,7 +344,7 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) {
if col := table.GetColumn(name); col != nil { if col := table.GetColumn(name); col != nil {
col.Indexes[index.Name] = true col.Indexes[index.Name] = true
} else { } else {
return nil, fmt.Errorf("Unknown col "+name+" in indexes %v", index) return nil, fmt.Errorf("Unknown col "+name+" in indexes %v of table", index, table.ColumnsSeq())
} }
} }
} }
@ -352,6 +352,9 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) {
return tables, nil return tables, nil
} }
/*
dump database all table structs and data to a file
*/
func (engine *Engine) DumpAllToFile(fp string) error { func (engine *Engine) DumpAllToFile(fp string) error {
f, err := os.Create(fp) f, err := os.Create(fp)
if err != nil { if err != nil {
@ -361,6 +364,9 @@ func (engine *Engine) DumpAllToFile(fp string) error {
return engine.DumpAll(f) return engine.DumpAll(f)
} }
/*
dump database all table structs and data to w
*/
func (engine *Engine) DumpAll(w io.Writer) error { func (engine *Engine) DumpAll(w io.Writer) error {
tables, err := engine.DBMetas() tables, err := engine.DBMetas()
if err != nil { if err != nil {
@ -558,6 +564,13 @@ func (engine *Engine) Decr(column string, arg ...interface{}) *Session {
return session.Decr(column, arg...) return session.Decr(column, arg...)
} }
// Method SetExpr provides a update string like "column = {expression}"
func (engine *Engine) SetExpr(column string, expression string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.SetExpr(column, expression)
}
// Temporarily change the Get, Find, Update's table // Temporarily change the Get, Find, Update's table
func (engine *Engine) Table(tableNameOrBean interface{}) *Session { func (engine *Engine) Table(tableNameOrBean interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
@ -766,7 +779,12 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
col.IsPrimaryKey = true col.IsPrimaryKey = true
col.Nullable = false col.Nullable = false
case k == "NULL": case k == "NULL":
col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT") if j == 0 {
col.Nullable = true
} else {
col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT")
}
// TODO: for postgres how add autoincr?
/*case strings.HasPrefix(k, "AUTOINCR(") && strings.HasSuffix(k, ")"): /*case strings.HasPrefix(k, "AUTOINCR(") && strings.HasSuffix(k, ")"):
col.IsAutoIncrement = true col.IsAutoIncrement = true
@ -915,7 +933,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
table.AddColumn(col) table.AddColumn(col)
if fieldType.Kind() == reflect.Int64 && (col.FieldName == "Id" || strings.HasSuffix(col.FieldName, ".Id")) { if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
idFieldColName = col.Name idFieldColName = col.Name
} }
} // end for } // end for
@ -959,40 +977,25 @@ func (engine *Engine) mapping(beans ...interface{}) (e error) {
// If a table has any reocrd // If a table has any reocrd
func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) { func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) {
v := rValue(bean)
t := v.Type()
if t.Kind() != reflect.Struct {
return false, errors.New("bean should be a struct or struct's point")
}
engine.autoMapType(v)
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
rows, err := session.Count(bean) return session.IsTableEmpty(bean)
return rows == 0, err
} }
// If a table is exist // If a table is exist
func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) {
v := rValue(bean)
var tableName string
if v.Type().Kind() == reflect.String {
tableName = bean.(string)
} else if v.Type().Kind() == reflect.Struct {
table := engine.autoMapType(v)
tableName = table.Name
} else {
return false, errors.New("bean should be a struct or struct's point")
}
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
has, err := session.isTableExist(tableName) return session.IsTableExist(beanOrTableName)
return has, err
} }
func (engine *Engine) IdOf(bean interface{}) core.PK { func (engine *Engine) IdOf(bean interface{}) core.PK {
table := engine.TableInfo(bean) return engine.IdOfV(reflect.ValueOf(bean))
v := reflect.Indirect(reflect.ValueOf(bean)) }
func (engine *Engine) IdOfV(rv reflect.Value) core.PK {
v := reflect.Indirect(rv)
table := engine.autoMapType(v)
pk := make([]interface{}, len(table.PrimaryKeys)) pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() { for i, col := range table.PKColumns() {
pkField := v.FieldByName(col.FieldName) pkField := v.FieldByName(col.FieldName)
@ -1109,7 +1112,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession() session := engine.NewSession()
session.Statement.RefTable = table session.Statement.RefTable = table
defer session.Close() defer session.Close()
isExist, err := session.isColumnExist(table.Name, col) isExist, err := session.Engine.dialect.IsColumnExist(table.Name, col)
if err != nil { if err != nil {
return err return err
} }
@ -1222,8 +1225,9 @@ func (engine *Engine) CreateTables(beans ...interface{}) error {
func (engine *Engine) DropTables(beans ...interface{}) error { func (engine *Engine) DropTables(beans ...interface{}) error {
session := engine.NewSession() session := engine.NewSession()
err := session.Begin()
defer session.Close() defer session.Close()
err := session.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -1258,13 +1262,6 @@ func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice [
return session.Query(sql, paramStr...) return session.Query(sql, paramStr...)
} }
// Exec a raw sql and return records as []map[string]string
func (engine *Engine) Q(sql string, paramStr ...interface{}) (resultsSlice []map[string]string, err error) {
session := engine.NewSession()
defer session.Close()
return session.Q(sql, paramStr...)
}
// Insert one or more records // Insert one or more records
func (engine *Engine) Insert(beans ...interface{}) (int64, error) { func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
session := engine.NewSession() session := engine.NewSession()
@ -1371,18 +1368,11 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) {
scanner.Split(semiColSpliter) scanner.Split(semiColSpliter)
session := engine.NewSession()
defer session.Close()
err := session.newDb()
if err != nil {
return results, err
}
for scanner.Scan() { for scanner.Scan() {
query := scanner.Text() query := scanner.Text()
query = strings.Trim(query, " \t") query = strings.Trim(query, " \t")
if len(query) > 0 { if len(query) > 0 {
result, err := session.Db.Exec(query) result, err := engine.DB().Exec(query)
results = append(results, result) results = append(results, result)
if err != nil { if err != nil {
lastError = err lastError = err
@ -1409,7 +1399,15 @@ func (engine *Engine) NowTime(sqlTypeName string) interface{} {
return engine.FormatTime(sqlTypeName, t) return engine.FormatTime(sqlTypeName, t)
} }
func (engine *Engine) NowTime2(sqlTypeName string) (interface{}, time.Time) {
t := time.Now()
return engine.FormatTime(sqlTypeName, t), t
}
func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}) { func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}) {
if engine.dialect.DBType() == core.ORACLE {
return t
}
switch sqlTypeName { switch sqlTypeName {
case core.Time: case core.Time:
s := engine.TZTime(t).Format("2006-01-02 15:04:05") //time.RFC3339 s := engine.TZTime(t).Format("2006-01-02 15:04:05") //time.RFC3339
@ -1419,6 +1417,8 @@ func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}
case core.DateTime, core.TimeStamp: case core.DateTime, core.TimeStamp:
if engine.dialect.DBType() == "ql" { if engine.dialect.DBType() == "ql" {
v = engine.TZTime(t) v = engine.TZTime(t)
} else if engine.dialect.DBType() == "sqlite3" {
v = engine.TZTime(t).UTC().Format("2006-01-02 15:04:05")
} else { } else {
v = engine.TZTime(t).Format("2006-01-02 15:04:05") v = engine.TZTime(t).Format("2006-01-02 15:04:05")
} }
@ -1430,6 +1430,8 @@ func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}
} else { } else {
v = engine.TZTime(t).Format(time.RFC3339Nano) v = engine.TZTime(t).Format(time.RFC3339Nano)
} }
case core.BigInt, core.Int:
v = engine.TZTime(t).Unix()
default: default:
v = engine.TZTime(t) v = engine.TZTime(t)
} }

View File

@ -11,6 +11,43 @@ import (
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
func isZero(k interface{}) bool {
switch k.(type) {
case int:
return k.(int) == 0
case int8:
return k.(int8) == 0
case int16:
return k.(int16) == 0
case int32:
return k.(int32) == 0
case int64:
return k.(int64) == 0
case uint:
return k.(uint) == 0
case uint8:
return k.(uint8) == 0
case uint16:
return k.(uint16) == 0
case uint32:
return k.(uint32) == 0
case uint64:
return k.(uint64) == 0
case string:
return k.(string) == ""
}
return false
}
func isPKZero(pk core.PK) bool {
for _, k := range pk {
if isZero(k) {
return true
}
}
return false
}
func indexNoCase(s, sep string) int { func indexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep)) return strings.Index(strings.ToLower(s), strings.ToLower(sep))
} }
@ -163,3 +200,182 @@ func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
return resultsSlice, nil return resultsSlice, nil
} }
func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) {
result := make(map[string][]byte)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
if data, err := value2Bytes(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) {
result := make(map[string]string)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
if data, err := value2String(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func txQuery2(tx *core.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string]string, err error) {
rows, err := tx.Query(sqlStr, params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
func query2(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string]string, err error) {
s, err := db.Prepare(sqlStr)
if err != nil {
return nil, err
}
defer s.Close()
rows, err := s.Query(params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Struct:
v.Set(reflect.ValueOf(t).Convert(v.Type()))
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t.Unix())
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t.Unix()))
}
}
}
func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0)
args := make([]interface{}, 0)
for _, col := range table.Columns() {
lColName := strings.ToLower(col.Name)
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := session.Statement.columnMap[lColName]; !ok {
continue
}
}
if col.MapType == core.ONLYFROMDB {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
session.Engine.LogError(err)
continue
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
}
}
if col.IsDeleted {
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := session.Statement.columnMap[lColName]; !ok {
continue
}
}
if session.Statement.OmitStr != "" {
if _, ok := session.Statement.columnMap[lColName]; ok {
continue
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
val, t := session.Engine.NowTime2(col.SQLType.Name)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.Statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
if includeQuote {
colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?")
} else {
colNames = append(colNames, col.Name)
}
}
return colNames, args, nil
}

View File

@ -270,7 +270,7 @@ func (db *mssql) IsReserved(name string) bool {
} }
func (db *mssql) Quote(name string) string { func (db *mssql) Quote(name string) string {
return "[" + name + "]" return "\"" + name + "\""
} }
func (db *mssql) QuoteStr() string { func (db *mssql) QuoteStr() string {

View File

@ -218,6 +218,9 @@ func (db *mysql) SqlType(c *core.Column) string {
res += ")" res += ")"
case core.NVarchar: case core.NVarchar:
res = core.Varchar res = core.Varchar
case core.Uuid:
res = core.Varchar
c.Length = 40
default: default:
res = t res = t
} }
@ -317,7 +320,6 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
//fmt.Println(columnName, isNullable, colType, colKey, extra, colDefault)
col.Name = strings.Trim(columnName, "` ") col.Name = strings.Trim(columnName, "` ")
if "YES" == isNullable { if "YES" == isNullable {
col.Nullable = true col.Nullable = true
@ -467,15 +469,17 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
} }
colName = strings.Trim(colName, "` ") colName = strings.Trim(colName, "` ")
var isRegular bool
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
isRegular = true
} }
var index *core.Index var index *core.Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(core.Index) index = new(core.Index)
index.IsRegular = isRegular
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
indexes[indexName] = index indexes[indexName] = index

View File

@ -509,7 +509,7 @@ func (db *oracle) SqlType(c *core.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial: case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial:
return "NUMBER" res = "NUMBER"
case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea: case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea:
return core.Blob return core.Blob
case core.Time, core.DateTime, core.TimeStamp: case core.Time, core.DateTime, core.TimeStamp:
@ -521,7 +521,7 @@ func (db *oracle) SqlType(c *core.Column) string {
case core.Text, core.MediumText, core.LongText: case core.Text, core.MediumText, core.LongText:
res = "CLOB" res = "CLOB"
case core.Char, core.Varchar, core.TinyText: case core.Char, core.Varchar, core.TinyText:
return "VARCHAR2" res = "VARCHAR2"
default: default:
res = t res = t
} }
@ -536,6 +536,10 @@ func (db *oracle) SqlType(c *core.Column) string {
return res return res
} }
func (db *oracle) AutoIncrStr() string {
return "AUTO_INCREMENT"
}
func (db *oracle) SupportInsertMany() bool { func (db *oracle) SupportInsertMany() bool {
return true return true
} }
@ -553,10 +557,6 @@ func (db *oracle) QuoteStr() string {
return "\"" return "\""
} }
func (db *oracle) AutoIncrStr() string {
return ""
}
func (db *oracle) SupportEngine() bool { func (db *oracle) SupportEngine() bool {
return false return false
} }
@ -565,19 +565,94 @@ func (db *oracle) SupportCharset() bool {
return false return false
} }
func (db *oracle) SupportDropIfExists() bool {
return false
}
func (db *oracle) IndexOnTable() bool { func (db *oracle) IndexOnTable() bool {
return false return false
} }
func (db *oracle) DropTableSql(tableName string) string {
return fmt.Sprintf("DROP TABLE `%s`", tableName)
}
func (b *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE "
if tableName == "" {
tableName = table.Name
}
sql += b.Quote(tableName) + " ("
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
/*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect)
} else {*/
sql += col.StringNoPk(b)
//}
sql = strings.TrimSpace(sql)
sql += ", "
}
if len(pkList) > 0 {
sql += "PRIMARY KEY ( "
sql += b.Quote(strings.Join(pkList, b.Quote(",")))
sql += " ), "
}
sql = sql[:len(sql)-2] + ")"
if b.SupportEngine() && storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if b.SupportCharset() {
if len(charset) == 0 {
charset = b.URI().Charset
}
if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
}
return sql
}
func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(idxName)} args := []interface{}{tableName, idxName}
return `SELECT INDEX_NAME FROM USER_INDEXES ` + return `SELECT INDEX_NAME FROM USER_INDEXES ` +
`WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args `WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args
} }
func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) { func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{strings.ToUpper(tableName)} args := []interface{}{tableName}
return `SELECT table_name FROM user_tables WHERE table_name = ?`, args return `SELECT table_name FROM user_tables WHERE table_name = :1`, args
}
func (db *oracle) MustDropTable(tableName string) error {
sql, args := db.TableCheckSql(tableName)
if db.Logger != nil {
db.Logger.Info("[sql]", sql, args)
}
rows, err := db.DB().Query(sql, args...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
return nil
}
sql = "Drop Table \"" + tableName + "\""
if db.Logger != nil {
db.Logger.Info("[sql]", sql)
}
_, err = db.DB().Exec(sql)
return err
} }
/*func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) { /*func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
@ -587,9 +662,9 @@ func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) {
}*/ }*/
func (db *oracle) IsColumnExist(tableName string, col *core.Column) (bool, error) { func (db *oracle) IsColumnExist(tableName string, col *core.Column) (bool, error) {
args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(col.Name)} args := []interface{}{tableName, col.Name}
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" +
" AND column_name = ?" " AND column_name = :2"
rows, err := db.DB().Query(query, args...) rows, err := db.DB().Query(query, args...)
if db.Logger != nil { if db.Logger != nil {
db.Logger.Info("[sql]", query, args) db.Logger.Info("[sql]", query, args)
@ -606,7 +681,7 @@ func (db *oracle) IsColumnExist(tableName string, col *core.Column) (bool, error
} }
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{strings.ToUpper(tableName)} args := []interface{}{tableName}
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
@ -625,7 +700,7 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
col := new(core.Column) col := new(core.Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale string var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
var dataLen int var dataLen int
err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision, err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
@ -634,36 +709,66 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
return nil, nil, err return nil, nil, err
} }
col.Name = strings.Trim(colName, `" `) col.Name = strings.Trim(*colName, `" `)
col.Default = colDefault if colDefault != nil {
col.Default = *colDefault
col.DefaultIsEmpty = false
}
if nullable == "Y" { if *nullable == "Y" {
col.Nullable = true col.Nullable = true
} else { } else {
col.Nullable = false col.Nullable = false
} }
switch dataType { var ignore bool
var dt string
var len1, len2 int
dts := strings.Split(*dataType, "(")
dt = dts[0]
if len(dts) > 1 {
lens := strings.Split(dts[1][:len(dts[1])-1], ",")
if len(lens) > 1 {
len1, _ = strconv.Atoi(lens[0])
len2, _ = strconv.Atoi(lens[1])
} else {
len1, _ = strconv.Atoi(lens[0])
}
}
switch dt {
case "VARCHAR2": case "VARCHAR2":
col.SQLType = core.SQLType{core.Varchar, 0, 0} col.SQLType = core.SQLType{core.Varchar, len1, len2}
case "TIMESTAMP WITH TIME ZONE": case "TIMESTAMP WITH TIME ZONE":
col.SQLType = core.SQLType{core.TimeStampz, 0, 0} col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
case "NUMBER":
col.SQLType = core.SQLType{core.Double, len1, len2}
case "LONG", "LONG RAW":
col.SQLType = core.SQLType{core.Text, 0, 0}
case "RAW":
col.SQLType = core.SQLType{core.Binary, 0, 0}
case "ROWID":
col.SQLType = core.SQLType{core.Varchar, 18, 0}
case "AQ$_SUBSCRIBERS":
ignore = true
default: default:
col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0} col.SQLType = core.SQLType{strings.ToUpper(dt), len1, len2}
} }
if ignore {
continue
}
if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v %v", *dataType, col.SQLType))
} }
col.Length = dataLen col.Length = dataLen
if col.SQLType.IsText() || col.SQLType.IsTime() { if col.SQLType.IsText() || col.SQLType.IsTime() {
if col.Default != "" { if !col.DefaultIsEmpty {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} else {
if col.DefaultIsEmpty {
col.Default = "''"
}
} }
} }
cols[col.Name] = col cols[col.Name] = col

View File

@ -25,11 +25,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
rows.session = session rows.session = session
rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type() rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type()
err := rows.session.newDb()
if err != nil {
return nil, err
}
defer rows.session.Statement.Init() defer rows.session.Statement.Init()
var sqlStr string var sqlStr string
@ -47,8 +42,8 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
} }
rows.session.Engine.logSQL(sqlStr, args) rows.session.Engine.logSQL(sqlStr, args)
var err error
rows.stmt, err = rows.session.Db.Prepare(sqlStr) rows.stmt, err = rows.session.DB().Prepare(sqlStr)
if err != nil { if err != nil {
rows.lastError = err rows.lastError = err
defer rows.Close() defer rows.Close()

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
package xorm package xorm
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -152,7 +153,7 @@ func (db *sqlite3) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName st
func (db *sqlite3) SqlType(c *core.Column) string { func (db *sqlite3) SqlType(c *core.Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case core.Date, core.DateTime, core.TimeStamp, core.Time: case core.Date, core.DateTime, core.TimeStamp, core.Time:
return core.Numeric return core.DateTime
case core.TimeStampz: case core.TimeStampz:
return core.Text return core.Text
case core.Char, core.Varchar, core.NVarchar, core.TinyText, core.Text, core.MediumText, core.LongText: case core.Char, core.Varchar, core.NVarchar, core.TinyText, core.Text, core.MediumText, core.LongText:
@ -297,6 +298,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
col := new(core.Column) col := new(core.Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
col.Nullable = true col.Nullable = true
col.DefaultIsEmpty = true
for idx, field := range fields { for idx, field := range fields {
if idx == 0 { if idx == 0 {
col.Name = strings.Trim(field, "`[] ") col.Name = strings.Trim(field, "`[] ")
@ -315,8 +317,14 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
} else { } else {
col.Nullable = true col.Nullable = true
} }
case "DEFAULT":
col.Default = fields[idx+1]
col.DefaultIsEmpty = false
} }
} }
if !col.SQLType.IsNumeric() && !col.DefaultIsEmpty {
col.Default = "'" + col.Default + "'"
}
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
@ -366,15 +374,16 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
indexes := make(map[string]*core.Index, 0) indexes := make(map[string]*core.Index, 0)
for rows.Next() { for rows.Next() {
var sql string var tmpSql sql.NullString
err = rows.Scan(&sql) err = rows.Scan(&tmpSql)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if sql == "" { if !tmpSql.Valid {
continue continue
} }
sql := tmpSql.String
index := new(core.Index) index := new(core.Index)
nNStart := strings.Index(sql, "INDEX") nNStart := strings.Index(sql, "INDEX")
@ -384,7 +393,6 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
} }
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
//fmt.Println(indexName)
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
index.Name = indexName[5+len(tableName) : len(indexName)] index.Name = indexName[5+len(tableName) : len(indexName)]
} else { } else {

View File

@ -26,6 +26,11 @@ type decrParam struct {
arg interface{} arg interface{}
} }
type exprParam struct {
colName string
expr string
}
// statement save all the sql info for executing SQL // statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *core.Table RefTable *core.Table
@ -63,6 +68,7 @@ type Statement struct {
inColumns map[string]*inParam inColumns map[string]*inParam
incrColumns map[string]incrParam incrColumns map[string]incrParam
decrColumns map[string]decrParam decrColumns map[string]decrParam
exprColumns map[string]exprParam
} }
// init // init
@ -98,6 +104,7 @@ func (statement *Statement) Init() {
statement.inColumns = make(map[string]*inParam) statement.inColumns = make(map[string]*inParam)
statement.incrColumns = make(map[string]incrParam) statement.incrColumns = make(map[string]incrParam)
statement.decrColumns = make(map[string]decrParam) statement.decrColumns = make(map[string]decrParam)
statement.exprColumns = make(map[string]exprParam)
} }
// add the raw sql statement // add the raw sql statement
@ -153,9 +160,6 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
t := v.Type() t := v.Type()
if t.Kind() == reflect.String { if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string) statement.AltTableName = tableNameOrBean.(string)
if statement.AltTableName[0] == '~' {
statement.AltTableName = statement.Engine.TableMapper.TableName(statement.AltTableName[1:])
}
} else if t.Kind() == reflect.Struct { } else if t.Kind() == reflect.Struct {
statement.RefTable = statement.Engine.autoMapType(v) statement.RefTable = statement.Engine.autoMapType(v)
} }
@ -282,7 +286,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
func buildUpdates(engine *Engine, table *core.Table, bean interface{}, func buildUpdates(engine *Engine, table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool, includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, includeAutoIncr bool, allUseBool bool, useAllCols bool,
mustColumnMap map[string]bool, update bool) ([]string, []interface{}) { mustColumnMap map[string]bool, columnMap map[string]bool, update bool) ([]string, []interface{}) {
colNames := make([]string, 0) colNames := make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
@ -302,6 +306,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if col.IsDeleted { if col.IsDeleted {
continue continue
} }
if use, ok := columnMap[col.Name]; ok && !use {
continue
}
if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
continue continue
@ -414,13 +421,16 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if table, ok := engine.Tables[fieldValue.Type()]; ok { if table, ok := engine.Tables[fieldValue.Type()]; ok {
if len(table.PrimaryKeys) == 1 { if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
if pkField.Int() != 0 { // fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !isZero(pkField.Interface()) {
val = pkField.Interface() val = pkField.Interface()
} else { } else {
continue continue
} }
} else { } else {
//TODO: how to handler? //TODO: how to handler?
panic("not supported")
} }
} else { } else {
val = fieldValue.Interface() val = fieldValue.Interface()
@ -579,24 +589,29 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
t := int64(fieldValue.Uint()) t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface() val = reflect.ValueOf(&t).Interface()
case reflect.Struct: case reflect.Struct:
if fieldType == reflect.TypeOf(time.Now()) { if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Interface().(time.Time) t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue continue
} }
val = engine.FormatTime(col.SQLType.Name, t) val = engine.FormatTime(col.SQLType.Name, t)
} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
continue
} else { } else {
engine.autoMapType(fieldValue) engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok { if table, ok := engine.Tables[fieldValue.Type()]; ok {
if len(table.PrimaryKeys) == 1 { if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
if pkField.Int() != 0 { // fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !isZero(pkField.Interface()) {
val = pkField.Interface() val = pkField.Interface()
} else { } else {
continue continue
} }
} else { } else {
//TODO: how to handler? //TODO: how to handler?
panic("not supported")
} }
} else { } else {
val = fieldValue.Interface() val = fieldValue.Interface()
@ -716,6 +731,13 @@ func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
return statement return statement
} }
// Generate "Update ... Set column = {expression}" statment
func (statement *Statement) SetExpr(column string, expression string) *Statement {
k := strings.ToLower(column)
statement.exprColumns[k] = exprParam{column, expression}
return statement
}
// Generate "Update ... Set column = column + arg" statment // Generate "Update ... Set column = column + arg" statment
func (statement *Statement) getInc() map[string]incrParam { func (statement *Statement) getInc() map[string]incrParam {
return statement.incrColumns return statement.incrColumns
@ -726,6 +748,11 @@ func (statement *Statement) getDec() map[string]decrParam {
return statement.decrColumns return statement.decrColumns
} }
// Generate "Update ... Set column = {expression}" statment
func (statement *Statement) getExpr() map[string]exprParam {
return statement.exprColumns
}
// Generate "Where column IN (?) " statment // Generate "Where column IN (?) " statment
func (statement *Statement) In(column string, args ...interface{}) *Statement { func (statement *Statement) In(column string, args ...interface{}) *Statement {
k := strings.ToLower(column) k := strings.ToLower(column)
@ -941,15 +968,9 @@ func (statement *Statement) Join(join_operator string, tablename interface{}, co
l := len(t) l := len(t)
if l > 1 { if l > 1 {
table := t[0] table := t[0]
if table[0] == '~' {
table = statement.Engine.TableMapper.TableName(table[1:])
}
joinTable = statement.Engine.Quote(table) + " AS " + statement.Engine.Quote(t[1]) joinTable = statement.Engine.Quote(table) + " AS " + statement.Engine.Quote(t[1])
} else if l == 1 { } else if l == 1 {
table := t[0] table := t[0]
if table[0] == '~' {
table = statement.Engine.TableMapper.TableName(table[1:])
}
joinTable = statement.Engine.Quote(table) joinTable = statement.Engine.Quote(table)
} }
case []interface{}: case []interface{}:
@ -962,9 +983,6 @@ func (statement *Statement) Join(join_operator string, tablename interface{}, co
t := v.Type() t := v.Type()
if t.Kind() == reflect.String { if t.Kind() == reflect.String {
table = f.(string) table = f.(string)
if table[0] == '~' {
table = statement.Engine.TableMapper.TableName(table[1:])
}
} else if t.Kind() == reflect.Struct { } else if t.Kind() == reflect.Struct {
r := statement.Engine.autoMapType(v) r := statement.Engine.autoMapType(v)
table = r.Name table = r.Name
@ -977,9 +995,6 @@ func (statement *Statement) Join(join_operator string, tablename interface{}, co
} }
default: default:
t := fmt.Sprintf("%v", tablename) t := fmt.Sprintf("%v", tablename)
if t[0] == '~' {
t = statement.Engine.TableMapper.TableName(t[1:])
}
joinTable = statement.Engine.Quote(t) joinTable = statement.Engine.Quote(t)
} }
if statement.JoinStr != "" { if statement.JoinStr != "" {
@ -1105,9 +1120,10 @@ func (s *Statement) genDelIndexSQL() []string {
return sqls return sqls
} }
/*
func (s *Statement) genDropSQL() string { func (s *Statement) genDropSQL() string {
return s.Engine.dialect.DropTableSql(s.TableName()) + ";" return s.Engine.dialect.MustDropTa(s.TableName()) + ";"
} }*/
func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) {
var table *core.Table var table *core.Table
@ -1126,13 +1142,21 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
statement.BeanArgs = args statement.BeanArgs = args
var columnStr string = statement.ColumnStr var columnStr string = statement.ColumnStr
if statement.JoinStr == "" { if len(statement.JoinStr) == 0 {
if columnStr == "" { if len(columnStr) == 0 {
columnStr = statement.genColumnStr() if statement.GroupByStr != "" {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
} else {
columnStr = statement.genColumnStr()
}
} }
} else { } else {
if columnStr == "" { if len(columnStr) == 0 {
columnStr = "*" if statement.GroupByStr != "" {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
} else {
columnStr = "*"
}
} }
} }
@ -1178,14 +1202,16 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}
id = "" id = ""
} }
statement.attachInSql() statement.attachInSql()
return statement.genSelectSql(fmt.Sprintf("count(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) return statement.genSelectSql(fmt.Sprintf("count(%v)", id)), append(statement.Params, statement.BeanArgs...)
} }
func (statement *Statement) genSelectSql(columnStr string) (a string) { func (statement *Statement) genSelectSql(columnStr string) (a string) {
if statement.GroupByStr != "" { /*if statement.GroupByStr != "" {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) if columnStr == "" {
statement.GroupByStr = columnStr columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
} }
//statement.GroupByStr = columnStr
}*/
var distinct string var distinct string
if statement.IsDistinct { if statement.IsDistinct {
distinct = "DISTINCT " distinct = "DISTINCT "
@ -1210,7 +1236,11 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
} }
var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName()) var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName())
if statement.TableAlias != "" { if statement.TableAlias != "" {
fromStr += " AS " + statement.Engine.Quote(statement.TableAlias) if statement.Engine.dialect.DBType() == core.ORACLE {
fromStr += " " + statement.Engine.Quote(statement.TableAlias)
} else {
fromStr += " AS " + statement.Engine.Quote(statement.TableAlias)
}
} }
if statement.JoinStr != "" { if statement.JoinStr != "" {
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
@ -1233,8 +1263,16 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
column = statement.RefTable.ColumnsSeq()[0] column = statement.RefTable.ColumnsSeq()[0]
} }
} }
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s))", var orderStr string
column, statement.Start, column, fromStr, whereStr) if len(statement.OrderStr) > 0 {
orderStr = " ORDER BY " + statement.OrderStr
}
var groupStr string
if len(statement.GroupByStr) > 0 {
groupStr = " GROUP BY " + statement.GroupByStr
}
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
} }
} }
@ -1258,12 +1296,16 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
if statement.OrderStr != "" { if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
} }
if statement.Engine.dialect.DBType() != core.MSSQL { if statement.Engine.dialect.DBType() != core.MSSQL && statement.Engine.dialect.DBType() != core.ORACLE {
if statement.Start > 0 { if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 { } else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
} }
} else if statement.Engine.dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
}
} }
return return

View File

@ -13,7 +13,7 @@ import (
) )
const ( const (
Version string = "0.4.1" Version string = "0.4.2.0225"
) )
func regDrvsNDialects() bool { func regDrvsNDialects() bool {
@ -84,17 +84,16 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
TZLocation: time.Local, TZLocation: time.Local,
} }
engine.dialect.SetLogger(engine.Logger)
engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper))) engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper)))
//engine.Filters = dialect.Filters()
//engine.Cacher = NewLRUCacher()
//err = engine.SetPool(NewSysConnectPool())
runtime.SetFinalizer(engine, close) runtime.SetFinalizer(engine, close)
return engine, err
return engine, nil
} }
// clone an engine // clone an engine
func (engine *Engine) Clone() (*Engine, error) { func (engine *Engine) Clone() (*Engine, error) {
return NewEngine(engine.dialect.DriverName(), engine.dialect.DataSourceName()) return NewEngine(engine.DriverName(), engine.DataSourceName())
} }

View File

@ -41,12 +41,18 @@ FAQ
> See: https://github.com/mattn/go-sqlite3/issues/106 > See: https://github.com/mattn/go-sqlite3/issues/106
> See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html > See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html
* Want to get time.Time with current locale
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
License License
------- -------
MIT: http://mattn.mit-license.org/2012 MIT: http://mattn.mit-license.org/2012
sqlite.c, sqlite3.h, sqlite3ext.h sqlite3-binding.c, sqlite3-binding.h, sqlite3ext.h
The -binding suffix was added to avoid build failures under gccgo.
In this repository, those files are amalgamation code that copied from SQLite3. The license of those codes are depend on the license of SQLite3. In this repository, those files are amalgamation code that copied from SQLite3. The license of those codes are depend on the license of SQLite3.

View File

@ -6,7 +6,7 @@
package sqlite3 package sqlite3
/* /*
#include <sqlite3.h> #include <sqlite3-binding.h>
#include <stdlib.h> #include <stdlib.h>
*/ */
import "C" import "C"

View File

@ -231,6 +231,12 @@ func TestExtendedErrorCodes_Unique(t *testing.T) {
t.Errorf("Wrong extended error code: %d != %d", t.Errorf("Wrong extended error code: %d != %d",
sqliteErr.ExtendedCode, ErrConstraintUnique) sqliteErr.ExtendedCode, ErrConstraintUnique)
} }
extended := sqliteErr.Code.Extend(3).Error()
expected := "constraint failed"
if extended != expected {
t.Errorf("Wrong basic error code: %q != %q",
extended, expected)
}
} }
} }

View File

@ -6,7 +6,10 @@
package sqlite3 package sqlite3
/* /*
#include <sqlite3.h> #cgo CFLAGS: -std=gnu99
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS
#include <sqlite3-binding.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
@ -44,14 +47,23 @@ _sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
#include <stdio.h> #include <stdio.h>
#include <stdint.h> #include <stdint.h>
static long static int
_sqlite3_last_insert_rowid(sqlite3* db) { _sqlite3_exec(sqlite3* db, const char* pcmd, long* rowid, long* changes)
return (long) sqlite3_last_insert_rowid(db); {
int rv = sqlite3_exec(db, pcmd, 0, 0, 0);
*rowid = (long) sqlite3_last_insert_rowid(db);
*changes = (long) sqlite3_changes(db);
return rv;
} }
static long static int
_sqlite3_changes(sqlite3* db) { _sqlite3_step(sqlite3_stmt* stmt, long* rowid, long* changes)
return (long) sqlite3_changes(db); {
int rv = sqlite3_step(stmt);
sqlite3* db = sqlite3_db_handle(stmt);
*rowid = (long) sqlite3_last_insert_rowid(db);
*changes = (long) sqlite3_changes(db);
return rv;
} }
*/ */
@ -60,8 +72,11 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt"
"io" "io"
"net/url"
"runtime" "runtime"
"strconv"
"strings" "strings"
"time" "time"
"unsafe" "unsafe"
@ -102,7 +117,8 @@ type SQLiteDriver struct {
// Conn struct. // Conn struct.
type SQLiteConn struct { type SQLiteConn struct {
db *C.sqlite3 db *C.sqlite3
loc *time.Location
} }
// Tx struct. // Tx struct.
@ -114,6 +130,8 @@ type SQLiteTx struct {
type SQLiteStmt struct { type SQLiteStmt struct {
c *SQLiteConn c *SQLiteConn
s *C.sqlite3_stmt s *C.sqlite3_stmt
nv int
nn []string
t string t string
closed bool closed bool
cls bool cls bool
@ -174,7 +192,7 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
if s.(*SQLiteStmt).s != nil { if s.(*SQLiteStmt).s != nil {
na := s.NumInput() na := s.NumInput()
if len(args) < na { if len(args) < na {
return nil, errors.New("args is not enough to execute query") return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
} }
res, err = s.Exec(args[:na]) res, err = s.Exec(args[:na])
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
@ -201,6 +219,9 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
} }
s.(*SQLiteStmt).cls = true s.(*SQLiteStmt).cls = true
na := s.NumInput() na := s.NumInput()
if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
}
rows, err := s.Query(args[:na]) rows, err := s.Query(args[:na])
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
@ -220,14 +241,13 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
func (c *SQLiteConn) exec(cmd string) (driver.Result, error) { func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
pcmd := C.CString(cmd) pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd)) defer C.free(unsafe.Pointer(pcmd))
rv := C.sqlite3_exec(c.db, pcmd, nil, nil, nil)
var rowid, changes C.long
rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return nil, c.lastError() return nil, c.lastError()
} }
return &SQLiteResult{ return &SQLiteResult{int64(rowid), int64(changes)}, nil
int64(C._sqlite3_last_insert_rowid(c.db)),
int64(C._sqlite3_changes(c.db)),
}, nil
} }
// Begin transaction. // Begin transaction.
@ -248,11 +268,51 @@ func errorString(err Error) string {
// file:test.db?cache=shared&mode=memory // file:test.db?cache=shared&mode=memory
// :memory: // :memory:
// file::memory: // file::memory:
// go-sqlite handle especially query parameters.
// _loc=XXX
// Specify location of time format. It's possible to specify "auto".
// _busy_timeout=XXX
// Specify value for sqlite3_busy_timeout.
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 { if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation") return nil, errors.New("sqlite library was not compiled for thread-safe operation")
} }
var loc *time.Location
busy_timeout := 5000
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:])
if err != nil {
return nil, err
}
// _loc
if val := params.Get("_loc"); val != "" {
if val == "auto" {
loc = time.Local
} else {
loc, err = time.LoadLocation(val)
if err != nil {
return nil, fmt.Errorf("Invalid _loc: %v: %v", val, err)
}
}
}
// _busy_timeout
if val := params.Get("_busy_timeout"); val != "" {
iv, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err)
}
busy_timeout = int(iv)
}
if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
}
var db *C.sqlite3 var db *C.sqlite3
name := C.CString(dsn) name := C.CString(dsn)
defer C.free(unsafe.Pointer(name)) defer C.free(unsafe.Pointer(name))
@ -268,12 +328,12 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New("sqlite succeeded without returning a database") return nil, errors.New("sqlite succeeded without returning a database")
} }
rv = C.sqlite3_busy_timeout(db, 5000) rv = C.sqlite3_busy_timeout(db, C.int(busy_timeout))
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return nil, Error{Code: ErrNo(rv)} return nil, Error{Code: ErrNo(rv)}
} }
conn := &SQLiteConn{db} conn := &SQLiteConn{db: db, loc: loc}
if len(d.Extensions) > 0 { if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1) rv = C.sqlite3_enable_load_extension(db, 1)
@ -281,21 +341,15 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
} }
stmt, err := conn.Prepare("SELECT load_extension(?);")
if err != nil {
return nil, err
}
for _, extension := range d.Extensions { for _, extension := range d.Extensions {
if _, err = stmt.Exec([]driver.Value{extension}); err != nil { cext := C.CString(extension)
return nil, err defer C.free(unsafe.Pointer(cext))
rv = C.sqlite3_load_extension(db, cext, nil, nil)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
} }
} }
if err = stmt.Close(); err != nil {
return nil, err
}
rv = C.sqlite3_enable_load_extension(db, 0) rv = C.sqlite3_enable_load_extension(db, 0)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
@ -333,10 +387,18 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
return nil, c.lastError() return nil, c.lastError()
} }
var t string var t string
if tail != nil && C.strlen(tail) > 0 { if tail != nil && *tail != '\000' {
t = strings.TrimSpace(C.GoString(tail)) t = strings.TrimSpace(C.GoString(tail))
} }
ss := &SQLiteStmt{c: c, s: s, t: t} nv := int(C.sqlite3_bind_parameter_count(s))
var nn []string
for i := 0; i < nv; i++ {
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
}
}
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close) runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
return ss, nil return ss, nil
} }
@ -360,7 +422,12 @@ func (s *SQLiteStmt) Close() error {
// Return a number of parameters. // Return a number of parameters.
func (s *SQLiteStmt) NumInput() int { func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s)) return s.nv
}
type bindArg struct {
n int
v driver.Value
} }
func (s *SQLiteStmt) bind(args []driver.Value) error { func (s *SQLiteStmt) bind(args []driver.Value) error {
@ -369,8 +436,24 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
return s.c.lastError() return s.c.lastError()
} }
for i, v := range args { var vargs []bindArg
n := C.int(i + 1) narg := len(args)
vargs = make([]bindArg, narg)
if len(s.nn) > 0 {
for i, v := range s.nn {
if pi, err := strconv.Atoi(v[1:]); err == nil {
vargs[i] = bindArg{pi, args[i]}
}
}
} else {
for i, v := range args {
vargs[i] = bindArg{i + 1, v}
}
}
for _, varg := range vargs {
n := C.int(varg.n)
v := varg.v
switch v := v.(type) { switch v := v.(type) {
case nil: case nil:
rv = C.sqlite3_bind_null(s.s, n) rv = C.sqlite3_bind_null(s.s, n)
@ -431,19 +514,18 @@ func (r *SQLiteResult) RowsAffected() (int64, error) {
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
if err := s.bind(args); err != nil { if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s) C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err return nil, err
} }
rv := C.sqlite3_step(s.s) var rowid, changes C.long
rv := C._sqlite3_step(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
err := s.c.lastError()
C.sqlite3_reset(s.s) C.sqlite3_reset(s.s)
return nil, s.c.lastError() C.sqlite3_clear_bindings(s.s)
return nil, err
} }
return &SQLiteResult{int64(rowid), int64(changes)}, nil
res := &SQLiteResult{
int64(C._sqlite3_last_insert_rowid(s.c.db)),
int64(C._sqlite3_changes(s.c.db)),
}
return res, nil
} }
// Close the rows. // Close the rows.
@ -499,7 +581,22 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
switch rc.decltype[i] { switch rc.decltype[i] {
case "timestamp", "datetime", "date": case "timestamp", "datetime", "date":
dest[i] = time.Unix(val, 0).Local() unixTimestamp := strconv.FormatInt(val, 10)
var t time.Time
if len(unixTimestamp) == 13 {
duration, err := time.ParseDuration(unixTimestamp + "ms")
if err != nil {
return fmt.Errorf("error parsing %s value %d, %s", rc.decltype[i], val, err)
}
epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)
t = epoch.Add(duration)
} else {
t = time.Unix(val, 0)
}
if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc)
}
dest[i] = t
case "boolean": case "boolean":
dest[i] = val > 0 dest[i] = val > 0
default: default:
@ -531,16 +628,21 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
switch rc.decltype[i] { switch rc.decltype[i] {
case "timestamp", "datetime", "date": case "timestamp", "datetime", "date":
var t time.Time
for _, format := range SQLiteTimestampFormats { for _, format := range SQLiteTimestampFormats {
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
dest[i] = timeVal.Local() t = timeVal
break break
} }
} }
if err != nil { if err != nil {
// The column is a time value, so return the zero time on parse failure. // The column is a time value, so return the zero time on parse failure.
dest[i] = time.Time{} t = time.Time{}
} }
if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc)
}
dest[i] = t
default: default:
dest[i] = []byte(s) dest[i] = []byte(s)
} }

View File

@ -0,0 +1,83 @@
// Copyright (C) 2015 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
import (
"database/sql"
"os"
"testing"
)
func TestFTS3(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts3(id INTEGER PRIMARY KEY, value TEXT)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 1, `今日の 晩御飯は 天麩羅よ`)
if err != nil {
t.Fatal("Failed to insert value:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 2, `今日は いい 天気だ`)
if err != nil {
t.Fatal("Failed to insert value:", err)
}
rows, err := db.Query("SELECT id, value FROM foo WHERE value MATCH '今日* 天*'")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
for rows.Next() {
var id int
var value string
if err := rows.Scan(&id, &value); err != nil {
t.Error("Unable to scan results:", err)
continue
}
if id == 1 && value != `今日の 晩御飯は 天麩羅よ` {
t.Error("Value for id 1 should be `今日の 晩御飯は 天麩羅よ`, but:", value)
} else if id == 2 && value != `今日は いい 天気だ` {
t.Error("Value for id 2 should be `今日は いい 天気だ`, but:", value)
}
}
rows, err = db.Query("SELECT value FROM foo WHERE value MATCH '今日* 天麩羅*'")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
var value string
if !rows.Next() {
t.Fatal("Result should be only one")
}
if err := rows.Scan(&value); err != nil {
t.Fatal("Unable to scan results:", err)
}
if value != `今日の 晩御飯は 天麩羅よ` {
t.Fatal("Value should be `今日の 晩御飯は 天麩羅よ`, but:", value)
}
if rows.Next() {
t.Fatal("Result should be only one")
}
}

View File

@ -9,6 +9,6 @@ package sqlite3
/* /*
#cgo CFLAGS: -I. #cgo CFLAGS: -I.
#cgo linux LDFLAGS: -ldl #cgo linux LDFLAGS: -ldl
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE #cgo LDFLAGS: -lpthread
*/ */
import "C" import "C"

View File

@ -9,8 +9,10 @@ import (
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time" "time"
@ -309,6 +311,7 @@ func TestTimestamp(t *testing.T) {
{"0000-00-00 00:00:00", time.Time{}}, {"0000-00-00 00:00:00", time.Time{}},
{timestamp1, timestamp1}, {timestamp1, timestamp1},
{timestamp1.Unix(), timestamp1}, {timestamp1.Unix(), timestamp1},
{timestamp1.UnixNano() / int64(time.Millisecond), timestamp1},
{timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1}, {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1},
{timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1}, {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1},
{timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1}, {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1},
@ -633,6 +636,102 @@ func TestWAL(t *testing.T) {
} }
} }
func TestTimezoneConversion(t *testing.T) {
zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
for _, tz := range zones {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz))
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
loc, err := time.LoadLocation(tz)
if err != nil {
t.Fatal("Failed to load location:", err)
}
timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC)
timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC)
timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC)
tests := []struct {
value interface{}
expected time.Time
}{
{"nonsense", time.Time{}.In(loc)},
{"0000-00-00 00:00:00", time.Time{}.In(loc)},
{timestamp1, timestamp1.In(loc)},
{timestamp1.Unix(), timestamp1.In(loc)},
{timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)},
{timestamp2, timestamp2.In(loc)},
{"2006-01-02 15:04:05.123456789", timestamp2.In(loc)},
{"2006-01-02T15:04:05.123456789", timestamp2.In(loc)},
{"2012-11-04", timestamp3.In(loc)},
{"2012-11-04 00:00", timestamp3.In(loc)},
{"2012-11-04 00:00:00", timestamp3.In(loc)},
{"2012-11-04 00:00:00.000", timestamp3.In(loc)},
{"2012-11-04T00:00", timestamp3.In(loc)},
{"2012-11-04T00:00:00", timestamp3.In(loc)},
{"2012-11-04T00:00:00.000", timestamp3.In(loc)},
}
for i := range tests {
_, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value)
if err != nil {
t.Fatal("Failed to insert timestamp:", err)
}
}
rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
seen := 0
for rows.Next() {
var id int
var ts, dt time.Time
if err := rows.Scan(&id, &ts, &dt); err != nil {
t.Error("Unable to scan results:", err)
continue
}
if id < 0 || id >= len(tests) {
t.Error("Bad row id: ", id)
continue
}
seen++
if !tests[id].expected.Equal(ts) {
t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts)
}
if !tests[id].expected.Equal(dt) {
t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
}
if tests[id].expected.Location().String() != ts.Location().String() {
t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), ts.Location().String())
}
if tests[id].expected.Location().String() != dt.Location().String() {
t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), dt.Location().String())
}
}
if seen != len(tests) {
t.Errorf("Expected to see %d rows", len(tests))
}
}
}
func TestSuite(t *testing.T) { func TestSuite(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:") db, err := sql.Open("sqlite3", ":memory:")
if err != nil { if err != nil {
@ -742,3 +841,107 @@ func TestStress(t *testing.T) {
db.Close() db.Close()
} }
} }
func TestDateTimeLocal(t *testing.T) {
zone := "Asia/Tokyo"
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone)
if err != nil {
t.Fatal("Failed to open database:", err)
}
db.Exec("CREATE TABLE foo (dt datetime);")
db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');")
row := db.QueryRow("select * from foo")
var d time.Time
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.Hour() == 15 || !strings.Contains(d.String(), "JST") {
t.Fatal("Result should have timezone", d)
}
db.Close()
db, err = sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
row = db.QueryRow("select * from foo")
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") {
t.Fatalf("Result should not have timezone %v %v", zone, d.String())
}
_, err = db.Exec("DELETE FROM foo")
if err != nil {
t.Fatal("Failed to delete table:", err)
}
dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST")
if err != nil {
t.Fatal("Failed to parse datetime:", err)
}
db.Exec("INSERT INTO foo VALUES(?);", dt)
db.Close()
db, err = sql.Open("sqlite3", tempFilename+"?_loc="+zone)
if err != nil {
t.Fatal("Failed to open database:", err)
}
row = db.QueryRow("select * from foo")
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.Hour() != 15 || !strings.Contains(d.String(), "JST") {
t.Fatalf("Result should have timezone %v %v", zone, d.String())
}
}
func TestVersion(t *testing.T) {
s, n, id := Version()
if s == "" || n == 0 || id == "" {
t.Errorf("Version failed %q, %d, %q\n", s, n, id)
}
}
func TestNumberNamedParams(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
_, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo")
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, "foo")
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}

View File

@ -2,6 +2,7 @@
// //
// Use of this source code is governed by an MIT-style // Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build windows
package sqlite3 package sqlite3
@ -9,6 +10,5 @@ package sqlite3
#cgo CFLAGS: -I. -fno-stack-check -fno-stack-protector -mno-stack-arg-probe #cgo CFLAGS: -I. -fno-stack-check -fno-stack-protector -mno-stack-arg-probe
#cgo windows,386 CFLAGS: -D_localtime32=localtime #cgo windows,386 CFLAGS: -D_localtime32=localtime
#cgo LDFLAGS: -lmingwex -lmingw32 #cgo LDFLAGS: -lmingwex -lmingw32
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
*/ */
import "C" import "C"

View File

@ -17,7 +17,7 @@
*/ */
#ifndef _SQLITE3EXT_H_ #ifndef _SQLITE3EXT_H_
#define _SQLITE3EXT_H_ #define _SQLITE3EXT_H_
#include "sqlite3.h" #include "sqlite3-binding.h"
typedef struct sqlite3_api_routines sqlite3_api_routines; typedef struct sqlite3_api_routines sqlite3_api_routines;