Merge branch 'master' into panel_repeat

Conflicts:
	public/app/features/dashboard/dashboardCtrl.js
	public/app/partials/submenu.html
	public/css/less/submenu.less
	public/test/specs/templateSrv-specs.js
	src/app/partials/roweditor.html
This commit is contained in:
Torkel Ödegaard 2015-04-27 13:53:02 +02:00
commit bf6f0f1a65
686 changed files with 69990 additions and 12483 deletions

View File

@ -1,17 +1,17 @@
[run]
init_cmds = [
["go", "build", "-o", "./bin/grafana"],
["./bin/grafana", "web"]
["go", "build", "-o", "./bin/grafana-server"],
["./bin/grafana-server"]
]
watch_all = true
watch_dirs = [
"$WORKDIR/pkg",
"$WORKDIR/src/views",
"$WORKDIR/public/views",
"$WORKDIR/conf",
]
watch_exts = [".go", ".ini"]
build_delay = 1500
cmds = [
["go", "build", "-o", "./bin/grafana"],
["./bin/grafana", "web"]
["go", "build", "-o", "./bin/grafana-server"],
["./bin/grafana-server"]
]

4
.gitignore vendored
View File

@ -13,9 +13,7 @@ docs/changed-files
docs/changed-files
# locally required config files
web.config
config.js
src/css/*.min.css
public/css/*.min.css
# Editor junk
*.sublime-workspace

View File

@ -1,9 +0,0 @@
language: node_js
node_js:
- "0.10"
git:
depth: 1
before_script:
- npm install -g grunt-cli
after_script:
- npm run coveralls

View File

@ -1,6 +1,68 @@
# 2.0.0 (unreleased)
# 2.0.3 (unreleased)
**Fixes**
- [Issue #1872](https://github.com/grafana/grafana/issues/1872). Firefox/IE issue, invisible text in dashboard search fixed
- [Issue #1857](https://github.com/grafana/grafana/issues/1857). /api/login/ping Fix for issue when behind reverse proxy and subpath
- [Issue #1863](https://github.com/grafana/grafana/issues/1863). MySQL: Dashboard.data column type changed to mediumtext (sql migration added)
# 2.0.2 (2015-04-22)
**Fixes**
- [Issue #1832](https://github.com/grafana/grafana/issues/1832). Graph Panel + Legend Table mode: Many series casued zero height graph, now legend will never reduce the height of the graph below 50% of row height.
- [Issue #1846](https://github.com/grafana/grafana/issues/1846). Snapshots: Fixed issue with snapshoting dashboards with an interval template variable
- [Issue #1848](https://github.com/grafana/grafana/issues/1848). Panel timeshift: You can now use panel timeshift without a relative time override
# 2.0.1 (2015-04-20)
**Fixes**
- [Issue #1784](https://github.com/grafana/grafana/issues/1784). Data source proxy: Fixed issue with using data source proxy when grafana is behind nginx suburl
- [Issue #1749](https://github.com/grafana/grafana/issues/1749). Graph Panel: Table legends are now visible when rendered to PNG
- [Issue #1786](https://github.com/grafana/grafana/issues/1786). Graph Panel: Legend in table mode now aligns, graph area is reduced depending on how many series
- [Issue #1734](https://github.com/grafana/grafana/issues/1734). Support for unicode / international characters in dashboard title (improved slugify)
- [Issue #1782](https://github.com/grafana/grafana/issues/1782). Github OAuth: Now works with Github for Enterprise, thanks @williamjoy
- [Issue #1780](https://github.com/grafana/grafana/issues/1780). Dashboard snapshot: Should not require login to view snapshot, Fixes #1780
# 2.0.0-Beta3 (2015-04-12)
**RPM / DEB Package changes (to follow HFS)**
- binary name changed to grafana-server
- does not install to `/opt/grafana` any more, installs to `/usr/share/grafana`
- binary to `/usr/sbin/grafana-server`
- init.d script improvements, renamed to `/etc/init.d/grafana-server`
- added default file with environment variables,
- `/etc/default/grafana-server` (deb/ubuntu)
- `/etc/sysconfig/grafana-server` (centos/redhat)
- added systemd service file, tested on debian jessie and centos7
- config file in same location `/etc/grafana/grafana.ini` (now complete config file but with every setting commented out)
- data directory (where sqlite3) file is stored is now by default `/var/lib/grafana`
- no symlinking current to versions anymore
- For more info see [Issue #1758](https://github.com/grafana/grafana/issues/1758).
**Config breaking change (setting rename)**
- `[log] root_path` has changed to `[paths] logs`
# 2.0.0-Beta2 (...)
**Enhancements**
- [Issue #1701](https://github.com/grafana/grafana/issues/1701). Share modal: Override UI theme via URL param for Share link, rendered panel, or embedded panel
- [Issue #1660](https://github.com/grafana/grafana/issues/1660). OAuth: Specify allowed email address domains for google or and github oauth logins
**Fixes**
- [Issue #1649](https://github.com/grafana/grafana/issues/1649). HTTP API: grafana /render calls nows with api keys
- [Issue #1667](https://github.com/grafana/grafana/issues/1667). Datasource proxy & session timeout fix (casued 401 Unauthorized error after a while)
- [Issue #1707](https://github.com/grafana/grafana/issues/1707). Unsaved changes: Do not show for snapshots, scripted and file based dashboards
- [Issue #1703](https://github.com/grafana/grafana/issues/1703). Unsaved changes: Do not show for users with role `Viewer`
- [Issue #1675](https://github.com/grafana/grafana/issues/1675). Data source proxy: Fixed issue with Gzip enabled and data source proxy
- [Issue #1681](https://github.com/grafana/grafana/issues/1681). MySQL session: fixed problem using mysql as session store
- [Issue #1671](https://github.com/grafana/grafana/issues/1671). Data sources: Fixed issue with changing default data source (should not require full page load to take effect, now fixed)
- [Issue #1685](https://github.com/grafana/grafana/issues/1685). Search: Dashboard results should be sorted alphabetically
- [Issue #1673](https://github.com/grafana/grafana/issues/1673). Basic auth: Fixed issue when using basic auth proxy infront of Grafana
# 2.0.0-Beta1 (2015-03-30)
**New features**
- [Issue #1623](https://github.com/grafana/grafana/issues/1623). Share Dashboard: Dashboard snapshot sharing (dash and data snapshot), save to local or save to public snapshot dashboard snapshots.raintank.io site
- [Issue #1622](https://github.com/grafana/grafana/issues/1622). Share Panel: The share modal now has an embed option, gives you an iframe that you can use to embedd a single graph on another web site
- [Issue #718](https://github.com/grafana/grafana/issues/718). Dashboard: When saving a dashboard and another user has made changes inbetween the user is promted with a warning if he really wants to overwrite the other's changes
- [Issue #1331](https://github.com/grafana/grafana/issues/1331). Graph & Singlestat: New axis/unit format selector and more units (kbytes, Joule, Watt, eV), and new design for graph axis & grid tab and single stat options tab views
@ -19,6 +81,7 @@
- [Issue #599](https://github.com/grafana/grafana/issues/599). Graph: Added right y axis label setting and graph support
- [Issue #1253](https://github.com/grafana/grafana/issues/1253). Graph & Singlestat: Users can now set decimal precision for legend and tooltips (override auto precision)
- [Issue #1255](https://github.com/grafana/grafana/issues/1255). Templating: Dashboard will now wait to load until all template variables that have refresh on load set or are initialized via url to be fully loaded and so all variables are in valid state before panels start issuing metric requests.
- [Issue #1344](https://github.com/grafana/grafana/issues/1344). OpenTSDB: Alias patterns (reference tag values), syntax is: $tag_tagname or [[tag_tagname]]
**Fixes**
- [Issue #1298](https://github.com/grafana/grafana/issues/1298). InfluxDB: Fix handling of empty array in templating variable query

37
Godeps/Godeps.json generated
View File

@ -1,5 +1,5 @@
{
"ImportPath": "github.com/torkelo/grafana-pro",
"ImportPath": "github.com/grafana/grafana",
"GoVersion": "go1.3",
"Packages": [
"./pkg/..."
@ -14,9 +14,12 @@
"Rev": "93de4f3fad97bf246b838f828e2348f46f21f20a"
},
{
"ImportPath": "github.com/codegangsta/cli",
"Comment": "1.2.0-38-g9908e96",
"Rev": "9908e96513e5a94de37004098a3974a567f18111"
"ImportPath": "github.com/dalu/slug",
"Rev": "6dbd13912e9be466e2c1de349a2c7d1466c97e07"
},
{
"ImportPath": "github.com/dalu/unidecode",
"Rev": "339814d47f3e32a6f7036a0a4c56ed9b373dd755"
},
{
"ImportPath": "github.com/go-sql-driver/mysql",
@ -25,12 +28,12 @@
},
{
"ImportPath": "github.com/go-xorm/core",
"Rev": "a949e067ced1cb6e6ef5c38b6f28b074fa718f1e"
"Rev": "be6e7ac47dc57bd0ada25322fa526944f66ccaa6"
},
{
"ImportPath": "github.com/go-xorm/xorm",
"Comment": "v0.4.1-19-g5c23849",
"Rev": "5c23849a66f4593e68909bb6c1fa30651b5b0541"
"Comment": "v0.4.2-58-ge2889e5",
"Rev": "e2889e5517600b82905f1d2ba8b70deb71823ffe"
},
{
"ImportPath": "github.com/jtolds/gls",
@ -47,11 +50,11 @@
},
{
"ImportPath": "github.com/macaron-contrib/session",
"Rev": "65b8817c40cb5bdce08673a15fd2a648c2ba0e16"
"Rev": "31e841d95c7302b9ac456c830ea2d6dfcef4f84a"
},
{
"ImportPath": "github.com/mattn/go-sqlite3",
"Rev": "d10e2c8f62100097910367dee90a9bd89d426a44"
"Rev": "e28cd440fabdd39b9520344bc26829f61db40ece"
},
{
"ImportPath": "github.com/smartystreets/goconvey/convey",
@ -68,12 +71,26 @@
},
{
"ImportPath": "golang.org/x/oauth2",
"Rev": "e5909d4679a1926c774c712b343f10b8298687a3"
"Rev": "c58fcf0ffc1c772aa2e1ee4894bc19f2649263b2"
},
{
"ImportPath": "gopkg.in/bufio.v1",
"Comment": "v1",
"Rev": "567b2bfa514e796916c4747494d6ff5132a1dfce"
},
{
"ImportPath": "gopkg.in/ini.v1",
"Comment": "v0-16-g1772191",
"Rev": "177219109c97e7920c933e21c9b25f874357b237"
},
{
"ImportPath": "gopkg.in/redis.v2",
"Comment": "v2.3.2",
"Rev": "e6179049628164864e6e84e973cfb56335748dea"
},
{
"ImportPath": "gopkgs.com/pool.v1",
"Rev": "c850f092aad1780cbffff25f471c5cc32097932a"
}
]
}

View File

@ -1,6 +0,0 @@
language: go
go: 1.1
script:
- go vet ./...
- go test -v ./...

View File

@ -1,21 +0,0 @@
Copyright (C) 2013 Jeremy Saenz
All Rights Reserved.
MIT LICENSE
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,285 +0,0 @@
[![Build Status](https://travis-ci.org/codegangsta/cli.png?branch=master)](https://travis-ci.org/codegangsta/cli)
# cli.go
cli.go is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way.
You can view the API docs here:
http://godoc.org/github.com/codegangsta/cli
## Overview
Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app.
**This is where cli.go comes into play.** cli.go makes command line programming fun, organized, and expressive!
## Installation
Make sure you have a working Go environment (go 1.1 is *required*). [See the install instructions](http://golang.org/doc/install.html).
To install `cli.go`, simply run:
```
$ go get github.com/codegangsta/cli
```
Make sure your `PATH` includes to the `$GOPATH/bin` directory so your commands can be easily used:
```
export PATH=$PATH:$GOPATH/bin
```
## Getting Started
One of the philosophies behind cli.go is that an API should be playful and full of discovery. So a cli.go app can be as little as one line of code in `main()`.
``` go
package main
import (
"os"
"github.com/codegangsta/cli"
)
func main() {
cli.NewApp().Run(os.Args)
}
```
This app will run and show help text, but is not very useful. Let's give an action to execute and some help documentation:
``` go
package main
import (
"os"
"github.com/codegangsta/cli"
)
func main() {
app := cli.NewApp()
app.Name = "boom"
app.Usage = "make an explosive entrance"
app.Action = func(c *cli.Context) {
println("boom! I say!")
}
app.Run(os.Args)
}
```
Running this already gives you a ton of functionality, plus support for things like subcommands and flags, which are covered below.
## Example
Being a programmer can be a lonely job. Thankfully by the power of automation that is not the case! Let's create a greeter app to fend off our demons of loneliness!
Start by creating a directory named `greet`, and within it, add a file, `greet.go` with the following code in it:
``` go
package main
import (
"os"
"github.com/codegangsta/cli"
)
func main() {
app := cli.NewApp()
app.Name = "greet"
app.Usage = "fight the loneliness!"
app.Action = func(c *cli.Context) {
println("Hello friend!")
}
app.Run(os.Args)
}
```
Install our command to the `$GOPATH/bin` directory:
```
$ go install
```
Finally run our new command:
```
$ greet
Hello friend!
```
cli.go also generates some bitchass help text:
```
$ greet help
NAME:
greet - fight the loneliness!
USAGE:
greet [global options] command [command options] [arguments...]
VERSION:
0.0.0
COMMANDS:
help, h Shows a list of commands or help for one command
GLOBAL OPTIONS
--version Shows version information
```
### Arguments
You can lookup arguments by calling the `Args` function on `cli.Context`.
``` go
...
app.Action = func(c *cli.Context) {
println("Hello", c.Args()[0])
}
...
```
### Flags
Setting and querying flags is simple.
``` go
...
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang",
Value: "english",
Usage: "language for the greeting",
},
}
app.Action = func(c *cli.Context) {
name := "someone"
if len(c.Args()) > 0 {
name = c.Args()[0]
}
if c.String("lang") == "spanish" {
println("Hola", name)
} else {
println("Hello", name)
}
}
...
```
#### Alternate Names
You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g.
``` go
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang, l",
Value: "english",
Usage: "language for the greeting",
},
}
```
#### Values from the Environment
You can also have the default value set from the environment via `EnvVar`. e.g.
``` go
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang, l",
Value: "english",
Usage: "language for the greeting",
EnvVar: "APP_LANG",
},
}
```
That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error.
### Subcommands
Subcommands can be defined for a more git-like command line app.
```go
...
app.Commands = []cli.Command{
{
Name: "add",
ShortName: "a",
Usage: "add a task to the list",
Action: func(c *cli.Context) {
println("added task: ", c.Args().First())
},
},
{
Name: "complete",
ShortName: "c",
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
},
},
{
Name: "template",
ShortName: "r",
Usage: "options for task templates",
Subcommands: []cli.Command{
{
Name: "add",
Usage: "add a new template",
Action: func(c *cli.Context) {
println("new task template: ", c.Args().First())
},
},
{
Name: "remove",
Usage: "remove an existing template",
Action: func(c *cli.Context) {
println("removed task template: ", c.Args().First())
},
},
},
},
}
...
```
### Bash Completion
You can enable completion commands by setting the `EnableBashCompletion`
flag on the `App` object. By default, this setting will only auto-complete to
show an app's subcommands, but you can write your own completion methods for
the App or its subcommands.
```go
...
var tasks = []string{"cook", "clean", "laundry", "eat", "sleep", "code"}
app := cli.NewApp()
app.EnableBashCompletion = true
app.Commands = []cli.Command{
{
Name: "complete",
ShortName: "c",
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
},
BashComplete: func(c *cli.Context) {
// This will complete if no args are passed
if len(c.Args()) > 0 {
return
}
for _, t := range tasks {
fmt.Println(t)
}
},
}
}
...
```
#### To Enable
Source the `autocomplete/bash_autocomplete` file in your `.bashrc` file while
setting the `PROG` variable to the name of your program:
`PROG=myprogram source /.../cli/autocomplete/bash_autocomplete`
## Contribution Guidelines
Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch.
If you are have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together.
If you feel like you have contributed to the project but have not yet been added as a collaborator, I probably forgot to add you. Hit @codegangsta up over email and we will get it figured out.

View File

@ -1,251 +0,0 @@
package cli
import (
"fmt"
"io/ioutil"
"os"
"time"
)
// App is the main structure of a cli application. It is recomended that
// and app be created with the cli.NewApp() function
type App struct {
// The name of the program. Defaults to os.Args[0]
Name string
// Description of the program.
Usage string
// Version of the program
Version string
// List of commands to execute
Commands []Command
// List of flags to parse
Flags []Flag
// Boolean to enable bash completion commands
EnableBashCompletion bool
// Boolean to hide built-in help command
HideHelp bool
// Boolean to hide built-in version flag
HideVersion bool
// An action to execute when the bash-completion flag is set
BashComplete func(context *Context)
// An action to execute before any subcommands are run, but after the context is ready
// If a non-nil error is returned, no subcommands are run
Before func(context *Context) error
// The action to execute when no subcommands are specified
Action func(context *Context)
// Execute this function if the proper command cannot be found
CommandNotFound func(context *Context, command string)
// Compilation date
Compiled time.Time
// Author
Author string
// Author e-mail
Email string
}
// Tries to find out when this binary was compiled.
// Returns the current time if it fails to find it.
func compileTime() time.Time {
info, err := os.Stat(os.Args[0])
if err != nil {
return time.Now()
}
return info.ModTime()
}
// Creates a new cli Application with some reasonable defaults for Name, Usage, Version and Action.
func NewApp() *App {
return &App{
Name: os.Args[0],
Usage: "A new cli application",
Version: "0.0.0",
BashComplete: DefaultAppComplete,
Action: helpCommand.Action,
Compiled: compileTime(),
}
}
// Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination
func (a *App) Run(arguments []string) error {
// append help to commands
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
a.appendFlag(HelpFlag)
}
//append version/help flags
if a.EnableBashCompletion {
a.appendFlag(BashCompletionFlag)
}
if !a.HideVersion {
a.appendFlag(VersionFlag)
}
// parse flags
set := flagSet(a.Name, a.Flags)
set.SetOutput(ioutil.Discard)
err := set.Parse(arguments[1:])
nerr := normalizeFlags(a.Flags, set)
if nerr != nil {
fmt.Println(nerr)
context := NewContext(a, set, set)
ShowAppHelp(context)
fmt.Println("")
return nerr
}
context := NewContext(a, set, set)
if err != nil {
fmt.Printf("Incorrect Usage.\n\n")
ShowAppHelp(context)
fmt.Println("")
return err
}
if checkCompletions(context) {
return nil
}
if checkHelp(context) {
return nil
}
if checkVersion(context) {
return nil
}
if a.Before != nil {
err := a.Before(context)
if err != nil {
return err
}
}
args := context.Args()
if args.Present() {
name := args.First()
c := a.Command(name)
if c != nil {
return c.Run(context)
}
}
// Run default Action
a.Action(context)
return nil
}
// Another entry point to the cli app, takes care of passing arguments and error handling
func (a *App) RunAndExitOnError() {
if err := a.Run(os.Args); err != nil {
os.Stderr.WriteString(fmt.Sprintln(err))
os.Exit(1)
}
}
// Invokes the subcommand given the context, parses ctx.Args() to generate command-specific flags
func (a *App) RunAsSubcommand(ctx *Context) error {
// append help to commands
if len(a.Commands) > 0 {
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
a.appendFlag(HelpFlag)
}
}
// append flags
if a.EnableBashCompletion {
a.appendFlag(BashCompletionFlag)
}
// parse flags
set := flagSet(a.Name, a.Flags)
set.SetOutput(ioutil.Discard)
err := set.Parse(ctx.Args().Tail())
nerr := normalizeFlags(a.Flags, set)
context := NewContext(a, set, ctx.globalSet)
if nerr != nil {
fmt.Println(nerr)
if len(a.Commands) > 0 {
ShowSubcommandHelp(context)
} else {
ShowCommandHelp(ctx, context.Args().First())
}
fmt.Println("")
return nerr
}
if err != nil {
fmt.Printf("Incorrect Usage.\n\n")
ShowSubcommandHelp(context)
return err
}
if checkCompletions(context) {
return nil
}
if len(a.Commands) > 0 {
if checkSubcommandHelp(context) {
return nil
}
} else {
if checkCommandHelp(ctx, context.Args().First()) {
return nil
}
}
if a.Before != nil {
err := a.Before(context)
if err != nil {
return err
}
}
args := context.Args()
if args.Present() {
name := args.First()
c := a.Command(name)
if c != nil {
return c.Run(context)
}
}
// Run default Action
if len(a.Commands) > 0 {
a.Action(context)
} else {
a.Action(ctx)
}
return nil
}
// Returns the named command on App. Returns nil if the command does not exist
func (a *App) Command(name string) *Command {
for _, c := range a.Commands {
if c.HasName(name) {
return &c
}
}
return nil
}
func (a *App) hasFlag(flag Flag) bool {
for _, f := range a.Flags {
if flag == f {
return true
}
}
return false
}
func (a *App) appendFlag(flag Flag) {
if !a.hasFlag(flag) {
a.Flags = append(a.Flags, flag)
}
}

View File

@ -1,423 +0,0 @@
package cli_test
import (
"fmt"
"os"
"testing"
"github.com/codegangsta/cli"
)
func ExampleApp() {
// set args for examples sake
os.Args = []string{"greet", "--name", "Jeremy"}
app := cli.NewApp()
app.Name = "greet"
app.Flags = []cli.Flag{
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
}
app.Action = func(c *cli.Context) {
fmt.Printf("Hello %v\n", c.String("name"))
}
app.Run(os.Args)
// Output:
// Hello Jeremy
}
func ExampleAppSubcommand() {
// set args for examples sake
os.Args = []string{"say", "hi", "english", "--name", "Jeremy"}
app := cli.NewApp()
app.Name = "say"
app.Commands = []cli.Command{
{
Name: "hello",
ShortName: "hi",
Usage: "use it to see a description",
Description: "This is how we describe hello the function",
Subcommands: []cli.Command{
{
Name: "english",
ShortName: "en",
Usage: "sends a greeting in english",
Description: "greets someone in english",
Flags: []cli.Flag{
cli.StringFlag{
Name: "name",
Value: "Bob",
Usage: "Name of the person to greet",
},
},
Action: func(c *cli.Context) {
fmt.Println("Hello,", c.String("name"))
},
},
},
},
}
app.Run(os.Args)
// Output:
// Hello, Jeremy
}
func ExampleAppHelp() {
// set args for examples sake
os.Args = []string{"greet", "h", "describeit"}
app := cli.NewApp()
app.Name = "greet"
app.Flags = []cli.Flag{
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
}
app.Commands = []cli.Command{
{
Name: "describeit",
ShortName: "d",
Usage: "use it to see a description",
Description: "This is how we describe describeit the function",
Action: func(c *cli.Context) {
fmt.Printf("i like to describe things")
},
},
}
app.Run(os.Args)
// Output:
// NAME:
// describeit - use it to see a description
//
// USAGE:
// command describeit [arguments...]
//
// DESCRIPTION:
// This is how we describe describeit the function
}
func ExampleAppBashComplete() {
// set args for examples sake
os.Args = []string{"greet", "--generate-bash-completion"}
app := cli.NewApp()
app.Name = "greet"
app.EnableBashCompletion = true
app.Commands = []cli.Command{
{
Name: "describeit",
ShortName: "d",
Usage: "use it to see a description",
Description: "This is how we describe describeit the function",
Action: func(c *cli.Context) {
fmt.Printf("i like to describe things")
},
}, {
Name: "next",
Usage: "next example",
Description: "more stuff to see when generating bash completion",
Action: func(c *cli.Context) {
fmt.Printf("the next example")
},
},
}
app.Run(os.Args)
// Output:
// describeit
// d
// next
// help
// h
}
func TestApp_Run(t *testing.T) {
s := ""
app := cli.NewApp()
app.Action = func(c *cli.Context) {
s = s + c.Args().First()
}
err := app.Run([]string{"command", "foo"})
expect(t, err, nil)
err = app.Run([]string{"command", "bar"})
expect(t, err, nil)
expect(t, s, "foobar")
}
var commandAppTests = []struct {
name string
expected bool
}{
{"foobar", true},
{"batbaz", true},
{"b", true},
{"f", true},
{"bat", false},
{"nothing", false},
}
func TestApp_Command(t *testing.T) {
app := cli.NewApp()
fooCommand := cli.Command{Name: "foobar", ShortName: "f"}
batCommand := cli.Command{Name: "batbaz", ShortName: "b"}
app.Commands = []cli.Command{
fooCommand,
batCommand,
}
for _, test := range commandAppTests {
expect(t, app.Command(test.name) != nil, test.expected)
}
}
func TestApp_CommandWithArgBeforeFlags(t *testing.T) {
var parsedOption, firstArg string
app := cli.NewApp()
command := cli.Command{
Name: "cmd",
Flags: []cli.Flag{
cli.StringFlag{Name: "option", Value: "", Usage: "some option"},
},
Action: func(c *cli.Context) {
parsedOption = c.String("option")
firstArg = c.Args().First()
},
}
app.Commands = []cli.Command{command}
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option"})
expect(t, parsedOption, "my-option")
expect(t, firstArg, "my-arg")
}
func TestApp_Float64Flag(t *testing.T) {
var meters float64
app := cli.NewApp()
app.Flags = []cli.Flag{
cli.Float64Flag{Name: "height", Value: 1.5, Usage: "Set the height, in meters"},
}
app.Action = func(c *cli.Context) {
meters = c.Float64("height")
}
app.Run([]string{"", "--height", "1.93"})
expect(t, meters, 1.93)
}
func TestApp_ParseSliceFlags(t *testing.T) {
var parsedOption, firstArg string
var parsedIntSlice []int
var parsedStringSlice []string
app := cli.NewApp()
command := cli.Command{
Name: "cmd",
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "p", Value: &cli.IntSlice{}, Usage: "set one or more ip addr"},
cli.StringSliceFlag{Name: "ip", Value: &cli.StringSlice{}, Usage: "set one or more ports to open"},
},
Action: func(c *cli.Context) {
parsedIntSlice = c.IntSlice("p")
parsedStringSlice = c.StringSlice("ip")
parsedOption = c.String("option")
firstArg = c.Args().First()
},
}
app.Commands = []cli.Command{command}
app.Run([]string{"", "cmd", "my-arg", "-p", "22", "-p", "80", "-ip", "8.8.8.8", "-ip", "8.8.4.4"})
IntsEquals := func(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
StrsEquals := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
var expectedIntSlice = []int{22, 80}
var expectedStringSlice = []string{"8.8.8.8", "8.8.4.4"}
if !IntsEquals(parsedIntSlice, expectedIntSlice) {
t.Errorf("%v does not match %v", parsedIntSlice, expectedIntSlice)
}
if !StrsEquals(parsedStringSlice, expectedStringSlice) {
t.Errorf("%v does not match %v", parsedStringSlice, expectedStringSlice)
}
}
func TestApp_BeforeFunc(t *testing.T) {
beforeRun, subcommandRun := false, false
beforeError := fmt.Errorf("fail")
var err error
app := cli.NewApp()
app.Before = func(c *cli.Context) error {
beforeRun = true
s := c.String("opt")
if s == "fail" {
return beforeError
}
return nil
}
app.Commands = []cli.Command{
cli.Command{
Name: "sub",
Action: func(c *cli.Context) {
subcommandRun = true
},
},
}
app.Flags = []cli.Flag{
cli.StringFlag{Name: "opt"},
}
// run with the Before() func succeeding
err = app.Run([]string{"command", "--opt", "succeed", "sub"})
if err != nil {
t.Fatalf("Run error: %s", err)
}
if beforeRun == false {
t.Errorf("Before() not executed when expected")
}
if subcommandRun == false {
t.Errorf("Subcommand not executed when expected")
}
// reset
beforeRun, subcommandRun = false, false
// run with the Before() func failing
err = app.Run([]string{"command", "--opt", "fail", "sub"})
// should be the same error produced by the Before func
if err != beforeError {
t.Errorf("Run error expected, but not received")
}
if beforeRun == false {
t.Errorf("Before() not executed when expected")
}
if subcommandRun == true {
t.Errorf("Subcommand executed when NOT expected")
}
}
func TestAppHelpPrinter(t *testing.T) {
oldPrinter := cli.HelpPrinter
defer func() {
cli.HelpPrinter = oldPrinter
}()
var wasCalled = false
cli.HelpPrinter = func(template string, data interface{}) {
wasCalled = true
}
app := cli.NewApp()
app.Run([]string{"-h"})
if wasCalled == false {
t.Errorf("Help printer expected to be called, but was not")
}
}
func TestAppVersionPrinter(t *testing.T) {
oldPrinter := cli.VersionPrinter
defer func() {
cli.VersionPrinter = oldPrinter
}()
var wasCalled = false
cli.VersionPrinter = func(c *cli.Context) {
wasCalled = true
}
app := cli.NewApp()
ctx := cli.NewContext(app, nil, nil)
cli.ShowVersion(ctx)
if wasCalled == false {
t.Errorf("Version printer expected to be called, but was not")
}
}
func TestAppCommandNotFound(t *testing.T) {
beforeRun, subcommandRun := false, false
app := cli.NewApp()
app.CommandNotFound = func(c *cli.Context, command string) {
beforeRun = true
}
app.Commands = []cli.Command{
cli.Command{
Name: "bar",
Action: func(c *cli.Context) {
subcommandRun = true
},
},
}
app.Run([]string{"command", "foo"})
expect(t, beforeRun, true)
expect(t, subcommandRun, false)
}
func TestGlobalFlagsInSubcommands(t *testing.T) {
subcommandRun := false
app := cli.NewApp()
app.Flags = []cli.Flag{
cli.BoolFlag{Name: "debug, d", Usage: "Enable debugging"},
}
app.Commands = []cli.Command{
cli.Command{
Name: "foo",
Subcommands: []cli.Command{
{
Name: "bar",
Action: func(c *cli.Context) {
if c.GlobalBool("debug") {
subcommandRun = true
}
},
},
},
},
}
app.Run([]string{"command", "-d", "foo", "bar"})
expect(t, subcommandRun, true)
}

View File

@ -1,13 +0,0 @@
#! /bin/bash
_cli_bash_autocomplete() {
local cur prev opts base
COMPREPLY=()
cur="${COMP_WORDS[COMP_CWORD]}"
prev="${COMP_WORDS[COMP_CWORD-1]}"
opts=$( ${COMP_WORDS[@]:0:$COMP_CWORD} --generate-bash-completion )
COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) )
return 0
}
complete -F _cli_bash_autocomplete $PROG

View File

@ -1,5 +0,0 @@
autoload -U compinit && compinit
autoload -U bashcompinit && bashcompinit
script_dir=$(dirname $0)
source ${script_dir}/bash_autocomplete

View File

@ -1,19 +0,0 @@
// Package cli provides a minimal framework for creating and organizing command line
// Go applications. cli is designed to be easy to understand and write, the most simple
// cli application can be written as follows:
// func main() {
// cli.NewApp().Run(os.Args)
// }
//
// Of course this application does not do much, so let's make this an actual application:
// func main() {
// app := cli.NewApp()
// app.Name = "greet"
// app.Usage = "say a greeting"
// app.Action = func(c *cli.Context) {
// println("Greetings")
// }
//
// app.Run(os.Args)
// }
package cli

View File

@ -1,100 +0,0 @@
package cli_test
import (
"os"
"github.com/codegangsta/cli"
)
func Example() {
app := cli.NewApp()
app.Name = "todo"
app.Usage = "task list on the command line"
app.Commands = []cli.Command{
{
Name: "add",
ShortName: "a",
Usage: "add a task to the list",
Action: func(c *cli.Context) {
println("added task: ", c.Args().First())
},
},
{
Name: "complete",
ShortName: "c",
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
},
},
}
app.Run(os.Args)
}
func ExampleSubcommand() {
app := cli.NewApp()
app.Name = "say"
app.Commands = []cli.Command{
{
Name: "hello",
ShortName: "hi",
Usage: "use it to see a description",
Description: "This is how we describe hello the function",
Subcommands: []cli.Command{
{
Name: "english",
ShortName: "en",
Usage: "sends a greeting in english",
Description: "greets someone in english",
Flags: []cli.Flag{
cli.StringFlag{
Name: "name",
Value: "Bob",
Usage: "Name of the person to greet",
},
},
Action: func(c *cli.Context) {
println("Hello, ", c.String("name"))
},
}, {
Name: "spanish",
ShortName: "sp",
Usage: "sends a greeting in spanish",
Flags: []cli.Flag{
cli.StringFlag{
Name: "surname",
Value: "Jones",
Usage: "Surname of the person to greet",
},
},
Action: func(c *cli.Context) {
println("Hola, ", c.String("surname"))
},
}, {
Name: "french",
ShortName: "fr",
Usage: "sends a greeting in french",
Flags: []cli.Flag{
cli.StringFlag{
Name: "nickname",
Value: "Stevie",
Usage: "Nickname of the person to greet",
},
},
Action: func(c *cli.Context) {
println("Bonjour, ", c.String("nickname"))
},
},
},
}, {
Name: "bye",
Usage: "says goodbye",
Action: func(c *cli.Context) {
println("bye")
},
},
}
app.Run(os.Args)
}

View File

@ -1,144 +0,0 @@
package cli
import (
"fmt"
"io/ioutil"
"strings"
)
// Command is a subcommand for a cli.App.
type Command struct {
// The name of the command
Name string
// short name of the command. Typically one character
ShortName string
// A short description of the usage of this command
Usage string
// A longer explanation of how the command works
Description string
// The function to call when checking for bash command completions
BashComplete func(context *Context)
// An action to execute before any sub-subcommands are run, but after the context is ready
// If a non-nil error is returned, no sub-subcommands are run
Before func(context *Context) error
// The function to call when this command is invoked
Action func(context *Context)
// List of child commands
Subcommands []Command
// List of flags to parse
Flags []Flag
// Treat all flags as normal arguments if true
SkipFlagParsing bool
// Boolean to hide built-in help command
HideHelp bool
}
// Invokes the command given the context, parses ctx.Args() to generate command-specific flags
func (c Command) Run(ctx *Context) error {
if len(c.Subcommands) > 0 || c.Before != nil {
return c.startApp(ctx)
}
if !c.HideHelp {
// append help to flags
c.Flags = append(
c.Flags,
HelpFlag,
)
}
if ctx.App.EnableBashCompletion {
c.Flags = append(c.Flags, BashCompletionFlag)
}
set := flagSet(c.Name, c.Flags)
set.SetOutput(ioutil.Discard)
firstFlagIndex := -1
for index, arg := range ctx.Args() {
if strings.HasPrefix(arg, "-") {
firstFlagIndex = index
break
}
}
var err error
if firstFlagIndex > -1 && !c.SkipFlagParsing {
args := ctx.Args()
regularArgs := args[1:firstFlagIndex]
flagArgs := args[firstFlagIndex:]
err = set.Parse(append(flagArgs, regularArgs...))
} else {
err = set.Parse(ctx.Args().Tail())
}
if err != nil {
fmt.Printf("Incorrect Usage.\n\n")
ShowCommandHelp(ctx, c.Name)
fmt.Println("")
return err
}
nerr := normalizeFlags(c.Flags, set)
if nerr != nil {
fmt.Println(nerr)
fmt.Println("")
ShowCommandHelp(ctx, c.Name)
fmt.Println("")
return nerr
}
context := NewContext(ctx.App, set, ctx.globalSet)
if checkCommandCompletions(context, c.Name) {
return nil
}
if checkCommandHelp(context, c.Name) {
return nil
}
context.Command = c
c.Action(context)
return nil
}
// Returns true if Command.Name or Command.ShortName matches given name
func (c Command) HasName(name string) bool {
return c.Name == name || c.ShortName == name
}
func (c Command) startApp(ctx *Context) error {
app := NewApp()
// set the name and usage
app.Name = fmt.Sprintf("%s %s", ctx.App.Name, c.Name)
if c.Description != "" {
app.Usage = c.Description
} else {
app.Usage = c.Usage
}
// set CommandNotFound
app.CommandNotFound = ctx.App.CommandNotFound
// set the flags and commands
app.Commands = c.Subcommands
app.Flags = c.Flags
app.HideHelp = c.HideHelp
// bash completion
app.EnableBashCompletion = ctx.App.EnableBashCompletion
if c.BashComplete != nil {
app.BashComplete = c.BashComplete
}
// set the actions
app.Before = c.Before
if c.Action != nil {
app.Action = c.Action
} else {
app.Action = helpSubcommand.Action
}
return app.RunAsSubcommand(ctx)
}

View File

@ -1,49 +0,0 @@
package cli_test
import (
"flag"
"testing"
"github.com/codegangsta/cli"
)
func TestCommandDoNotIgnoreFlags(t *testing.T) {
app := cli.NewApp()
set := flag.NewFlagSet("test", 0)
test := []string{"blah", "blah", "-break"}
set.Parse(test)
c := cli.NewContext(app, set, set)
command := cli.Command{
Name: "test-cmd",
ShortName: "tc",
Usage: "this is for testing",
Description: "testing",
Action: func(_ *cli.Context) {},
}
err := command.Run(c)
expect(t, err.Error(), "flag provided but not defined: -break")
}
func TestCommandIgnoreFlags(t *testing.T) {
app := cli.NewApp()
set := flag.NewFlagSet("test", 0)
test := []string{"blah", "blah"}
set.Parse(test)
c := cli.NewContext(app, set, set)
command := cli.Command{
Name: "test-cmd",
ShortName: "tc",
Usage: "this is for testing",
Description: "testing",
Action: func(_ *cli.Context) {},
SkipFlagParsing: true,
}
err := command.Run(c)
expect(t, err, nil)
}

View File

@ -1,339 +0,0 @@
package cli
import (
"errors"
"flag"
"strconv"
"strings"
"time"
)
// Context is a type that is passed through to
// each Handler action in a cli application. Context
// can be used to retrieve context-specific Args and
// parsed command-line options.
type Context struct {
App *App
Command Command
flagSet *flag.FlagSet
globalSet *flag.FlagSet
setFlags map[string]bool
globalSetFlags map[string]bool
}
// Creates a new context. For use in when invoking an App or Command action.
func NewContext(app *App, set *flag.FlagSet, globalSet *flag.FlagSet) *Context {
return &Context{App: app, flagSet: set, globalSet: globalSet}
}
// Looks up the value of a local int flag, returns 0 if no int flag exists
func (c *Context) Int(name string) int {
return lookupInt(name, c.flagSet)
}
// Looks up the value of a local time.Duration flag, returns 0 if no time.Duration flag exists
func (c *Context) Duration(name string) time.Duration {
return lookupDuration(name, c.flagSet)
}
// Looks up the value of a local float64 flag, returns 0 if no float64 flag exists
func (c *Context) Float64(name string) float64 {
return lookupFloat64(name, c.flagSet)
}
// Looks up the value of a local bool flag, returns false if no bool flag exists
func (c *Context) Bool(name string) bool {
return lookupBool(name, c.flagSet)
}
// Looks up the value of a local boolT flag, returns false if no bool flag exists
func (c *Context) BoolT(name string) bool {
return lookupBoolT(name, c.flagSet)
}
// Looks up the value of a local string flag, returns "" if no string flag exists
func (c *Context) String(name string) string {
return lookupString(name, c.flagSet)
}
// Looks up the value of a local string slice flag, returns nil if no string slice flag exists
func (c *Context) StringSlice(name string) []string {
return lookupStringSlice(name, c.flagSet)
}
// Looks up the value of a local int slice flag, returns nil if no int slice flag exists
func (c *Context) IntSlice(name string) []int {
return lookupIntSlice(name, c.flagSet)
}
// Looks up the value of a local generic flag, returns nil if no generic flag exists
func (c *Context) Generic(name string) interface{} {
return lookupGeneric(name, c.flagSet)
}
// Looks up the value of a global int flag, returns 0 if no int flag exists
func (c *Context) GlobalInt(name string) int {
return lookupInt(name, c.globalSet)
}
// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists
func (c *Context) GlobalDuration(name string) time.Duration {
return lookupDuration(name, c.globalSet)
}
// Looks up the value of a global bool flag, returns false if no bool flag exists
func (c *Context) GlobalBool(name string) bool {
return lookupBool(name, c.globalSet)
}
// Looks up the value of a global string flag, returns "" if no string flag exists
func (c *Context) GlobalString(name string) string {
return lookupString(name, c.globalSet)
}
// Looks up the value of a global string slice flag, returns nil if no string slice flag exists
func (c *Context) GlobalStringSlice(name string) []string {
return lookupStringSlice(name, c.globalSet)
}
// Looks up the value of a global int slice flag, returns nil if no int slice flag exists
func (c *Context) GlobalIntSlice(name string) []int {
return lookupIntSlice(name, c.globalSet)
}
// Looks up the value of a global generic flag, returns nil if no generic flag exists
func (c *Context) GlobalGeneric(name string) interface{} {
return lookupGeneric(name, c.globalSet)
}
// Determines if the flag was actually set
func (c *Context) IsSet(name string) bool {
if c.setFlags == nil {
c.setFlags = make(map[string]bool)
c.flagSet.Visit(func(f *flag.Flag) {
c.setFlags[f.Name] = true
})
}
return c.setFlags[name] == true
}
// Determines if the global flag was actually set
func (c *Context) GlobalIsSet(name string) bool {
if c.globalSetFlags == nil {
c.globalSetFlags = make(map[string]bool)
c.globalSet.Visit(func(f *flag.Flag) {
c.globalSetFlags[f.Name] = true
})
}
return c.globalSetFlags[name] == true
}
// Returns a slice of flag names used in this context.
func (c *Context) FlagNames() (names []string) {
for _, flag := range c.Command.Flags {
name := strings.Split(flag.getName(), ",")[0]
if name == "help" {
continue
}
names = append(names, name)
}
return
}
// Returns a slice of global flag names used by the app.
func (c *Context) GlobalFlagNames() (names []string) {
for _, flag := range c.App.Flags {
name := strings.Split(flag.getName(), ",")[0]
if name == "help" || name == "version" {
continue
}
names = append(names, name)
}
return
}
type Args []string
// Returns the command line arguments associated with the context.
func (c *Context) Args() Args {
args := Args(c.flagSet.Args())
return args
}
// Returns the nth argument, or else a blank string
func (a Args) Get(n int) string {
if len(a) > n {
return a[n]
}
return ""
}
// Returns the first argument, or else a blank string
func (a Args) First() string {
return a.Get(0)
}
// Return the rest of the arguments (not the first one)
// or else an empty string slice
func (a Args) Tail() []string {
if len(a) >= 2 {
return []string(a)[1:]
}
return []string{}
}
// Checks if there are any arguments present
func (a Args) Present() bool {
return len(a) != 0
}
// Swaps arguments at the given indexes
func (a Args) Swap(from, to int) error {
if from >= len(a) || to >= len(a) {
return errors.New("index out of range")
}
a[from], a[to] = a[to], a[from]
return nil
}
func lookupInt(name string, set *flag.FlagSet) int {
f := set.Lookup(name)
if f != nil {
val, err := strconv.Atoi(f.Value.String())
if err != nil {
return 0
}
return val
}
return 0
}
func lookupDuration(name string, set *flag.FlagSet) time.Duration {
f := set.Lookup(name)
if f != nil {
val, err := time.ParseDuration(f.Value.String())
if err == nil {
return val
}
}
return 0
}
func lookupFloat64(name string, set *flag.FlagSet) float64 {
f := set.Lookup(name)
if f != nil {
val, err := strconv.ParseFloat(f.Value.String(), 64)
if err != nil {
return 0
}
return val
}
return 0
}
func lookupString(name string, set *flag.FlagSet) string {
f := set.Lookup(name)
if f != nil {
return f.Value.String()
}
return ""
}
func lookupStringSlice(name string, set *flag.FlagSet) []string {
f := set.Lookup(name)
if f != nil {
return (f.Value.(*StringSlice)).Value()
}
return nil
}
func lookupIntSlice(name string, set *flag.FlagSet) []int {
f := set.Lookup(name)
if f != nil {
return (f.Value.(*IntSlice)).Value()
}
return nil
}
func lookupGeneric(name string, set *flag.FlagSet) interface{} {
f := set.Lookup(name)
if f != nil {
return f.Value
}
return nil
}
func lookupBool(name string, set *flag.FlagSet) bool {
f := set.Lookup(name)
if f != nil {
val, err := strconv.ParseBool(f.Value.String())
if err != nil {
return false
}
return val
}
return false
}
func lookupBoolT(name string, set *flag.FlagSet) bool {
f := set.Lookup(name)
if f != nil {
val, err := strconv.ParseBool(f.Value.String())
if err != nil {
return true
}
return val
}
return false
}
func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) {
switch ff.Value.(type) {
case *StringSlice:
default:
set.Set(name, ff.Value.String())
}
}
func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
visited := make(map[string]bool)
set.Visit(func(f *flag.Flag) {
visited[f.Name] = true
})
for _, f := range flags {
parts := strings.Split(f.getName(), ",")
if len(parts) == 1 {
continue
}
var ff *flag.Flag
for _, name := range parts {
name = strings.Trim(name, " ")
if visited[name] {
if ff != nil {
return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name)
}
ff = set.Lookup(name)
}
}
if ff == nil {
continue
}
for _, name := range parts {
name = strings.Trim(name, " ")
if !visited[name] {
copyFlag(name, ff, set)
}
}
}
return nil
}

View File

@ -1,99 +0,0 @@
package cli_test
import (
"flag"
"testing"
"time"
"github.com/codegangsta/cli"
)
func TestNewContext(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Int("myflag", 12, "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Int("myflag", 42, "doc")
command := cli.Command{Name: "mycommand"}
c := cli.NewContext(nil, set, globalSet)
c.Command = command
expect(t, c.Int("myflag"), 12)
expect(t, c.GlobalInt("myflag"), 42)
expect(t, c.Command.Name, "mycommand")
}
func TestContext_Int(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Int("myflag", 12, "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Int("myflag"), 12)
}
func TestContext_Duration(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Duration("myflag", time.Duration(12*time.Second), "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Duration("myflag"), time.Duration(12*time.Second))
}
func TestContext_String(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.String("myflag", "hello world", "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.String("myflag"), "hello world")
}
func TestContext_Bool(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Bool("myflag"), false)
}
func TestContext_BoolT(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", true, "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.BoolT("myflag"), true)
}
func TestContext_Args(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
c := cli.NewContext(nil, set, set)
set.Parse([]string{"--myflag", "bat", "baz"})
expect(t, len(c.Args()), 2)
expect(t, c.Bool("myflag"), true)
}
func TestContext_IsSet(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
set.String("otherflag", "hello world", "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Bool("myflagGlobal", true, "doc")
c := cli.NewContext(nil, set, globalSet)
set.Parse([]string{"--myflag", "bat", "baz"})
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"})
expect(t, c.IsSet("myflag"), true)
expect(t, c.IsSet("otherflag"), false)
expect(t, c.IsSet("bogusflag"), false)
expect(t, c.IsSet("myflagGlobal"), false)
}
func TestContext_GlobalIsSet(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
set.String("otherflag", "hello world", "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Bool("myflagGlobal", true, "doc")
globalSet.Bool("myflagGlobalUnset", true, "doc")
c := cli.NewContext(nil, set, globalSet)
set.Parse([]string{"--myflag", "bat", "baz"})
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"})
expect(t, c.GlobalIsSet("myflag"), false)
expect(t, c.GlobalIsSet("otherflag"), false)
expect(t, c.GlobalIsSet("bogusflag"), false)
expect(t, c.GlobalIsSet("myflagGlobal"), true)
expect(t, c.GlobalIsSet("myflagGlobalUnset"), false)
expect(t, c.GlobalIsSet("bogusGlobal"), false)
}

View File

@ -1,410 +0,0 @@
package cli
import (
"flag"
"fmt"
"os"
"strconv"
"strings"
"time"
)
// This flag enables bash-completion for all commands and subcommands
var BashCompletionFlag = BoolFlag{
Name: "generate-bash-completion",
}
// This flag prints the version for the application
var VersionFlag = BoolFlag{
Name: "version, v",
Usage: "print the version",
}
// This flag prints the help for all commands and subcommands
var HelpFlag = BoolFlag{
Name: "help, h",
Usage: "show help",
}
// Flag is a common interface related to parsing flags in cli.
// For more advanced flag parsing techniques, it is recomended that
// this interface be implemented.
type Flag interface {
fmt.Stringer
// Apply Flag settings to the given flag set
Apply(*flag.FlagSet)
getName() string
}
func flagSet(name string, flags []Flag) *flag.FlagSet {
set := flag.NewFlagSet(name, flag.ContinueOnError)
for _, f := range flags {
f.Apply(set)
}
return set
}
func eachName(longName string, fn func(string)) {
parts := strings.Split(longName, ",")
for _, name := range parts {
name = strings.Trim(name, " ")
fn(name)
}
}
// Generic is a generic parseable type identified by a specific flag
type Generic interface {
Set(value string) error
String() string
}
// GenericFlag is the flag type for types implementing Generic
type GenericFlag struct {
Name string
Value Generic
Usage string
EnvVar string
}
func (f GenericFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s %v\t`%v` %s", prefixFor(f.Name), f.Name, f.Value, "-"+f.Name+" option -"+f.Name+" option", f.Usage))
}
func (f GenericFlag) Apply(set *flag.FlagSet) {
val := f.Value
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
val.Set(envVal)
}
}
eachName(f.Name, func(name string) {
set.Var(f.Value, name, f.Usage)
})
}
func (f GenericFlag) getName() string {
return f.Name
}
type StringSlice []string
func (f *StringSlice) Set(value string) error {
*f = append(*f, value)
return nil
}
func (f *StringSlice) String() string {
return fmt.Sprintf("%s", *f)
}
func (f *StringSlice) Value() []string {
return *f
}
type StringSliceFlag struct {
Name string
Value *StringSlice
Usage string
EnvVar string
}
func (f StringSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}
func (f StringSliceFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
newVal := &StringSlice{}
for _, s := range strings.Split(envVal, ",") {
newVal.Set(s)
}
f.Value = newVal
}
}
eachName(f.Name, func(name string) {
set.Var(f.Value, name, f.Usage)
})
}
func (f StringSliceFlag) getName() string {
return f.Name
}
type IntSlice []int
func (f *IntSlice) Set(value string) error {
tmp, err := strconv.Atoi(value)
if err != nil {
return err
} else {
*f = append(*f, tmp)
}
return nil
}
func (f *IntSlice) String() string {
return fmt.Sprintf("%d", *f)
}
func (f *IntSlice) Value() []int {
return *f
}
type IntSliceFlag struct {
Name string
Value *IntSlice
Usage string
EnvVar string
}
func (f IntSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}
func (f IntSliceFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
newVal := &IntSlice{}
for _, s := range strings.Split(envVal, ",") {
err := newVal.Set(s)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
}
}
f.Value = newVal
}
}
eachName(f.Name, func(name string) {
set.Var(f.Value, name, f.Usage)
})
}
func (f IntSliceFlag) getName() string {
return f.Name
}
type BoolFlag struct {
Name string
Usage string
EnvVar string
}
func (f BoolFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}
func (f BoolFlag) Apply(set *flag.FlagSet) {
val := false
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
}
}
}
eachName(f.Name, func(name string) {
set.Bool(name, val, f.Usage)
})
}
func (f BoolFlag) getName() string {
return f.Name
}
type BoolTFlag struct {
Name string
Usage string
EnvVar string
}
func (f BoolTFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}
func (f BoolTFlag) Apply(set *flag.FlagSet) {
val := true
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
}
}
}
eachName(f.Name, func(name string) {
set.Bool(name, val, f.Usage)
})
}
func (f BoolTFlag) getName() string {
return f.Name
}
type StringFlag struct {
Name string
Value string
Usage string
EnvVar string
}
func (f StringFlag) String() string {
var fmtString string
fmtString = "%s %v\t%v"
if len(f.Value) > 0 {
fmtString = "%s '%v'\t%v"
} else {
fmtString = "%s %v\t%v"
}
return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
}
func (f StringFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
f.Value = envVal
}
}
eachName(f.Name, func(name string) {
set.String(name, f.Value, f.Usage)
})
}
func (f StringFlag) getName() string {
return f.Name
}
type IntFlag struct {
Name string
Value int
Usage string
EnvVar string
}
func (f IntFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
func (f IntFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValInt, err := strconv.ParseUint(envVal, 10, 64)
if err == nil {
f.Value = int(envValInt)
}
}
}
eachName(f.Name, func(name string) {
set.Int(name, f.Value, f.Usage)
})
}
func (f IntFlag) getName() string {
return f.Name
}
type DurationFlag struct {
Name string
Value time.Duration
Usage string
EnvVar string
}
func (f DurationFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
func (f DurationFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValDuration, err := time.ParseDuration(envVal)
if err == nil {
f.Value = envValDuration
}
}
}
eachName(f.Name, func(name string) {
set.Duration(name, f.Value, f.Usage)
})
}
func (f DurationFlag) getName() string {
return f.Name
}
type Float64Flag struct {
Name string
Value float64
Usage string
EnvVar string
}
func (f Float64Flag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
func (f Float64Flag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValFloat, err := strconv.ParseFloat(envVal, 10)
if err == nil {
f.Value = float64(envValFloat)
}
}
}
eachName(f.Name, func(name string) {
set.Float64(name, f.Value, f.Usage)
})
}
func (f Float64Flag) getName() string {
return f.Name
}
func prefixFor(name string) (prefix string) {
if len(name) == 1 {
prefix = "-"
} else {
prefix = "--"
}
return
}
func prefixedNames(fullName string) (prefixed string) {
parts := strings.Split(fullName, ",")
for i, name := range parts {
name = strings.Trim(name, " ")
prefixed += prefixFor(name) + name
if i < len(parts)-1 {
prefixed += ", "
}
}
return
}
func withEnvHint(envVar, str string) string {
envText := ""
if envVar != "" {
envText = fmt.Sprintf(" [$%s]", envVar)
}
return str + envText
}

View File

@ -1,587 +0,0 @@
package cli_test
import (
"fmt"
"os"
"reflect"
"strings"
"testing"
"github.com/codegangsta/cli"
)
var boolFlagTests = []struct {
name string
expected string
}{
{"help", "--help\t"},
{"h", "-h\t"},
}
func TestBoolFlagHelpOutput(t *testing.T) {
for _, test := range boolFlagTests {
flag := cli.BoolFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
var stringFlagTests = []struct {
name string
value string
expected string
}{
{"help", "", "--help \t"},
{"h", "", "-h \t"},
{"h", "", "-h \t"},
{"test", "Something", "--test 'Something'\t"},
}
func TestStringFlagHelpOutput(t *testing.T) {
for _, test := range stringFlagTests {
flag := cli.StringFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestStringFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_FOO", "derp")
for _, test := range stringFlagTests {
flag := cli.StringFlag{Name: test.name, Value: test.value, EnvVar: "APP_FOO"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_FOO]") {
t.Errorf("%s does not end with [$APP_FOO]", output)
}
}
}
var stringSliceFlagTests = []struct {
name string
value *cli.StringSlice
expected string
}{
{"help", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("")
return s
}(), "--help '--help option --help option'\t"},
{"h", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("")
return s
}(), "-h '-h option -h option'\t"},
{"h", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("")
return s
}(), "-h '-h option -h option'\t"},
{"test", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("Something")
return s
}(), "--test '--test option --test option'\t"},
}
func TestStringSliceFlagHelpOutput(t *testing.T) {
for _, test := range stringSliceFlagTests {
flag := cli.StringSliceFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestStringSliceFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_QWWX", "11,4")
for _, test := range stringSliceFlagTests {
flag := cli.StringSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_QWWX"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_QWWX]") {
t.Errorf("%q does not end with [$APP_QWWX]", output)
}
}
}
var intFlagTests = []struct {
name string
expected string
}{
{"help", "--help '0'\t"},
{"h", "-h '0'\t"},
}
func TestIntFlagHelpOutput(t *testing.T) {
for _, test := range intFlagTests {
flag := cli.IntFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestIntFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_BAR", "2")
for _, test := range intFlagTests {
flag := cli.IntFlag{Name: test.name, EnvVar: "APP_BAR"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAR]") {
t.Errorf("%s does not end with [$APP_BAR]", output)
}
}
}
var durationFlagTests = []struct {
name string
expected string
}{
{"help", "--help '0'\t"},
{"h", "-h '0'\t"},
}
func TestDurationFlagHelpOutput(t *testing.T) {
for _, test := range durationFlagTests {
flag := cli.DurationFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestDurationFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_BAR", "2h3m6s")
for _, test := range durationFlagTests {
flag := cli.DurationFlag{Name: test.name, EnvVar: "APP_BAR"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAR]") {
t.Errorf("%s does not end with [$APP_BAR]", output)
}
}
}
var intSliceFlagTests = []struct {
name string
value *cli.IntSlice
expected string
}{
{"help", &cli.IntSlice{}, "--help '--help option --help option'\t"},
{"h", &cli.IntSlice{}, "-h '-h option -h option'\t"},
{"h", &cli.IntSlice{}, "-h '-h option -h option'\t"},
{"test", func() *cli.IntSlice {
i := &cli.IntSlice{}
i.Set("9")
return i
}(), "--test '--test option --test option'\t"},
}
func TestIntSliceFlagHelpOutput(t *testing.T) {
for _, test := range intSliceFlagTests {
flag := cli.IntSliceFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestIntSliceFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_SMURF", "42,3")
for _, test := range intSliceFlagTests {
flag := cli.IntSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_SMURF"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_SMURF]") {
t.Errorf("%q does not end with [$APP_SMURF]", output)
}
}
}
var float64FlagTests = []struct {
name string
expected string
}{
{"help", "--help '0'\t"},
{"h", "-h '0'\t"},
}
func TestFloat64FlagHelpOutput(t *testing.T) {
for _, test := range float64FlagTests {
flag := cli.Float64Flag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestFloat64FlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_BAZ", "99.4")
for _, test := range float64FlagTests {
flag := cli.Float64Flag{Name: test.name, EnvVar: "APP_BAZ"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAZ]") {
t.Errorf("%s does not end with [$APP_BAZ]", output)
}
}
}
var genericFlagTests = []struct {
name string
value cli.Generic
expected string
}{
{"help", &Parser{}, "--help <nil>\t`-help option -help option` "},
{"h", &Parser{}, "-h <nil>\t`-h option -h option` "},
{"test", &Parser{}, "--test <nil>\t`-test option -test option` "},
}
func TestGenericFlagHelpOutput(t *testing.T) {
for _, test := range genericFlagTests {
flag := cli.GenericFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestGenericFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_ZAP", "3")
for _, test := range genericFlagTests {
flag := cli.GenericFlag{Name: test.name, EnvVar: "APP_ZAP"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_ZAP]") {
t.Errorf("%s does not end with [$APP_ZAP]", output)
}
}
}
func TestParseMultiString(t *testing.T) {
(&cli.App{
Flags: []cli.Flag{
cli.StringFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.String("serve") != "10" {
t.Errorf("main name not set")
}
if ctx.String("s") != "10" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10"})
}
func TestParseMultiStringFromEnv(t *testing.T) {
os.Setenv("APP_COUNT", "20")
(&cli.App{
Flags: []cli.Flag{
cli.StringFlag{Name: "count, c", EnvVar: "APP_COUNT"},
},
Action: func(ctx *cli.Context) {
if ctx.String("count") != "20" {
t.Errorf("main name not set")
}
if ctx.String("c") != "20" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run"})
}
func TestParseMultiStringSlice(t *testing.T) {
(&cli.App{
Flags: []cli.Flag{
cli.StringSliceFlag{Name: "serve, s", Value: &cli.StringSlice{}},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.StringSlice("serve"), []string{"10", "20"}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.StringSlice("s"), []string{"10", "20"}) {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10", "-s", "20"})
}
func TestParseMultiStringSliceFromEnv(t *testing.T) {
os.Setenv("APP_INTERVALS", "20,30,40")
(&cli.App{
Flags: []cli.Flag{
cli.StringSliceFlag{Name: "intervals, i", Value: &cli.StringSlice{}, EnvVar: "APP_INTERVALS"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiInt(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.IntFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.Int("serve") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("s") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10"})
}
func TestParseMultiIntFromEnv(t *testing.T) {
os.Setenv("APP_TIMEOUT_SECONDS", "10")
a := cli.App{
Flags: []cli.Flag{
cli.IntFlag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *cli.Context) {
if ctx.Int("timeout") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("t") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiIntSlice(t *testing.T) {
(&cli.App{
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "serve, s", Value: &cli.IntSlice{}},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.IntSlice("serve"), []int{10, 20}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.IntSlice("s"), []int{10, 20}) {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10", "-s", "20"})
}
func TestParseMultiIntSliceFromEnv(t *testing.T) {
os.Setenv("APP_INTERVALS", "20,30,40")
(&cli.App{
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "intervals, i", Value: &cli.IntSlice{}, EnvVar: "APP_INTERVALS"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiFloat64(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.Float64Flag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.Float64("serve") != 10.2 {
t.Errorf("main name not set")
}
if ctx.Float64("s") != 10.2 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10.2"})
}
func TestParseMultiFloat64FromEnv(t *testing.T) {
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
a := cli.App{
Flags: []cli.Flag{
cli.Float64Flag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *cli.Context) {
if ctx.Float64("timeout") != 15.5 {
t.Errorf("main name not set")
}
if ctx.Float64("t") != 15.5 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBool(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.BoolFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.Bool("serve") != true {
t.Errorf("main name not set")
}
if ctx.Bool("s") != true {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "--serve"})
}
func TestParseMultiBoolFromEnv(t *testing.T) {
os.Setenv("APP_DEBUG", "1")
a := cli.App{
Flags: []cli.Flag{
cli.BoolFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
},
Action: func(ctx *cli.Context) {
if ctx.Bool("debug") != true {
t.Errorf("main name not set from env")
}
if ctx.Bool("d") != true {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBoolT(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.BoolTFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.BoolT("serve") != true {
t.Errorf("main name not set")
}
if ctx.BoolT("s") != true {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "--serve"})
}
func TestParseMultiBoolTFromEnv(t *testing.T) {
os.Setenv("APP_DEBUG", "0")
a := cli.App{
Flags: []cli.Flag{
cli.BoolTFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
},
Action: func(ctx *cli.Context) {
if ctx.BoolT("debug") != false {
t.Errorf("main name not set from env")
}
if ctx.BoolT("d") != false {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
type Parser [2]string
func (p *Parser) Set(value string) error {
parts := strings.Split(value, ",")
if len(parts) != 2 {
return fmt.Errorf("invalid format")
}
(*p)[0] = parts[0]
(*p)[1] = parts[1]
return nil
}
func (p *Parser) String() string {
return fmt.Sprintf("%s,%s", p[0], p[1])
}
func TestParseGeneric(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.GenericFlag{Name: "serve, s", Value: &Parser{}},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"10", "20"}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"10", "20"}) {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10,20"})
}
func TestParseGenericFromEnv(t *testing.T) {
os.Setenv("APP_SERVE", "20,30")
a := cli.App{
Flags: []cli.Flag{
cli.GenericFlag{Name: "serve, s", Value: &Parser{}, EnvVar: "APP_SERVE"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"20", "30"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"20", "30"}) {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}

View File

@ -1,224 +0,0 @@
package cli
import (
"fmt"
"os"
"text/tabwriter"
"text/template"
)
// The text template for the Default help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var AppHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
USAGE:
{{.Name}} {{if .Flags}}[global options] {{end}}command{{if .Flags}} [command options]{{end}} [arguments...]
VERSION:
{{.Version}}{{if or .Author .Email}}
AUTHOR:{{if .Author}}
{{.Author}}{{if .Email}} - <{{.Email}}>{{end}}{{else}}
{{.Email}}{{end}}{{end}}
COMMANDS:
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}
{{end}}{{if .Flags}}
GLOBAL OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{end}}
`
// The text template for the command help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var CommandHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
USAGE:
command {{.Name}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}}
DESCRIPTION:
{{.Description}}{{end}}{{if .Flags}}
OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{ end }}
`
// The text template for the subcommand help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var SubcommandHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
USAGE:
{{.Name}} command{{if .Flags}} [command options]{{end}} [arguments...]
COMMANDS:
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}
{{end}}{{if .Flags}}
OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{end}}
`
var helpCommand = Command{
Name: "help",
ShortName: "h",
Usage: "Shows a list of commands or help for one command",
Action: func(c *Context) {
args := c.Args()
if args.Present() {
ShowCommandHelp(c, args.First())
} else {
ShowAppHelp(c)
}
},
}
var helpSubcommand = Command{
Name: "help",
ShortName: "h",
Usage: "Shows a list of commands or help for one command",
Action: func(c *Context) {
args := c.Args()
if args.Present() {
ShowCommandHelp(c, args.First())
} else {
ShowSubcommandHelp(c)
}
},
}
// Prints help for the App
var HelpPrinter = printHelp
// Prints version for the App
var VersionPrinter = printVersion
func ShowAppHelp(c *Context) {
HelpPrinter(AppHelpTemplate, c.App)
}
// Prints the list of subcommands as the default app completion method
func DefaultAppComplete(c *Context) {
for _, command := range c.App.Commands {
fmt.Println(command.Name)
if command.ShortName != "" {
fmt.Println(command.ShortName)
}
}
}
// Prints help for the given command
func ShowCommandHelp(c *Context, command string) {
for _, c := range c.App.Commands {
if c.HasName(command) {
HelpPrinter(CommandHelpTemplate, c)
return
}
}
if c.App.CommandNotFound != nil {
c.App.CommandNotFound(c, command)
} else {
fmt.Printf("No help topic for '%v'\n", command)
}
}
// Prints help for the given subcommand
func ShowSubcommandHelp(c *Context) {
HelpPrinter(SubcommandHelpTemplate, c.App)
}
// Prints the version number of the App
func ShowVersion(c *Context) {
VersionPrinter(c)
}
func printVersion(c *Context) {
fmt.Printf("%v version %v\n", c.App.Name, c.App.Version)
}
// Prints the lists of commands within a given context
func ShowCompletions(c *Context) {
a := c.App
if a != nil && a.BashComplete != nil {
a.BashComplete(c)
}
}
// Prints the custom completions for a given command
func ShowCommandCompletions(ctx *Context, command string) {
c := ctx.App.Command(command)
if c != nil && c.BashComplete != nil {
c.BashComplete(ctx)
}
}
func printHelp(templ string, data interface{}) {
w := tabwriter.NewWriter(os.Stdout, 0, 8, 1, '\t', 0)
t := template.Must(template.New("help").Parse(templ))
err := t.Execute(w, data)
if err != nil {
panic(err)
}
w.Flush()
}
func checkVersion(c *Context) bool {
if c.GlobalBool("version") {
ShowVersion(c)
return true
}
return false
}
func checkHelp(c *Context) bool {
if c.GlobalBool("h") || c.GlobalBool("help") {
ShowAppHelp(c)
return true
}
return false
}
func checkCommandHelp(c *Context, name string) bool {
if c.Bool("h") || c.Bool("help") {
ShowCommandHelp(c, name)
return true
}
return false
}
func checkSubcommandHelp(c *Context) bool {
if c.GlobalBool("h") || c.GlobalBool("help") {
ShowSubcommandHelp(c)
return true
}
return false
}
func checkCompletions(c *Context) bool {
if (c.GlobalBool(BashCompletionFlag.Name) || c.Bool(BashCompletionFlag.Name)) && c.App.EnableBashCompletion {
ShowCompletions(c)
return true
}
return false
}
func checkCommandCompletions(c *Context, name string) bool {
if c.Bool(BashCompletionFlag.Name) && c.App.EnableBashCompletion {
ShowCommandCompletions(c, name)
return true
}
return false
}

View File

@ -1,19 +0,0 @@
package cli_test
import (
"reflect"
"testing"
)
/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}
func refute(t *testing.T, a interface{}, b interface{}) {
if a == b {
t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}

50
Godeps/_workspace/src/github.com/dalu/slug/README.md generated vendored Normal file
View File

@ -0,0 +1,50 @@
slug
====
Package `slug` generate slug from unicode string, URL-friendly slugify with
multiple languages support.
[![GoDoc](https://godoc.org/github.com/dalu/slug?status.png)](https://godoc.org/github.com/dalu/slug)
[Documentation online](http://godoc.org/github.com/dalu/slug)
## Example
package main
import(
"github.com/gosimple/slug"
"fmt"
)
func main () {
text := slug.Make("Hellö Wörld хелло ворлд")
fmt.Println(text) // Will print hello-world-khello-vorld
someText := slug.Make("影師")
fmt.Println(someText) // Will print: ying-shi
enText := slug.MakeLang("This & that", "en")
fmt.Println(enText) // Will print 'this-and-that'
deText := slug.MakeLang("Diese & Dass", "de")
fmt.Println(deText) // Will print 'diese-und-dass'
slug.CustomSub = map[string]string{
"water": "sand",
}
textSub := slug.Make("water is hot")
fmt.Println(textSub) // Will print 'sand-is-hot'
}
## Installation
go get -u github.com/dalu/slug
## License
The source files are distributed under the
[Mozilla Public License, version 2.0](http://mozilla.org/MPL/2.0/),
unless otherwise noted.
Please read the [FAQ](http://www.mozilla.org/MPL/2.0/FAQ.html)
if you have further questions regarding the license.

View File

@ -0,0 +1,16 @@
// Copyright 2013 by Dobrosław Żybort. All rights reserved.
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package slug
var defaultSub = map[rune]string{
'"': "",
'\'': "",
'': "",
'': "-", // figure dash
'': "-", // en dash
'—': "-", // em dash
'―': "-", // horizontal bar
}

39
Godeps/_workspace/src/github.com/dalu/slug/doc.go generated vendored Normal file
View File

@ -0,0 +1,39 @@
// Copyright 2013 by Dobrosław Żybort. All rights reserved.
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
/*
Package slug generate slug from unicode string, URL-friendly slugify with
multiple languages support.
Example:
package main
import(
"github.com/dalu/slug"
"fmt"
)
func main () {
text := slug.Make("Hellö Wörld хелло ворлд")
fmt.Println(text) // Will print hello-world-khello-vorld
someText := slug.Make("影師")
fmt.Println(someText) // Will print: ying-shi
enText := slug.MakeLang("This & that", "en")
fmt.Println(enText) // Will print 'this-and-that'
deText := slug.MakeLang("Diese & Dass", "de")
fmt.Println(deText) // Will print 'diese-und-dass'
slug.CustomSub = map[string]string{
"water": "sand",
}
textSub := slug.Make("water is hot")
fmt.Println(textSub) // Will print 'sand-is-hot'
}
*/
package slug

View File

@ -0,0 +1,26 @@
// Copyright 2013 by Dobrosław Żybort. All rights reserved.
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package slug
var deSub = map[rune]string{
'&': "und",
'@': "an",
}
var enSub = map[rune]string{
'&': "and",
'@': "at",
}
var plSub = map[rune]string{
'&': "i",
'@': "na",
}
var esSub = map[rune]string{
'&': "y",
'@': "en",
}

122
Godeps/_workspace/src/github.com/dalu/slug/slug.go generated vendored Normal file
View File

@ -0,0 +1,122 @@
// Copyright 2013 by Dobrosław Żybort. All rights reserved.
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package slug
import (
"github.com/dalu/unidecode"
"regexp"
"strings"
)
var (
// Custom substitution map
CustomSub map[string]string
// Custom rune substitution map
CustomRuneSub map[rune]string
// Maximum slug length. It's smart so it will cat slug after full word.
// By default slugs aren't shortened.
// If MaxLength is smaller than length of the first word, then returned
// slug will contain only substring from the first word truncated
// after MaxLength.
MaxLength int
)
//=============================================================================
// Make returns slug generated from provided string. Will use "en" as language
// substitution.
func Make(s string) (slug string) {
return MakeLang(s, "en")
}
// MakeLang returns slug generated from provided string and will use provided
// language for chars substitution.
func MakeLang(s string, lang string) (slug string) {
slug = strings.TrimSpace(s)
// Custom substitutions
// Always substitute runes first
slug = SubstituteRune(slug, CustomRuneSub)
slug = Substitute(slug, CustomSub)
// Process string with selected substitution language
switch lang {
case "de":
slug = SubstituteRune(slug, deSub)
case "en":
slug = SubstituteRune(slug, enSub)
case "pl":
slug = SubstituteRune(slug, plSub)
case "es":
slug = SubstituteRune(slug, esSub)
default: // fallback to "en" if lang not found
slug = SubstituteRune(slug, enSub)
}
slug = SubstituteRune(slug, defaultSub)
// Process all non ASCII symbols
slug = unidecode.Unidecode(slug)
slug = strings.ToLower(slug)
// Process all remaining symbols
slug = regexp.MustCompile("[^a-z0-9-_]").ReplaceAllString(slug, "-")
slug = regexp.MustCompile("-+").ReplaceAllString(slug, "-")
slug = strings.Trim(slug, "-")
if MaxLength > 0 {
slug = smartTruncate(slug)
}
return slug
}
// Substitute returns string with superseded all substrings from
// provided substitution map.
func Substitute(s string, sub map[string]string) (buf string) {
buf = s
for key, val := range sub {
buf = strings.Replace(s, key, val, -1)
}
return
}
// SubstituteRune substitutes string chars with provided rune
// substitution map.
func SubstituteRune(s string, sub map[rune]string) (buf string) {
for _, c := range s {
if d, ok := sub[c]; ok {
buf += d
} else {
buf += string(c)
}
}
return
}
func smartTruncate(text string) string {
if len(text) < MaxLength {
return text
}
var truncated string
words := strings.SplitAfter(text, "-")
// If MaxLength is smaller than length of the first word return word
// truncated after MaxLength.
if len(words[0]) > MaxLength {
return words[0][:MaxLength]
}
for _, word := range words {
if len(truncated)+len(word)-1 <= MaxLength {
truncated = truncated + word
} else {
break
}
}
return strings.Trim(truncated, "-")
}

337
Godeps/_workspace/src/github.com/dalu/slug/slug_test.go generated vendored Normal file
View File

@ -0,0 +1,337 @@
// Copyright 2013 by Dobrosław Żybort. All rights reserved.
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package slug
import (
"testing"
)
//=============================================================================
func TestSlugMake(t *testing.T) {
var testCases = []struct {
in string
want string
}{
{"DOBROSLAWZYBORT", "dobroslawzybort"},
{"Dobroslaw Zybort", "dobroslaw-zybort"},
{" Dobroslaw Zybort ?", "dobroslaw-zybort"},
{"Dobrosław Żybort", "dobroslaw-zybort"},
{"Ala ma 6 kotów.", "ala-ma-6-kotow"},
{"áÁàÀãÃâÂäÄąĄą̊Ą̊", "aaaaaaaaaaaaaa"},
{"ćĆĉĈçÇ", "cccccc"},
{"éÉèÈẽẼêÊëËęĘ", "eeeeeeeeeeee"},
{"íÍìÌĩĨîÎïÏįĮ", "iiiiiiiiiiii"},
{"łŁ", "ll"},
{"ńŃ", "nn"},
{"óÓòÒõÕôÔöÖǫǪǭǬø", "ooooooooooooooo"},
{"śŚ", "ss"},
{"úÚùÙũŨûÛüÜųŲ", "uuuuuuuuuuuu"},
{"y̨Y̨", "yy"},
{"źŹżŹ", "zzzz"},
{"·/,:;`˜'\"", ""},
{"20002013", "2000-2013"},
{"style—not", "style-not"},
{"test_slug", "test_slug"},
{"Æ", "ae"},
{"Ich heiße", "ich-heisse"},
{"This & that", "this-and-that"},
{"fácil €", "facil-eu"},
{"smile ☺", "smile"},
{"Hellö Wörld хелло ворлд", "hello-world-khello-vorld"},
{"\"C'est déjà lété.\"", "cest-deja-lete"},
{"jaja---lol-méméméoo--a", "jaja-lol-mememeoo-a"},
{"影師", "ying-shi"},
}
for index, st := range testCases {
got := Make(st.in)
if got != st.want {
t.Errorf(
"%d. Make(%#v) = %#v; want %#v",
index, st.in, got, st.want)
}
}
}
func TestSlugMakeLang(t *testing.T) {
var testCases = []struct {
lang string
in string
want string
}{
{"en", "This & that", "this-and-that"},
{"de", "This & that", "this-und-that"},
{"pl", "This & that", "this-i-that"},
{"es", "This & that", "this-y-that"},
{"test", "This & that", "this-and-that"}, // unknown lang, fallback to "en"
}
for index, smlt := range testCases {
got := MakeLang(smlt.in, smlt.lang)
if got != smlt.want {
t.Errorf(
"%d. MakeLang(%#v, %#v) = %#v; want %#v",
index, smlt.in, smlt.lang, got, smlt.want)
}
}
}
func TestSlugMakeUserSubstituteLang(t *testing.T) {
var testCases = []struct {
cSub map[string]string
lang string
in string
want string
}{
{map[string]string{"'": " "}, "en", "That's great", "that-s-great"},
{map[string]string{"&": "or"}, "en", "This & that", "this-or-that"}, // by default "&" => "and"
{map[string]string{"&": "or"}, "de", "This & that", "this-or-that"}, // by default "&" => "und"
}
for index, smust := range testCases {
CustomSub = smust.cSub
got := MakeLang(smust.in, smust.lang)
if got != smust.want {
t.Errorf(
"%d. %#v; MakeLang(%#v, %#v) = %#v; want %#v",
index, smust.cSub, smust.in, smust.lang,
got, smust.want)
}
}
}
func TestSlugMakeSubstituteOrderLang(t *testing.T) {
// Always substitute runes first
var testCases = []struct {
rSub map[rune]string
sSub map[string]string
in string
want string
}{
{map[rune]string{'o': "left"}, map[string]string{"o": "right"}, "o o", "left-left"},
{map[rune]string{'&': "down"}, map[string]string{"&": "up"}, "&", "down"},
}
for index, smsot := range testCases {
CustomRuneSub = smsot.rSub
CustomSub = smsot.sSub
got := Make(smsot.in)
if got != smsot.want {
t.Errorf(
"%d. %#v; %#v; Make(%#v) = %#v; want %#v",
index, smsot.rSub, smsot.sSub, smsot.in,
got, smsot.want)
}
}
}
func TestSubstituteLang(t *testing.T) {
var testCases = []struct {
cSub map[string]string
in string
want string
}{
{map[string]string{"o": "no"}, "o o o", "no no no"},
{map[string]string{"'": " "}, "That's great", "That s great"},
}
for index, sst := range testCases {
got := Substitute(sst.in, sst.cSub)
if got != sst.want {
t.Errorf(
"%d. Substitute(%#v, %#v) = %#v; want %#v",
index, sst.in, sst.cSub, got, sst.want)
}
}
}
func TestSubstituteRuneLang(t *testing.T) {
var testCases = []struct {
cSub map[rune]string
in string
want string
}{
{map[rune]string{'o': "no"}, "o o o", "no no no"},
{map[rune]string{'\'': " "}, "That's great", "That s great"},
}
for index, ssrt := range testCases {
got := SubstituteRune(ssrt.in, ssrt.cSub)
if got != ssrt.want {
t.Errorf(
"%d. SubstituteRune(%#v, %#v) = %#v; want %#v",
index, ssrt.in, ssrt.cSub, got, ssrt.want)
}
}
}
func TestSlugMakeSmartTruncate(t *testing.T) {
var testCases = []struct {
in string
maxLength int
want string
}{
{"DOBROSLAWZYBORT", 100, "dobroslawzybort"},
{"Dobroslaw Zybort", 100, "dobroslaw-zybort"},
{"Dobroslaw Zybort", 12, "dobroslaw"},
{" Dobroslaw Zybort ?", 12, "dobroslaw"},
{"Ala ma 6 kotów.", 10, "ala-ma-6"},
{"Dobrosław Żybort", 5, "dobro"},
}
for index, smstt := range testCases {
MaxLength = smstt.maxLength
got := Make(smstt.in)
if got != smstt.want {
t.Errorf(
"%d. MaxLength = %v; Make(%#v) = %#v; want %#v",
index, smstt.maxLength, smstt.in, got, smstt.want)
}
}
}
func BenchmarkMakeShortAscii(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
Make("Hello world")
}
}
func BenchmarkMakeShort(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
Make("хелло ворлд")
}
}
func BenchmarkMakeShortSymbols(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
Make("·/,:;`˜'\" &€£¥")
}
}
func BenchmarkMakeMediumAscii(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
Make("ABCDE FGHIJ KLMNO PQRST UWXYZ ABCDE FGHIJ KLMNO PQRST UWXYZ ABCDE")
}
}
func BenchmarkMakeMedium(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
Make("ヲァィゥェ ォャュョッ ーアイウエ オカキクケ コサシスセ ソタチツテ トナニヌネ ノハヒフヘ ホマミムメ モヤユヨラ リルレロワ")
}
}
func BenchmarkMakeLongAscii(b *testing.B) {
longStr := "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi " +
"pulvinar sodales ultrices. Nulla facilisi. Sed at vestibulum erat. Ut " +
"sit amet urna posuere, sagittis eros ac, varius nisi. Morbi ullamcorper " +
"odio at nunc pulvinar mattis. Vestibulum rutrum, ante eu dictum mattis, " +
"elit risus finibus nunc, consectetur facilisis eros leo ut sapien. Sed " +
"pulvinar volutpat mi. Cras semper mi ac eros accumsan, at feugiat massa " +
"elementum. Morbi eget dolor sit amet purus condimentum egestas non ut " +
"sapien. Duis feugiat magna vitae nisi lobortis, quis finibus sem " +
"sollicitudin. Pellentesque eleifend blandit ipsum, ut porta arcu " +
"ultricies et. Fusce vel ipsum porta, placerat diam ac, consectetur " +
"magna. Nulla in porta sem. Suspendisse commodo, felis in molestie " +
"ultricies, arcu ipsum aliquet turpis, elementum dapibus ipsum lorem a " +
"nisl. Etiam varius imperdiet placerat. Aliquam euismod lacus arcu, " +
"ultrices hendrerit est pellentesque vel. Aliquam sit amet laoreet leo. " +
"Integer eros libero, mollis sed posuere."
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
Make(longStr)
}
}
func BenchmarkSubstituteRuneShort(b *testing.B) {
shortStr := "Hello/Hi world"
subs := map[rune]string{'o': "no", '/': "slash"}
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
SubstituteRune(shortStr, subs)
}
}
func BenchmarkSubstituteRuneLong(b *testing.B) {
longStr := "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi " +
"pulvinar sodales ultrices. Nulla facilisi. Sed at vestibulum erat. Ut " +
"sit amet urna posuere, sagittis eros ac, varius nisi. Morbi ullamcorper " +
"odio at nunc pulvinar mattis. Vestibulum rutrum, ante eu dictum mattis, " +
"elit risus finibus nunc, consectetur facilisis eros leo ut sapien. Sed " +
"pulvinar volutpat mi. Cras semper mi ac eros accumsan, at feugiat massa " +
"elementum. Morbi eget dolor sit amet purus condimentum egestas non ut " +
"sapien. Duis feugiat magna vitae nisi lobortis, quis finibus sem " +
"sollicitudin. Pellentesque eleifend blandit ipsum, ut porta arcu " +
"ultricies et. Fusce vel ipsum porta, placerat diam ac, consectetur " +
"magna. Nulla in porta sem. Suspendisse commodo, felis in molestie " +
"ultricies, arcu ipsum aliquet turpis, elementum dapibus ipsum lorem a " +
"nisl. Etiam varius imperdiet placerat. Aliquam euismod lacus arcu, " +
"ultrices hendrerit est pellentesque vel. Aliquam sit amet laoreet leo. " +
"Integer eros libero, mollis sed posuere."
subs := map[rune]string{
'o': "no",
'/': "slash",
'i': "done",
'E': "es",
'a': "ASD",
'1': "one",
'l': "onetwo",
}
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
SubstituteRune(longStr, subs)
}
}
func BenchmarkSmartTruncateShort(b *testing.B) {
shortStr := "Hello-world"
MaxLength = 8
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
smartTruncate(shortStr)
}
}
func BenchmarkSmartTruncateLong(b *testing.B) {
longStr := "Lorem-ipsum-dolor-sit-amet,-consectetur-adipiscing-elit.-Morbi-" +
"pulvinar-sodales-ultrices.-Nulla-facilisi.-Sed-at-vestibulum-erat.-Ut-" +
"sit-amet-urna-posuere,-sagittis-eros-ac,-varius-nisi.-Morbi-ullamcorper-" +
"odio-at-nunc-pulvinar-mattis.-Vestibulum-rutrum,-ante-eu-dictum-mattis,-" +
"elit-risus-finibus-nunc,-consectetur-facilisis-eros-leo-ut-sapien.-Sed-" +
"pulvinar-volutpat-mi.-Cras-semper-mi-ac-eros-accumsan,-at-feugiat-massa-" +
"elementum.-Morbi-eget-dolor-sit-amet-purus-condimentum-egestas-non-ut-" +
"sapien.-Duis-feugiat-magna-vitae-nisi-lobortis,-quis-finibus-sem-" +
"sollicitudin.-Pellentesque-eleifend-blandit-ipsum,-ut-porta-arcu-" +
"ultricies-et.-Fusce-vel-ipsum-porta,-placerat-diam-ac,-consectetur-" +
"magna.-Nulla-in-porta-sem.-Suspendisse-commodo,-felis-in-molestie-" +
"ultricies,-arcu-ipsum-aliquet-turpis,-elementum-dapibus-ipsum-lorem-a-" +
"nisl.-Etiam-varius-imperdiet-placerat.-Aliquam-euismod-lacus-arcu,-" +
"ultrices-hendrerit-est-pellentesque-vel.-Aliquam-sit-amet-laoreet-leo.-" +
"Integer-eros-libero,-mollis-sed-posuere."
MaxLength = 256
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
smartTruncate(longStr)
}
}

View File

@ -0,0 +1,23 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test

201
Godeps/_workspace/src/github.com/dalu/unidecode/LICENSE generated vendored Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,6 @@
unidecode
=========
Unicode transliterator in Golang - Replaces non-ASCII characters with their ASCII approximations.
View other available versions, documentation and examples at http://gopkgs.com/unidecode

View File

@ -0,0 +1,44 @@
package unidecode
import (
"compress/zlib"
"encoding/binary"
"io"
"strings"
"sync"
)
var (
decoded = false
mutex sync.Mutex
transliterations [65536][]rune
transCount = rune(len(transliterations))
getUint16 = binary.LittleEndian.Uint16
)
func decodeTransliterations() {
r, err := zlib.NewReader(strings.NewReader(tableData))
if err != nil {
panic(err)
}
defer r.Close()
tmp1 := make([]byte, 2)
tmp2 := tmp1[:1]
for {
if _, err := io.ReadAtLeast(r, tmp1, 2); err != nil {
if err == io.EOF {
break
}
panic(err)
}
chr := getUint16(tmp1)
if _, err := io.ReadAtLeast(r, tmp2, 1); err != nil {
panic(err)
}
b := make([]byte, int(tmp2[0]))
if _, err := io.ReadFull(r, b); err != nil {
panic(err)
}
transliterations[int(chr)] = []rune(string(b))
}
}

View File

@ -0,0 +1,71 @@
// +build none
package main
import (
"bytes"
"compress/zlib"
"encoding/binary"
"fmt"
"go/format"
"io/ioutil"
"strconv"
"strings"
)
func main() {
data, err := ioutil.ReadFile("table.txt")
if err != nil {
panic(err)
}
var buf bytes.Buffer
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "/*") || line == "" {
continue
}
sep := strings.IndexByte(line, ':')
if sep == -1 {
panic(line)
}
val, err := strconv.ParseInt(line[:sep], 0, 32)
if err != nil {
panic(err)
}
s, err := strconv.Unquote(line[sep+2:])
if err != nil {
panic(err)
}
if s == "" {
continue
}
if err := binary.Write(&buf, binary.LittleEndian, uint16(val)); err != nil {
panic(err)
}
if err := binary.Write(&buf, binary.LittleEndian, uint8(len(s))); err != nil {
panic(err)
}
buf.WriteString(s)
}
var cbuf bytes.Buffer
w, err := zlib.NewWriterLevel(&cbuf, zlib.BestCompression)
if err != nil {
panic(err)
}
if _, err := w.Write(buf.Bytes()); err != nil {
panic(err)
}
if err := w.Close(); err != nil {
panic(err)
}
buf.Reset()
buf.WriteString("package unidecode\n")
buf.WriteString("// AUTOGENERATED - DO NOT EDIT!\n\n")
fmt.Fprintf(&buf, "var tableData = %q;\n", cbuf.String())
dst, err := format.Source(buf.Bytes())
if err != nil {
panic(err)
}
if err := ioutil.WriteFile("table.go", dst, 0644); err != nil {
panic(err)
}
}

File diff suppressed because one or more lines are too long

46731
Godeps/_workspace/src/github.com/dalu/unidecode/table.txt generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,63 @@
// Package unidecode implements a unicode transliterator
// which replaces non-ASCII characters with their ASCII
// approximations.
package unidecode
import (
"unicode"
"gopkgs.com/pool.v1"
)
const pooledCapacity = 64
var (
slicePool = pool.New(0)
)
// Unidecode implements a unicode transliterator, which
// replaces non-ASCII characters with their ASCII
// counterparts.
// Given an unicode encoded string, returns
// another string with non-ASCII characters replaced
// with their closest ASCII counterparts.
// e.g. Unicode("áéíóú") => "aeiou"
func Unidecode(s string) string {
if !decoded {
mutex.Lock()
if !decoded {
decodeTransliterations()
decoded = true
}
mutex.Unlock()
}
l := len(s)
var r []rune
if l > pooledCapacity {
r = make([]rune, 0, len(s))
} else {
if x := slicePool.Get(); x != nil {
r = x.([]rune)[:0]
} else {
r = make([]rune, 0, pooledCapacity)
}
}
for _, c := range s {
if c <= unicode.MaxASCII {
r = append(r, c)
continue
}
if c > unicode.MaxRune || c > transCount {
/* Ignore reserved chars */
continue
}
if d := transliterations[c]; d != nil {
r = append(r, d...)
}
}
res := string(r)
if l <= pooledCapacity {
slicePool.Put(r)
}
return res
}

View File

@ -0,0 +1,57 @@
package unidecode
import (
"testing"
)
func testTransliteration(original string, decoded string, t *testing.T) {
if r := Unidecode(original); r != decoded {
t.Errorf("Expected '%s', got '%s'\n", decoded, r)
}
}
func TestASCII(t *testing.T) {
s := "ABCDEF"
testTransliteration(s, s, t)
}
func TestKnosos(t *testing.T) {
o := "Κνωσός"
d := "Knosos"
testTransliteration(o, d, t)
}
func TestBeiJing(t *testing.T) {
o := "\u5317\u4EB0"
d := "Bei Jing "
testTransliteration(o, d, t)
}
func TestEmoji(t *testing.T) {
o := "Hey Luna t belle 😵😂"
d := "Hey Luna t belle "
testTransliteration(o, d, t)
}
func BenchmarkUnidecode(b *testing.B) {
cases := []string{
"ABCDEF",
"Κνωσός",
"\u5317\u4EB0",
}
for ii := 0; ii < b.N; ii++ {
for _, v := range cases {
_ = Unidecode(v)
}
}
}
func BenchmarkDecodeTable(b *testing.B) {
for ii := 0; ii < b.N; ii++ {
decodeTransliterations()
}
}
func init() {
decodeTransliterations()
}

114
Godeps/_workspace/src/github.com/go-xorm/core/README.md generated vendored Normal file
View File

@ -0,0 +1,114 @@
Core is a lightweight wrapper of sql.DB.
# Open
```Go
db, _ := core.Open(db, connstr)
```
# SetMapper
```Go
db.SetMapper(SameMapper())
```
## Scan usage
### Scan
```Go
rows, _ := db.Query()
for rows.Next() {
rows.Scan()
}
```
### ScanMap
```Go
rows, _ := db.Query()
for rows.Next() {
rows.ScanMap()
```
### ScanSlice
You can use `[]string`, `[][]byte`, `[]interface{}`, `[]*string`, `[]sql.NullString` to ScanSclice. Notice, slice's length should be equal or less than select columns.
```Go
rows, _ := db.Query()
cols, _ := rows.Columns()
for rows.Next() {
var s = make([]string, len(cols))
rows.ScanSlice(&s)
}
```
```Go
rows, _ := db.Query()
cols, _ := rows.Columns()
for rows.Next() {
var s = make([]*string, len(cols))
rows.ScanSlice(&s)
}
```
### ScanStruct
```Go
rows, _ := db.Query()
for rows.Next() {
rows.ScanStructByName()
rows.ScanStructByIndex()
}
```
## Query usage
```Go
rows, err := db.Query("select * from table where name = ?", name)
user = User{
Name:"lunny",
}
rows, err := db.QueryStruct("select * from table where name = ?Name",
&user)
var user = map[string]interface{}{
"name": "lunny",
}
rows, err = db.QueryMap("select * from table where name = ?name",
&user)
```
## QueryRow usage
```Go
row := db.QueryRow("select * from table where name = ?", name)
user = User{
Name:"lunny",
}
row := db.QueryRowStruct("select * from table where name = ?Name",
&user)
var user = map[string]interface{}{
"name": "lunny",
}
row = db.QueryRowMap("select * from table where name = ?name",
&user)
```
## Exec usage
```Go
db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", name, title, age, alias...)
user = User{
Name:"lunny",
Title:"test",
Age: 18,
}
result, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
&user)
var user = map[string]interface{}{
"Name": "lunny",
"Title": "test",
"Age": 18,
}
result, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
&user)
```

View File

@ -121,6 +121,21 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
col.fieldPath = strings.Split(col.FieldName, ".")
}
if dataStruct.Type().Kind() == reflect.Map {
var keyValue reflect.Value
if len(col.fieldPath) == 1 {
keyValue = reflect.ValueOf(col.FieldName)
} else if len(col.fieldPath) == 2 {
keyValue = reflect.ValueOf(col.fieldPath[1])
} else {
return nil, fmt.Errorf("Unsupported mutliderive %v", col.FieldName)
}
fieldValue = dataStruct.MapIndex(keyValue)
return &fieldValue, nil
}
if len(col.fieldPath) == 1 {
fieldValue = dataStruct.FieldByName(col.FieldName)
} else if len(col.fieldPath) == 2 {

View File

@ -47,15 +47,13 @@ type Dialect interface {
SupportInsertMany() bool
SupportEngine() bool
SupportCharset() bool
SupportDropIfExists() bool
IndexOnTable() bool
ShowCreateNull() bool
IndexCheckSql(tableName, idxName string) (string, []interface{})
TableCheckSql(tableName string) (string, []interface{})
//ColumnCheckSql(tableName, colName string) (string, []interface{})
//IsTableExist(tableName string) (bool, error)
//IsIndexExist(tableName string, idx *Index) (bool, error)
IsColumnExist(tableName string, col *Column) (bool, error)
CreateTableSql(table *Table, tableName, storeEngine, charset string) string
@ -65,15 +63,13 @@ type Dialect interface {
ModifyColumnSql(tableName string, col *Column) string
//CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error
//MustDropTable(tableName string) error
GetColumns(tableName string) ([]string, map[string]*Column, error)
GetTables() ([]*Table, error)
GetIndexes(tableName string) (map[string]*Index, error)
// Get data from db cell to a struct's field
//GetData(col *Column, fieldValue *reflect.Value, cellData interface{}) error
// Set field data to db
//SetData(col *Column, fieldValue *refelct.Value) (interface{}, error)
Filters() []Filter
}
@ -144,6 +140,10 @@ func (db *Base) RollBackStr() string {
return "ROLL BACK"
}
func (db *Base) SupportDropIfExists() bool {
return true
}
func (db *Base) DropTableSql(tableName string) string {
return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName)
}
@ -170,35 +170,52 @@ func (db *Base) IsColumnExist(tableName string, col *Column) (bool, error) {
return db.HasRecords(query, db.DbName, tableName, col.Name)
}
/*
func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error {
sql, args := db.dialect.TableCheckSql(tableName)
rows, err := db.DB().Query(sql, args...)
if db.Logger != nil {
db.Logger.Info("[sql]", sql, args)
}
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
return nil
}
sql = db.dialect.CreateTableSql(table, tableName, storeEngine, charset)
_, err = db.DB().Exec(sql)
if db.Logger != nil {
db.Logger.Info("[sql]", sql)
}
return err
}*/
func (db *Base) CreateIndexSql(tableName string, index *Index) string {
quote := db.dialect.Quote
var unique string
var idxName string
if index.Type == UniqueType {
unique = " UNIQUE"
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
} else {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v);", unique,
idxName = index.XName(tableName)
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
quote(idxName), quote(tableName),
quote(strings.Join(index.Cols, quote(","))))
}
func (db *Base) DropIndexSql(tableName string, index *Index) string {
quote := db.dialect.Quote
//var unique string
var idxName string = index.Name
if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") {
if index.Type == UniqueType {
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
var name string
if index.IsRegular {
name = index.XName(tableName)
} else {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
name = index.Name
}
}
return fmt.Sprintf("DROP INDEX %v ON %s",
quote(idxName), quote(tableName))
return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
}
func (db *Base) ModifyColumnSql(tableName string, col *Column) string {

View File

@ -1,7 +1,9 @@
package core
import (
"fmt"
"sort"
"strings"
)
const (
@ -11,11 +13,23 @@ const (
// database index
type Index struct {
IsRegular bool
Name string
Type int
Cols []string
}
func (index *Index) XName(tableName string) string {
if !strings.HasPrefix(index.Name, "UQE_") &&
!strings.HasPrefix(index.Name, "IDX_") {
if index.Type == UniqueType {
return fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
}
return fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
return index.Name
}
// add columns which will be composite index
func (index *Index) AddColumn(cols ...string) {
for _, col := range cols {
@ -24,6 +38,9 @@ func (index *Index) AddColumn(cols ...string) {
}
func (index *Index) Equal(dst *Index) bool {
if index.Type != dst.Type {
return false
}
if len(index.Cols) != len(dst.Cols) {
return false
}
@ -40,5 +57,5 @@ func (index *Index) Equal(dst *Index) bool {
// new an index
func NewIndex(name string, indexType int) *Index {
return &Index{name, indexType, make([]string, 0)}
return &Index{true, name, indexType, make([]string, 0)}
}

View File

@ -9,7 +9,6 @@ import (
type IMapper interface {
Obj2Table(string) string
Table2Obj(string) string
TableName(string) string
}
type CacheMapper struct {
@ -56,10 +55,6 @@ func (m *CacheMapper) Table2Obj(t string) string {
return o
}
func (m *CacheMapper) TableName(t string) string {
return t
}
// SameMapper implements IMapper and provides same name between struct and
// database table
type SameMapper struct {
@ -73,10 +68,6 @@ func (m SameMapper) Table2Obj(t string) string {
return t
}
func (m SameMapper) TableName(t string) string {
return t
}
// SnakeMapper implements IMapper and provides name transaltion between
// struct and database table
type SnakeMapper struct {
@ -97,25 +88,6 @@ func snakeCasedName(name string) string {
return string(newstr)
}
/*func pascal2Sql(s string) (d string) {
d = ""
lastIdx := 0
for i := 0; i < len(s); i++ {
if s[i] >= 'A' && s[i] <= 'Z' {
if lastIdx < i {
d += s[lastIdx+1 : i]
}
if i != 0 {
d += "_"
}
d += string(s[i] + 32)
lastIdx = i
}
}
d += s[lastIdx+1:]
return
}*/
func (mapper SnakeMapper) Obj2Table(name string) string {
return snakeCasedName(name)
}
@ -148,9 +120,103 @@ func (mapper SnakeMapper) Table2Obj(name string) string {
return titleCasedName(name)
}
func (mapper SnakeMapper) TableName(t string) string {
return t
// GonicMapper implements IMapper. It will consider initialisms when mapping names.
// E.g. id -> ID, user -> User and to table names: UserID -> user_id, MyUID -> my_uid
type GonicMapper map[string]bool
func isASCIIUpper(r rune) bool {
return 'A' <= r && r <= 'Z'
}
func toASCIIUpper(r rune) rune {
if 'a' <= r && r <= 'z' {
r -= ('a' - 'A')
}
return r
}
func gonicCasedName(name string) string {
newstr := make([]rune, 0, len(name)+3)
for idx, chr := range name {
if isASCIIUpper(chr) && idx > 0 {
if !isASCIIUpper(newstr[len(newstr)-1]) {
newstr = append(newstr, '_')
}
}
if !isASCIIUpper(chr) && idx > 1 {
l := len(newstr)
if isASCIIUpper(newstr[l-1]) && isASCIIUpper(newstr[l-2]) {
newstr = append(newstr, newstr[l-1])
newstr[l-1] = '_'
}
}
newstr = append(newstr, chr)
}
return strings.ToLower(string(newstr))
}
func (mapper GonicMapper) Obj2Table(name string) string {
return gonicCasedName(name)
}
func (mapper GonicMapper) Table2Obj(name string) string {
newstr := make([]rune, 0)
name = strings.ToLower(name)
parts := strings.Split(name, "_")
for _, p := range parts {
_, isInitialism := mapper[strings.ToUpper(p)]
for i, r := range p {
if i == 0 || isInitialism {
r = toASCIIUpper(r)
}
newstr = append(newstr, r)
}
}
return string(newstr)
}
// A GonicMapper that contains a list of common initialisms taken from golang/lint
var LintGonicMapper = GonicMapper{
"API": true,
"ASCII": true,
"CPU": true,
"CSS": true,
"DNS": true,
"EOF": true,
"GUID": true,
"HTML": true,
"HTTP": true,
"HTTPS": true,
"ID": true,
"IP": true,
"JSON": true,
"LHS": true,
"QPS": true,
"RAM": true,
"RHS": true,
"RPC": true,
"SLA": true,
"SMTP": true,
"SSH": true,
"TLS": true,
"TTL": true,
"UI": true,
"UID": true,
"UUID": true,
"URI": true,
"URL": true,
"UTF8": true,
"VM": true,
"XML": true,
"XSRF": true,
"XSS": true,
}
// provide prefix table name support
type PrefixMapper struct {
Mapper IMapper
@ -165,10 +231,6 @@ func (mapper PrefixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):])
}
func (mapper PrefixMapper) TableName(name string) string {
return mapper.Prefix + name
}
func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper {
return PrefixMapper{mapper, prefix}
}
@ -187,10 +249,6 @@ func (mapper SuffixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)])
}
func (mapper SuffixMapper) TableName(name string) string {
return name + mapper.Suffix
}
func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper {
return SuffixMapper{mapper, suffix}
}

View File

@ -0,0 +1,45 @@
package core
import (
"testing"
)
func TestGonicMapperFromObj(t *testing.T) {
testCases := map[string]string{
"HTTPLib": "http_lib",
"id": "id",
"ID": "id",
"IDa": "i_da",
"iDa": "i_da",
"IDAa": "id_aa",
"aID": "a_id",
"aaID": "aa_id",
"aaaID": "aaa_id",
"MyREalFunkYLONgNAME": "my_r_eal_funk_ylo_ng_name",
}
for in, expected := range testCases {
out := gonicCasedName(in)
if out != expected {
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
}
}
}
func TestGonicMapperToObj(t *testing.T) {
testCases := map[string]string{
"http_lib": "HTTPLib",
"id": "ID",
"ida": "Ida",
"id_aa": "IDAa",
"aa_id": "AaID",
"my_r_eal_funk_ylo_ng_name": "MyREalFunkYloNgName",
}
for in, expected := range testCases {
out := LintGonicMapper.Table2Obj(in)
if out != expected {
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
}
}
}

View File

@ -1,7 +1,8 @@
package core
import (
"encoding/json"
"bytes"
"encoding/gob"
)
type PK []interface{}
@ -12,14 +13,14 @@ func NewPK(pks ...interface{}) *PK {
}
func (p *PK) ToString() (string, error) {
bs, err := json.Marshal(*p)
if err != nil {
return "", nil
}
return string(bs), nil
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(*p)
return buf.String(), err
}
func (p *PK) FromString(content string) error {
return json.Unmarshal([]byte(content), p)
dec := gob.NewDecoder(bytes.NewBufferString(content))
err := dec.Decode(p)
return err
}

View File

@ -2,6 +2,7 @@ package core
import (
"fmt"
"reflect"
"testing"
)
@ -19,4 +20,14 @@ func TestPK(t *testing.T) {
t.Error(err)
}
fmt.Println(s)
if len(*p) != len(*s) {
t.Fatal("p", *p, "should be equal", *s)
}
for i, ori := range *p {
if ori != (*s)[i] {
t.Fatal("ori", ori, reflect.ValueOf(ori), "should be equal", (*s)[i], reflect.ValueOf((*s)[i]))
}
}
}

View File

@ -65,13 +65,18 @@ func (table *Table) GetColumnIdx(name string, idx int) *Column {
// if has primary key, return column
func (table *Table) PKColumns() []*Column {
columns := make([]*Column, 0)
for _, name := range table.PrimaryKeys {
columns = append(columns, table.GetColumn(name))
columns := make([]*Column, len(table.PrimaryKeys))
for i, name := range table.PrimaryKeys {
columns[i] = table.GetColumn(name)
}
return columns
}
func (table *Table) ColumnType(name string) reflect.Type {
t, _ := table.Type.FieldByName(name)
return t.Type
}
func (table *Table) AutoIncrColumn() *Column {
return table.GetColumn(table.AutoIncrement)
}

View File

@ -70,6 +70,7 @@ var (
NVarchar = "NVARCHAR"
TinyText = "TINYTEXT"
Text = "TEXT"
Clob = "CLOB"
MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT"
Uuid = "UUID"
@ -120,6 +121,7 @@ var (
MediumText: TEXT_TYPE,
LongText: TEXT_TYPE,
Uuid: TEXT_TYPE,
Clob: TEXT_TYPE,
Date: TIME_TYPE,
DateTime: TIME_TYPE,
@ -250,7 +252,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
case reflect.String:
st = SQLType{Varchar, 255, 0}
case reflect.Struct:
if t == reflect.TypeOf(c_TIME_DEFAULT) {
if t.ConvertibleTo(reflect.TypeOf(c_TIME_DEFAULT)) {
st = SQLType{DateTime, 0, 0}
} else {
// TODO need to handle association struct
@ -303,7 +305,7 @@ func SQLType2Type(st SQLType) reflect.Type {
return reflect.TypeOf(float32(1))
case Double:
return reflect.TypeOf(float64(1))
case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid:
case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid, Clob:
return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
return reflect.TypeOf([]byte{})

View File

@ -82,11 +82,13 @@ Or
# Cases
* [Wego](http://github.com/go-tango/wego)
* [Docker.cn](https://docker.cn/)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [Gorevel](http://http://gorevel.cn/) - [github.com/goofcc/gorevel](http://github.com/goofcc/gorevel)
* [Gorevel](http://gorevel.cn/) - [github.com/goofcc/gorevel](http://github.com/goofcc/gorevel)
* [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker)

View File

@ -44,16 +44,10 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
## 更新日志
* **v0.4.0 RC1**
新特性:
* 移动xorm cmd [github.com/go-xorm/cmd](github.com/go-xorm/cmd)
* 在重构一般DB操作核心库 [github.com/go-xorm/core](https://github.com/go-xorm/core)
* 移动测试github.com/复XORM/测试 [github.com/go-xorm/tests](github.com/go-xorm/tests)
改进:
* Prepared statement 缓存
* 添加 Incr API
* 指定时区位置
* **v0.4.2**
新特性:
* deleted标记
* bug fixed
[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16)
@ -78,6 +72,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
## 案例
* [Wego](http://github.com/go-tango/wego)
* [Docker.cn](https://docker.cn/)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)

View File

@ -1 +1 @@
xorm v0.4.1
xorm v0.4.2.0225

View File

@ -63,21 +63,22 @@ There are 7 major ORM methods and many helpful methods to use to operate databas
// SELECT * FROM user
4. Query multiple records and record by record handle, there two methods, one is Iterate,
another is Raws
another is Rows
err := engine.Iterate(...)
// SELECT * FROM user
raws, err := engine.Raws(...)
rows, err := engine.Rows(...)
// SELECT * FROM user
defer rows.Close()
bean := new(Struct)
for raws.Next() {
err = raws.Scan(bean)
for rows.Next() {
err = rows.Scan(bean)
}
5. Update one or more records
affected, err := engine.Update(&user)
affected, err := engine.Id(...).Update(&user)
// UPDATE user SET ...
6. Delete one or more records, Delete MUST has conditon
@ -150,6 +151,6 @@ Attention: the above 7 methods should be the last chainable method.
engine.Join("LEFT", "userdetail", "user.id=userdetail.id").Find()
//SELECT * FROM user LEFT JOIN userdetail ON user.id=userdetail.id
More usage, please visit https://github.com/go-xorm/xorm/blob/master/docs/QuickStartEn.md
More usage, please visit http://xorm.io/docs
*/
package xorm

View File

@ -344,7 +344,7 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) {
if col := table.GetColumn(name); col != nil {
col.Indexes[index.Name] = true
} else {
return nil, fmt.Errorf("Unknown col "+name+" in indexes %v", index)
return nil, fmt.Errorf("Unknown col "+name+" in indexes %v of table", index, table.ColumnsSeq())
}
}
}
@ -352,6 +352,9 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) {
return tables, nil
}
/*
dump database all table structs and data to a file
*/
func (engine *Engine) DumpAllToFile(fp string) error {
f, err := os.Create(fp)
if err != nil {
@ -361,6 +364,9 @@ func (engine *Engine) DumpAllToFile(fp string) error {
return engine.DumpAll(f)
}
/*
dump database all table structs and data to w
*/
func (engine *Engine) DumpAll(w io.Writer) error {
tables, err := engine.DBMetas()
if err != nil {
@ -558,6 +564,13 @@ func (engine *Engine) Decr(column string, arg ...interface{}) *Session {
return session.Decr(column, arg...)
}
// Method SetExpr provides a update string like "column = {expression}"
func (engine *Engine) SetExpr(column string, expression string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.SetExpr(column, expression)
}
// Temporarily change the Get, Find, Update's table
func (engine *Engine) Table(tableNameOrBean interface{}) *Session {
session := engine.NewSession()
@ -766,7 +779,12 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
col.IsPrimaryKey = true
col.Nullable = false
case k == "NULL":
if j == 0 {
col.Nullable = true
} else {
col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT")
}
// TODO: for postgres how add autoincr?
/*case strings.HasPrefix(k, "AUTOINCR(") && strings.HasSuffix(k, ")"):
col.IsAutoIncrement = true
@ -915,7 +933,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
table.AddColumn(col)
if fieldType.Kind() == reflect.Int64 && (col.FieldName == "Id" || strings.HasSuffix(col.FieldName, ".Id")) {
if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
idFieldColName = col.Name
}
} // end for
@ -959,40 +977,25 @@ func (engine *Engine) mapping(beans ...interface{}) (e error) {
// If a table has any reocrd
func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) {
v := rValue(bean)
t := v.Type()
if t.Kind() != reflect.Struct {
return false, errors.New("bean should be a struct or struct's point")
}
engine.autoMapType(v)
session := engine.NewSession()
defer session.Close()
rows, err := session.Count(bean)
return rows == 0, err
return session.IsTableEmpty(bean)
}
// If a table is exist
func (engine *Engine) IsTableExist(bean interface{}) (bool, error) {
v := rValue(bean)
var tableName string
if v.Type().Kind() == reflect.String {
tableName = bean.(string)
} else if v.Type().Kind() == reflect.Struct {
table := engine.autoMapType(v)
tableName = table.Name
} else {
return false, errors.New("bean should be a struct or struct's point")
}
func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) {
session := engine.NewSession()
defer session.Close()
has, err := session.isTableExist(tableName)
return has, err
return session.IsTableExist(beanOrTableName)
}
func (engine *Engine) IdOf(bean interface{}) core.PK {
table := engine.TableInfo(bean)
v := reflect.Indirect(reflect.ValueOf(bean))
return engine.IdOfV(reflect.ValueOf(bean))
}
func (engine *Engine) IdOfV(rv reflect.Value) core.PK {
v := reflect.Indirect(rv)
table := engine.autoMapType(v)
pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
pkField := v.FieldByName(col.FieldName)
@ -1109,7 +1112,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
isExist, err := session.isColumnExist(table.Name, col)
isExist, err := session.Engine.dialect.IsColumnExist(table.Name, col)
if err != nil {
return err
}
@ -1222,8 +1225,9 @@ func (engine *Engine) CreateTables(beans ...interface{}) error {
func (engine *Engine) DropTables(beans ...interface{}) error {
session := engine.NewSession()
err := session.Begin()
defer session.Close()
err := session.Begin()
if err != nil {
return err
}
@ -1258,13 +1262,6 @@ func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice [
return session.Query(sql, paramStr...)
}
// Exec a raw sql and return records as []map[string]string
func (engine *Engine) Q(sql string, paramStr ...interface{}) (resultsSlice []map[string]string, err error) {
session := engine.NewSession()
defer session.Close()
return session.Q(sql, paramStr...)
}
// Insert one or more records
func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
session := engine.NewSession()
@ -1371,18 +1368,11 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) {
scanner.Split(semiColSpliter)
session := engine.NewSession()
defer session.Close()
err := session.newDb()
if err != nil {
return results, err
}
for scanner.Scan() {
query := scanner.Text()
query = strings.Trim(query, " \t")
if len(query) > 0 {
result, err := session.Db.Exec(query)
result, err := engine.DB().Exec(query)
results = append(results, result)
if err != nil {
lastError = err
@ -1409,7 +1399,15 @@ func (engine *Engine) NowTime(sqlTypeName string) interface{} {
return engine.FormatTime(sqlTypeName, t)
}
func (engine *Engine) NowTime2(sqlTypeName string) (interface{}, time.Time) {
t := time.Now()
return engine.FormatTime(sqlTypeName, t), t
}
func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}) {
if engine.dialect.DBType() == core.ORACLE {
return t
}
switch sqlTypeName {
case core.Time:
s := engine.TZTime(t).Format("2006-01-02 15:04:05") //time.RFC3339
@ -1419,6 +1417,8 @@ func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}
case core.DateTime, core.TimeStamp:
if engine.dialect.DBType() == "ql" {
v = engine.TZTime(t)
} else if engine.dialect.DBType() == "sqlite3" {
v = engine.TZTime(t).UTC().Format("2006-01-02 15:04:05")
} else {
v = engine.TZTime(t).Format("2006-01-02 15:04:05")
}
@ -1430,6 +1430,8 @@ func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}
} else {
v = engine.TZTime(t).Format(time.RFC3339Nano)
}
case core.BigInt, core.Int:
v = engine.TZTime(t).Unix()
default:
v = engine.TZTime(t)
}

View File

@ -11,6 +11,43 @@ import (
"github.com/go-xorm/core"
)
func isZero(k interface{}) bool {
switch k.(type) {
case int:
return k.(int) == 0
case int8:
return k.(int8) == 0
case int16:
return k.(int16) == 0
case int32:
return k.(int32) == 0
case int64:
return k.(int64) == 0
case uint:
return k.(uint) == 0
case uint8:
return k.(uint8) == 0
case uint16:
return k.(uint16) == 0
case uint32:
return k.(uint32) == 0
case uint64:
return k.(uint64) == 0
case string:
return k.(string) == ""
}
return false
}
func isPKZero(pk core.PK) bool {
for _, k := range pk {
if isZero(k) {
return true
}
}
return false
}
func indexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep))
}
@ -163,3 +200,182 @@ func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
return resultsSlice, nil
}
func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) {
result := make(map[string][]byte)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
if data, err := value2Bytes(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) {
result := make(map[string]string)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
if data, err := value2String(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func txQuery2(tx *core.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string]string, err error) {
rows, err := tx.Query(sqlStr, params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
func query2(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string]string, err error) {
s, err := db.Prepare(sqlStr)
if err != nil {
return nil, err
}
defer s.Close()
rows, err := s.Query(params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Struct:
v.Set(reflect.ValueOf(t).Convert(v.Type()))
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t.Unix())
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t.Unix()))
}
}
}
func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0)
args := make([]interface{}, 0)
for _, col := range table.Columns() {
lColName := strings.ToLower(col.Name)
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := session.Statement.columnMap[lColName]; !ok {
continue
}
}
if col.MapType == core.ONLYFROMDB {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
session.Engine.LogError(err)
continue
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
}
}
if col.IsDeleted {
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := session.Statement.columnMap[lColName]; !ok {
continue
}
}
if session.Statement.OmitStr != "" {
if _, ok := session.Statement.columnMap[lColName]; ok {
continue
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
val, t := session.Engine.NowTime2(col.SQLType.Name)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.Statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
if includeQuote {
colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?")
} else {
colNames = append(colNames, col.Name)
}
}
return colNames, args, nil
}

View File

@ -270,7 +270,7 @@ func (db *mssql) IsReserved(name string) bool {
}
func (db *mssql) Quote(name string) string {
return "[" + name + "]"
return "\"" + name + "\""
}
func (db *mssql) QuoteStr() string {

View File

@ -218,6 +218,9 @@ func (db *mysql) SqlType(c *core.Column) string {
res += ")"
case core.NVarchar:
res = core.Varchar
case core.Uuid:
res = core.Varchar
c.Length = 40
default:
res = t
}
@ -317,7 +320,6 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
if err != nil {
return nil, nil, err
}
//fmt.Println(columnName, isNullable, colType, colKey, extra, colDefault)
col.Name = strings.Trim(columnName, "` ")
if "YES" == isNullable {
col.Nullable = true
@ -467,15 +469,17 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
}
colName = strings.Trim(colName, "` ")
var isRegular bool
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)]
isRegular = true
}
var index *core.Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(core.Index)
index.IsRegular = isRegular
index.Type = indexType
index.Name = indexName
indexes[indexName] = index

View File

@ -509,7 +509,7 @@ func (db *oracle) SqlType(c *core.Column) string {
var res string
switch t := c.SQLType.Name; t {
case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial:
return "NUMBER"
res = "NUMBER"
case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea:
return core.Blob
case core.Time, core.DateTime, core.TimeStamp:
@ -521,7 +521,7 @@ func (db *oracle) SqlType(c *core.Column) string {
case core.Text, core.MediumText, core.LongText:
res = "CLOB"
case core.Char, core.Varchar, core.TinyText:
return "VARCHAR2"
res = "VARCHAR2"
default:
res = t
}
@ -536,6 +536,10 @@ func (db *oracle) SqlType(c *core.Column) string {
return res
}
func (db *oracle) AutoIncrStr() string {
return "AUTO_INCREMENT"
}
func (db *oracle) SupportInsertMany() bool {
return true
}
@ -553,10 +557,6 @@ func (db *oracle) QuoteStr() string {
return "\""
}
func (db *oracle) AutoIncrStr() string {
return ""
}
func (db *oracle) SupportEngine() bool {
return false
}
@ -565,19 +565,94 @@ func (db *oracle) SupportCharset() bool {
return false
}
func (db *oracle) SupportDropIfExists() bool {
return false
}
func (db *oracle) IndexOnTable() bool {
return false
}
func (db *oracle) DropTableSql(tableName string) string {
return fmt.Sprintf("DROP TABLE `%s`", tableName)
}
func (b *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE "
if tableName == "" {
tableName = table.Name
}
sql += b.Quote(tableName) + " ("
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
/*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect)
} else {*/
sql += col.StringNoPk(b)
//}
sql = strings.TrimSpace(sql)
sql += ", "
}
if len(pkList) > 0 {
sql += "PRIMARY KEY ( "
sql += b.Quote(strings.Join(pkList, b.Quote(",")))
sql += " ), "
}
sql = sql[:len(sql)-2] + ")"
if b.SupportEngine() && storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if b.SupportCharset() {
if len(charset) == 0 {
charset = b.URI().Charset
}
if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
}
return sql
}
func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(idxName)}
args := []interface{}{tableName, idxName}
return `SELECT INDEX_NAME FROM USER_INDEXES ` +
`WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args
`WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args
}
func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{strings.ToUpper(tableName)}
return `SELECT table_name FROM user_tables WHERE table_name = ?`, args
args := []interface{}{tableName}
return `SELECT table_name FROM user_tables WHERE table_name = :1`, args
}
func (db *oracle) MustDropTable(tableName string) error {
sql, args := db.TableCheckSql(tableName)
if db.Logger != nil {
db.Logger.Info("[sql]", sql, args)
}
rows, err := db.DB().Query(sql, args...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
return nil
}
sql = "Drop Table \"" + tableName + "\""
if db.Logger != nil {
db.Logger.Info("[sql]", sql)
}
_, err = db.DB().Exec(sql)
return err
}
/*func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
@ -587,9 +662,9 @@ func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) {
}*/
func (db *oracle) IsColumnExist(tableName string, col *core.Column) (bool, error) {
args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(col.Name)}
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" +
" AND column_name = ?"
args := []interface{}{tableName, col.Name}
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" +
" AND column_name = :2"
rows, err := db.DB().Query(query, args...)
if db.Logger != nil {
db.Logger.Info("[sql]", query, args)
@ -606,7 +681,7 @@ func (db *oracle) IsColumnExist(tableName string, col *core.Column) (bool, error
}
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{strings.ToUpper(tableName)}
args := []interface{}{tableName}
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
@ -625,7 +700,7 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
col := new(core.Column)
col.Indexes = make(map[string]bool)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale string
var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
var dataLen int
err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
@ -634,36 +709,66 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
return nil, nil, err
}
col.Name = strings.Trim(colName, `" `)
col.Default = colDefault
col.Name = strings.Trim(*colName, `" `)
if colDefault != nil {
col.Default = *colDefault
col.DefaultIsEmpty = false
}
if nullable == "Y" {
if *nullable == "Y" {
col.Nullable = true
} else {
col.Nullable = false
}
switch dataType {
var ignore bool
var dt string
var len1, len2 int
dts := strings.Split(*dataType, "(")
dt = dts[0]
if len(dts) > 1 {
lens := strings.Split(dts[1][:len(dts[1])-1], ",")
if len(lens) > 1 {
len1, _ = strconv.Atoi(lens[0])
len2, _ = strconv.Atoi(lens[1])
} else {
len1, _ = strconv.Atoi(lens[0])
}
}
switch dt {
case "VARCHAR2":
col.SQLType = core.SQLType{core.Varchar, 0, 0}
col.SQLType = core.SQLType{core.Varchar, len1, len2}
case "TIMESTAMP WITH TIME ZONE":
col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
case "NUMBER":
col.SQLType = core.SQLType{core.Double, len1, len2}
case "LONG", "LONG RAW":
col.SQLType = core.SQLType{core.Text, 0, 0}
case "RAW":
col.SQLType = core.SQLType{core.Binary, 0, 0}
case "ROWID":
col.SQLType = core.SQLType{core.Varchar, 18, 0}
case "AQ$_SUBSCRIBERS":
ignore = true
default:
col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0}
col.SQLType = core.SQLType{strings.ToUpper(dt), len1, len2}
}
if ignore {
continue
}
if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v %v", *dataType, col.SQLType))
}
col.Length = dataLen
if col.SQLType.IsText() || col.SQLType.IsTime() {
if col.Default != "" {
if !col.DefaultIsEmpty {
col.Default = "'" + col.Default + "'"
} else {
if col.DefaultIsEmpty {
col.Default = "''"
}
}
}
cols[col.Name] = col

View File

@ -25,11 +25,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
rows.session = session
rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type()
err := rows.session.newDb()
if err != nil {
return nil, err
}
defer rows.session.Statement.Init()
var sqlStr string
@ -47,8 +42,8 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
}
rows.session.Engine.logSQL(sqlStr, args)
rows.stmt, err = rows.session.Db.Prepare(sqlStr)
var err error
rows.stmt, err = rows.session.DB().Prepare(sqlStr)
if err != nil {
rows.lastError = err
defer rows.Close()

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
package xorm
import (
"database/sql"
"errors"
"fmt"
"strings"
@ -152,7 +153,7 @@ func (db *sqlite3) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName st
func (db *sqlite3) SqlType(c *core.Column) string {
switch t := c.SQLType.Name; t {
case core.Date, core.DateTime, core.TimeStamp, core.Time:
return core.Numeric
return core.DateTime
case core.TimeStampz:
return core.Text
case core.Char, core.Varchar, core.NVarchar, core.TinyText, core.Text, core.MediumText, core.LongText:
@ -297,6 +298,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
col := new(core.Column)
col.Indexes = make(map[string]bool)
col.Nullable = true
col.DefaultIsEmpty = true
for idx, field := range fields {
if idx == 0 {
col.Name = strings.Trim(field, "`[] ")
@ -315,8 +317,14 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
} else {
col.Nullable = true
}
case "DEFAULT":
col.Default = fields[idx+1]
col.DefaultIsEmpty = false
}
}
if !col.SQLType.IsNumeric() && !col.DefaultIsEmpty {
col.Default = "'" + col.Default + "'"
}
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
@ -366,15 +374,16 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
indexes := make(map[string]*core.Index, 0)
for rows.Next() {
var sql string
err = rows.Scan(&sql)
var tmpSql sql.NullString
err = rows.Scan(&tmpSql)
if err != nil {
return nil, err
}
if sql == "" {
if !tmpSql.Valid {
continue
}
sql := tmpSql.String
index := new(core.Index)
nNStart := strings.Index(sql, "INDEX")
@ -384,7 +393,6 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
}
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
//fmt.Println(indexName)
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
index.Name = indexName[5+len(tableName) : len(indexName)]
} else {

View File

@ -26,6 +26,11 @@ type decrParam struct {
arg interface{}
}
type exprParam struct {
colName string
expr string
}
// statement save all the sql info for executing SQL
type Statement struct {
RefTable *core.Table
@ -63,6 +68,7 @@ type Statement struct {
inColumns map[string]*inParam
incrColumns map[string]incrParam
decrColumns map[string]decrParam
exprColumns map[string]exprParam
}
// init
@ -98,6 +104,7 @@ func (statement *Statement) Init() {
statement.inColumns = make(map[string]*inParam)
statement.incrColumns = make(map[string]incrParam)
statement.decrColumns = make(map[string]decrParam)
statement.exprColumns = make(map[string]exprParam)
}
// add the raw sql statement
@ -153,9 +160,6 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
t := v.Type()
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
if statement.AltTableName[0] == '~' {
statement.AltTableName = statement.Engine.TableMapper.TableName(statement.AltTableName[1:])
}
} else if t.Kind() == reflect.Struct {
statement.RefTable = statement.Engine.autoMapType(v)
}
@ -282,7 +286,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
func buildUpdates(engine *Engine, table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool,
mustColumnMap map[string]bool, update bool) ([]string, []interface{}) {
mustColumnMap map[string]bool, columnMap map[string]bool, update bool) ([]string, []interface{}) {
colNames := make([]string, 0)
var args = make([]interface{}, 0)
@ -302,6 +306,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if col.IsDeleted {
continue
}
if use, ok := columnMap[col.Name]; ok && !use {
continue
}
if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
continue
@ -414,13 +421,16 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if table, ok := engine.Tables[fieldValue.Type()]; ok {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
if pkField.Int() != 0 {
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !isZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else {
//TODO: how to handler?
panic("not supported")
}
} else {
val = fieldValue.Interface()
@ -579,24 +589,29 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType == reflect.TypeOf(time.Now()) {
t := fieldValue.Interface().(time.Time)
if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = engine.FormatTime(col.SQLType.Name, t)
} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
continue
} else {
engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
if pkField.Int() != 0 {
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !isZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else {
//TODO: how to handler?
panic("not supported")
}
} else {
val = fieldValue.Interface()
@ -716,6 +731,13 @@ func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
return statement
}
// Generate "Update ... Set column = {expression}" statment
func (statement *Statement) SetExpr(column string, expression string) *Statement {
k := strings.ToLower(column)
statement.exprColumns[k] = exprParam{column, expression}
return statement
}
// Generate "Update ... Set column = column + arg" statment
func (statement *Statement) getInc() map[string]incrParam {
return statement.incrColumns
@ -726,6 +748,11 @@ func (statement *Statement) getDec() map[string]decrParam {
return statement.decrColumns
}
// Generate "Update ... Set column = {expression}" statment
func (statement *Statement) getExpr() map[string]exprParam {
return statement.exprColumns
}
// Generate "Where column IN (?) " statment
func (statement *Statement) In(column string, args ...interface{}) *Statement {
k := strings.ToLower(column)
@ -941,15 +968,9 @@ func (statement *Statement) Join(join_operator string, tablename interface{}, co
l := len(t)
if l > 1 {
table := t[0]
if table[0] == '~' {
table = statement.Engine.TableMapper.TableName(table[1:])
}
joinTable = statement.Engine.Quote(table) + " AS " + statement.Engine.Quote(t[1])
} else if l == 1 {
table := t[0]
if table[0] == '~' {
table = statement.Engine.TableMapper.TableName(table[1:])
}
joinTable = statement.Engine.Quote(table)
}
case []interface{}:
@ -962,9 +983,6 @@ func (statement *Statement) Join(join_operator string, tablename interface{}, co
t := v.Type()
if t.Kind() == reflect.String {
table = f.(string)
if table[0] == '~' {
table = statement.Engine.TableMapper.TableName(table[1:])
}
} else if t.Kind() == reflect.Struct {
r := statement.Engine.autoMapType(v)
table = r.Name
@ -977,9 +995,6 @@ func (statement *Statement) Join(join_operator string, tablename interface{}, co
}
default:
t := fmt.Sprintf("%v", tablename)
if t[0] == '~' {
t = statement.Engine.TableMapper.TableName(t[1:])
}
joinTable = statement.Engine.Quote(t)
}
if statement.JoinStr != "" {
@ -1105,9 +1120,10 @@ func (s *Statement) genDelIndexSQL() []string {
return sqls
}
/*
func (s *Statement) genDropSQL() string {
return s.Engine.dialect.DropTableSql(s.TableName()) + ";"
}
return s.Engine.dialect.MustDropTa(s.TableName()) + ";"
}*/
func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) {
var table *core.Table
@ -1126,15 +1142,23 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
statement.BeanArgs = args
var columnStr string = statement.ColumnStr
if statement.JoinStr == "" {
if columnStr == "" {
if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 {
if statement.GroupByStr != "" {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if len(columnStr) == 0 {
if statement.GroupByStr != "" {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
} else {
if columnStr == "" {
columnStr = "*"
}
}
}
statement.attachInSql() // !admpub! fix bug:Iterate func missing "... IN (...)"
return statement.genSelectSql(columnStr), append(statement.Params, statement.BeanArgs...)
@ -1178,14 +1202,16 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}
id = ""
}
statement.attachInSql()
return statement.genSelectSql(fmt.Sprintf("count(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
return statement.genSelectSql(fmt.Sprintf("count(%v)", id)), append(statement.Params, statement.BeanArgs...)
}
func (statement *Statement) genSelectSql(columnStr string) (a string) {
if statement.GroupByStr != "" {
/*if statement.GroupByStr != "" {
if columnStr == "" {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
statement.GroupByStr = columnStr
}
//statement.GroupByStr = columnStr
}*/
var distinct string
if statement.IsDistinct {
distinct = "DISTINCT "
@ -1210,8 +1236,12 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
}
var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName())
if statement.TableAlias != "" {
if statement.Engine.dialect.DBType() == core.ORACLE {
fromStr += " " + statement.Engine.Quote(statement.TableAlias)
} else {
fromStr += " AS " + statement.Engine.Quote(statement.TableAlias)
}
}
if statement.JoinStr != "" {
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
}
@ -1233,8 +1263,16 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
column = statement.RefTable.ColumnsSeq()[0]
}
}
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s))",
column, statement.Start, column, fromStr, whereStr)
var orderStr string
if len(statement.OrderStr) > 0 {
orderStr = " ORDER BY " + statement.OrderStr
}
var groupStr string
if len(statement.GroupByStr) > 0 {
groupStr = " GROUP BY " + statement.GroupByStr
}
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
}
}
@ -1258,12 +1296,16 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
if statement.Engine.dialect.DBType() != core.MSSQL {
if statement.Engine.dialect.DBType() != core.MSSQL && statement.Engine.dialect.DBType() != core.ORACLE {
if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
}
} else if statement.Engine.dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
}
}
return

View File

@ -13,7 +13,7 @@ import (
)
const (
Version string = "0.4.1"
Version string = "0.4.2.0225"
)
func regDrvsNDialects() bool {
@ -84,17 +84,16 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
TZLocation: time.Local,
}
engine.dialect.SetLogger(engine.Logger)
engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper)))
//engine.Filters = dialect.Filters()
//engine.Cacher = NewLRUCacher()
//err = engine.SetPool(NewSysConnectPool())
runtime.SetFinalizer(engine, close)
return engine, err
return engine, nil
}
// clone an engine
func (engine *Engine) Clone() (*Engine, error) {
return NewEngine(engine.dialect.DriverName(), engine.dialect.DataSourceName())
return NewEngine(engine.DriverName(), engine.DataSourceName())
}

View File

@ -0,0 +1,2 @@
ledis/tmp.db
nodb/tmp.db

View File

@ -1,7 +1,7 @@
session [![Build Status](https://drone.io/github.com/macaron-contrib/session/status.png)](https://drone.io/github.com/macaron-contrib/session/latest) [![](http://gocover.io/_badge/github.com/macaron-contrib/session)](http://gocover.io/github.com/macaron-contrib/session)
=======
Middleware session provides session management for [Macaron](https://github.com/Unknwon/macaron). It can use many session providers, including memory, file, Redis, Memcache, PostgreSQL, MySQL, Couchbase and Ledis.
Middleware session provides session management for [Macaron](https://github.com/Unknwon/macaron). It can use many session providers, including memory, file, Redis, Memcache, PostgreSQL, MySQL, Couchbase, Ledis and Nodb.
### Installation
@ -12,6 +12,10 @@ Middleware session provides session management for [Macaron](https://github.com/
- [API Reference](https://gowalker.org/github.com/macaron-contrib/session)
- [Documentation](http://macaron.gogs.io/docs/middlewares/session)
## Credits
This package is forked from [beego/session](https://github.com/astaxie/beego/tree/master/session) with reconstruction(over 80%).
## License
This project is under Apache v2 License. See the [LICENSE](LICENSE) file for the full license text.

View File

@ -28,17 +28,17 @@ import (
"github.com/Unknwon/com"
)
// FileSessionStore represents a file session store implementation.
type FileSessionStore struct {
// FileStore represents a file session store implementation.
type FileStore struct {
p *FileProvider
sid string
lock sync.RWMutex
data map[interface{}]interface{}
}
// NewFileSessionStore creates and returns a file session store.
func NewFileSessionStore(p *FileProvider, sid string, kv map[interface{}]interface{}) *FileSessionStore {
return &FileSessionStore{
// NewFileStore creates and returns a file session store.
func NewFileStore(p *FileProvider, sid string, kv map[interface{}]interface{}) *FileStore {
return &FileStore{
p: p,
sid: sid,
data: kv,
@ -46,7 +46,7 @@ func NewFileSessionStore(p *FileProvider, sid string, kv map[interface{}]interfa
}
// Set sets value to given key in session.
func (s *FileSessionStore) Set(key, val interface{}) error {
func (s *FileStore) Set(key, val interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -55,7 +55,7 @@ func (s *FileSessionStore) Set(key, val interface{}) error {
}
// Get gets value by given key in session.
func (s *FileSessionStore) Get(key interface{}) interface{} {
func (s *FileStore) Get(key interface{}) interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
@ -63,7 +63,7 @@ func (s *FileSessionStore) Get(key interface{}) interface{} {
}
// Delete delete a key from session.
func (s *FileSessionStore) Delete(key interface{}) error {
func (s *FileStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -72,12 +72,12 @@ func (s *FileSessionStore) Delete(key interface{}) error {
}
// ID returns current session ID.
func (s *FileSessionStore) ID() string {
func (s *FileStore) ID() string {
return s.sid
}
// Release releases resource and save data to provider.
func (s *FileSessionStore) Release() error {
func (s *FileStore) Release() error {
data, err := EncodeGob(s.data)
if err != nil {
return err
@ -87,7 +87,7 @@ func (s *FileSessionStore) Release() error {
}
// Flush deletes all session data.
func (s *FileSessionStore) Flush() error {
func (s *FileStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
@ -97,7 +97,6 @@ func (s *FileSessionStore) Flush() error {
// FileProvider represents a file session provider implementation.
type FileProvider struct {
lock sync.RWMutex
maxlifetime int64
rootPath string
}
@ -115,9 +114,6 @@ func (p *FileProvider) filepath(sid string) string {
// Read returns raw session store by session ID.
func (p *FileProvider) Read(sid string) (_ RawStore, err error) {
p.lock.Lock()
defer p.lock.Unlock()
filename := p.filepath(sid)
if err = os.MkdirAll(path.Dir(filename), os.ModePerm); err != nil {
return nil, err
@ -151,22 +147,16 @@ func (p *FileProvider) Read(sid string) (_ RawStore, err error) {
return nil, err
}
}
return NewFileSessionStore(p, sid, kv), nil
return NewFileStore(p, sid, kv), nil
}
// Exist returns true if session with given ID exists.
func (p *FileProvider) Exist(sid string) bool {
p.lock.Lock()
defer p.lock.Unlock()
return com.IsFile(p.filepath(sid))
}
// Destory deletes a session by session ID.
func (p *FileProvider) Destory(sid string) error {
p.lock.Lock()
defer p.lock.Unlock()
return os.Remove(p.filepath(sid))
}
@ -201,12 +191,9 @@ func (p *FileProvider) regenerate(oldsid, sid string) (err error) {
// Regenerate regenerates a session store from old session ID to new one.
func (p *FileProvider) Regenerate(oldsid, sid string) (_ RawStore, err error) {
p.lock.Lock()
if err := p.regenerate(oldsid, sid); err != nil {
p.lock.Unlock()
return nil, err
}
p.lock.Unlock()
return p.Read(sid)
}
@ -236,9 +223,6 @@ func (p *FileProvider) GC() {
return
}
p.lock.Lock()
defer p.lock.Unlock()
if err := filepath.Walk(p.rootPath, func(path string, fi os.FileInfo, err error) error {
if err != nil {
return err

View File

@ -16,26 +16,39 @@
package session
import (
"fmt"
"strings"
"sync"
"github.com/Unknwon/com"
"github.com/siddontang/ledisdb/config"
"github.com/siddontang/ledisdb/ledis"
"gopkg.in/ini.v1"
"github.com/macaron-contrib/session"
)
var c *ledis.DB
// LedisSessionStore represents a ledis session store implementation.
type LedisSessionStore struct {
// LedisStore represents a ledis session store implementation.
type LedisStore struct {
c *ledis.DB
sid string
expire int64
lock sync.RWMutex
data map[interface{}]interface{}
maxlifetime int64
}
// NewLedisStore creates and returns a ledis session store.
func NewLedisStore(c *ledis.DB, sid string, expire int64, kv map[interface{}]interface{}) *LedisStore {
return &LedisStore{
c: c,
expire: expire,
sid: sid,
data: kv,
}
}
// Set sets value to given key in session.
func (s *LedisSessionStore) Set(key, val interface{}) error {
func (s *LedisStore) Set(key, val interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -44,7 +57,7 @@ func (s *LedisSessionStore) Set(key, val interface{}) error {
}
// Get gets value by given key in session.
func (s *LedisSessionStore) Get(key interface{}) interface{} {
func (s *LedisStore) Get(key interface{}) interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
@ -52,7 +65,7 @@ func (s *LedisSessionStore) Get(key interface{}) interface{} {
}
// Delete delete a key from session.
func (s *LedisSessionStore) Delete(key interface{}) error {
func (s *LedisStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -61,25 +74,26 @@ func (s *LedisSessionStore) Delete(key interface{}) error {
}
// ID returns current session ID.
func (s *LedisSessionStore) ID() string {
func (s *LedisStore) ID() string {
return s.sid
}
// Release releases resource and save data to provider.
func (s *LedisSessionStore) Release() error {
func (s *LedisStore) Release() error {
data, err := session.EncodeGob(s.data)
if err != nil {
return err
}
if err = c.Set([]byte(s.sid), data); err != nil {
if err = s.c.Set([]byte(s.sid), data); err != nil {
return err
}
_, err = c.Expire([]byte(s.sid), s.maxlifetime)
_, err = s.c.Expire([]byte(s.sid), s.expire)
return err
}
// Flush deletes all session data.
func (s *LedisSessionStore) Flush() error {
func (s *LedisStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
@ -89,30 +103,54 @@ func (s *LedisSessionStore) Flush() error {
// LedisProvider represents a ledis session provider implementation.
type LedisProvider struct {
maxlifetime int64
savePath string
c *ledis.DB
expire int64
}
// Init initializes memory session provider.
func (p *LedisProvider) Init(maxlifetime int64, savePath string) error {
p.maxlifetime = maxlifetime
p.savePath = savePath
cfg := new(config.Config)
cfg.DataDir = p.savePath
var err error
nowLedis, err := ledis.Open(cfg)
c, err = nowLedis.Select(0)
// Init initializes ledis session provider.
// configs: data_dir=./app.db,db=0
func (p *LedisProvider) Init(expire int64, configs string) error {
p.expire = expire
cfg, err := ini.Load([]byte(strings.Replace(configs, ",", "\n", -1)))
if err != nil {
println(err)
return nil
return err
}
return nil
db := 0
opt := new(config.Config)
for k, v := range cfg.Section("").KeysHash() {
switch k {
case "data_dir":
opt.DataDir = v
case "db":
db = com.StrTo(v).MustInt()
default:
return fmt.Errorf("session/ledis: unsupported option '%s'", k)
}
}
l, err := ledis.Open(opt)
if err != nil {
return fmt.Errorf("session/ledis: error opening db: %v", err)
}
p.c, err = l.Select(db)
return err
}
// Read returns raw session store by session ID.
func (p *LedisProvider) Read(sid string) (session.RawStore, error) {
kvs, err := c.Get([]byte(sid))
if !p.Exist(sid) {
if err := p.c.Set([]byte(sid), []byte("")); err != nil {
return nil, err
}
}
var kv map[interface{}]interface{}
kvs, err := p.c.Get([]byte(sid))
if err != nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
@ -121,41 +159,40 @@ func (p *LedisProvider) Read(sid string) (session.RawStore, error) {
return nil, err
}
}
ls := &LedisSessionStore{sid: sid, data: kv, maxlifetime: p.maxlifetime}
return ls, nil
return NewLedisStore(p.c, sid, p.expire, kv), nil
}
// Exist returns true if session with given ID exists.
func (p *LedisProvider) Exist(sid string) bool {
count, _ := c.Exists([]byte(sid))
if count == 0 {
return false
} else {
return true
}
count, err := p.c.Exists([]byte(sid))
return err == nil && count > 0
}
// Destory deletes a session by session ID.
func (p *LedisProvider) Destory(sid string) error {
_, err := c.Del([]byte(sid))
_, err := p.c.Del([]byte(sid))
return err
}
// Regenerate regenerates a session store from old session ID to new one.
func (p *LedisProvider) Regenerate(oldsid, sid string) (session.RawStore, error) {
count, _ := c.Exists([]byte(sid))
if count == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set([]byte(sid), []byte(""))
c.Expire([]byte(sid), p.maxlifetime)
} else {
data, _ := c.Get([]byte(oldsid))
c.Set([]byte(sid), data)
c.Expire([]byte(sid), p.maxlifetime)
func (p *LedisProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) {
if p.Exist(sid) {
return nil, fmt.Errorf("new sid '%s' already exists", sid)
}
kvs, err := c.Get([]byte(sid))
kvs := make([]byte, 0)
if p.Exist(oldsid) {
if kvs, err = p.c.Get([]byte(oldsid)); err != nil {
return nil, err
} else if _, err = p.c.Del([]byte(oldsid)); err != nil {
return nil, err
}
}
if err = p.c.SetEX([]byte(sid), p.expire, kvs); err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
@ -165,18 +202,20 @@ func (p *LedisProvider) Regenerate(oldsid, sid string) (session.RawStore, error)
return nil, err
}
}
ls := &LedisSessionStore{sid: sid, data: kv, maxlifetime: p.maxlifetime}
return ls, nil
return NewLedisStore(p.c, sid, p.expire, kv), nil
}
// Count counts and returns number of sessions.
func (p *LedisProvider) Count() int {
// FIXME
return 0
// FIXME: how come this library does not have DbSize() method?
return -1
}
// GC calls GC to clean expired sessions.
func (p *LedisProvider) GC() {}
func (p *LedisProvider) GC() {
// FIXME: wtf???
}
func init() {
session.Register("ledis", &LedisProvider{})

View File

@ -0,0 +1 @@
ignore

View File

@ -0,0 +1,105 @@
// Copyright 2014 Unknwon
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package session
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Unknwon/macaron"
. "github.com/smartystreets/goconvey/convey"
"github.com/macaron-contrib/session"
)
func Test_LedisProvider(t *testing.T) {
Convey("Test ledis session provider", t, func() {
opt := session.Options{
Provider: "ledis",
ProviderConfig: "data_dir=./tmp.db",
}
Convey("Basic operation", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
sess.Set("uname", "unknwon")
})
m.Get("/reg", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := raw.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
})
m.Get("/get", func(ctx *macaron.Context, sess session.Store) {
sid := sess.ID()
So(sid, ShouldNotBeEmpty)
raw, err := sess.Read(sid)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := sess.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
So(sess.Delete("uname"), ShouldBeNil)
So(sess.Get("uname"), ShouldBeNil)
So(sess.Destory(ctx), ShouldBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
m.ServeHTTP(resp, req)
cookie := resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/reg", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
cookie = resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/get", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
Convey("Regenrate empty session", func() {
m.Get("/empty", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
})
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/empty", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf486; Path=/;")
m.ServeHTTP(resp, req)
})
})
})
}

View File

@ -16,6 +16,7 @@
package session
import (
"fmt"
"strings"
"sync"
@ -24,20 +25,35 @@ import (
"github.com/macaron-contrib/session"
)
var (
client *memcache.Client
)
// MemcacheSessionStore represents a memcache session store implementation.
type MemcacheSessionStore struct {
// MemcacheStore represents a memcache session store implementation.
type MemcacheStore struct {
c *memcache.Client
sid string
expire int32
lock sync.RWMutex
data map[interface{}]interface{}
maxlifetime int64
}
// NewMemcacheStore creates and returns a memcache session store.
func NewMemcacheStore(c *memcache.Client, sid string, expire int32, kv map[interface{}]interface{}) *MemcacheStore {
return &MemcacheStore{
c: c,
sid: sid,
expire: expire,
data: kv,
}
}
func NewItem(sid string, data []byte, expire int32) *memcache.Item {
return &memcache.Item{
Key: sid,
Value: data,
Expiration: expire,
}
}
// Set sets value to given key in session.
func (s *MemcacheSessionStore) Set(key, val interface{}) error {
func (s *MemcacheStore) Set(key, val interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -46,7 +62,7 @@ func (s *MemcacheSessionStore) Set(key, val interface{}) error {
}
// Get gets value by given key in session.
func (s *MemcacheSessionStore) Get(key interface{}) interface{} {
func (s *MemcacheStore) Get(key interface{}) interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
@ -54,7 +70,7 @@ func (s *MemcacheSessionStore) Get(key interface{}) interface{} {
}
// Delete delete a key from session.
func (s *MemcacheSessionStore) Delete(key interface{}) error {
func (s *MemcacheStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -63,26 +79,22 @@ func (s *MemcacheSessionStore) Delete(key interface{}) error {
}
// ID returns current session ID.
func (s *MemcacheSessionStore) ID() string {
func (s *MemcacheStore) ID() string {
return s.sid
}
// Release releases resource and save data to provider.
func (s *MemcacheSessionStore) Release() error {
func (s *MemcacheStore) Release() error {
data, err := session.EncodeGob(s.data)
if err != nil {
return err
}
return client.Set(&memcache.Item{
Key: s.sid,
Value: data,
Expiration: int32(s.maxlifetime),
})
return s.c.Set(NewItem(s.sid, data, s.expire))
}
// Flush deletes all session data.
func (s *MemcacheSessionStore) Flush() error {
func (s *MemcacheStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
@ -90,41 +102,75 @@ func (s *MemcacheSessionStore) Flush() error {
return nil
}
// MemProvider represents a memcache session provider implementation.
type MemProvider struct {
maxlifetime int64
conninfo []string
poolsize int
password string
// MemcacheProvider represents a memcache session provider implementation.
type MemcacheProvider struct {
c *memcache.Client
expire int32
}
// Init initializes memory session provider.
// connStrs can be multiple connection strings separate by ;
// e.g. 127.0.0.1:9090
func (p *MemProvider) Init(maxlifetime int64, connStrs string) error {
p.maxlifetime = maxlifetime
p.conninfo = strings.Split(connStrs, ";")
client = memcache.New(p.conninfo...)
return nil
}
func (p *MemProvider) connectInit() error {
client = memcache.New(p.conninfo...)
// Init initializes memcache session provider.
// connStrs: 127.0.0.1:9090;127.0.0.1:9091
func (p *MemcacheProvider) Init(expire int64, connStrs string) error {
p.expire = int32(expire)
p.c = memcache.New(strings.Split(connStrs, ";")...)
return nil
}
// Read returns raw session store by session ID.
func (p *MemProvider) Read(sid string) (session.RawStore, error) {
if client == nil {
if err := p.connectInit(); err != nil {
func (p *MemcacheProvider) Read(sid string) (session.RawStore, error) {
if !p.Exist(sid) {
if err := p.c.Set(NewItem(sid, []byte(""), p.expire)); err != nil {
return nil, err
}
}
item, err := client.Get(sid)
var kv map[interface{}]interface{}
item, err := p.c.Get(sid)
if err != nil {
return nil, err
}
if len(item.Value) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(item.Value)
if err != nil {
return nil, err
}
}
return NewMemcacheStore(p.c, sid, p.expire, kv), nil
}
// Exist returns true if session with given ID exists.
func (p *MemcacheProvider) Exist(sid string) bool {
_, err := p.c.Get(sid)
return err == nil
}
// Destory deletes a session by session ID.
func (p *MemcacheProvider) Destory(sid string) error {
return p.c.Delete(sid)
}
// Regenerate regenerates a session store from old session ID to new one.
func (p *MemcacheProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) {
if p.Exist(sid) {
return nil, fmt.Errorf("new sid '%s' already exists", sid)
}
item := NewItem(sid, []byte(""), p.expire)
if p.Exist(oldsid) {
item, err = p.c.Get(oldsid)
if err != nil {
return nil, err
} else if err = p.c.Delete(oldsid); err != nil {
return nil, err
}
item.Key = sid
}
if err = p.c.Set(item); err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(item.Value) == 0 {
@ -136,86 +182,18 @@ func (p *MemProvider) Read(sid string) (session.RawStore, error) {
}
}
rs := &MemcacheSessionStore{sid: sid, data: kv, maxlifetime: p.maxlifetime}
return rs, nil
}
// Exist returns true if session with given ID exists.
func (p *MemProvider) Exist(sid string) bool {
if client == nil {
if err := p.connectInit(); err != nil {
return false
}
}
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
return false
} else {
return true
}
}
// Destory deletes a session by session ID.
func (p *MemProvider) Destory(sid string) error {
if client == nil {
if err := p.connectInit(); err != nil {
return err
}
}
return client.Delete(sid)
}
// Regenerate regenerates a session store from old session ID to new one.
func (p *MemProvider) Regenerate(oldsid, sid string) (session.RawStore, error) {
if client == nil {
if err := p.connectInit(); err != nil {
return nil, err
}
}
var contain []byte
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
item.Key = sid
item.Value = []byte("")
item.Expiration = int32(p.maxlifetime)
client.Set(item)
} else {
client.Delete(oldsid)
item.Key = sid
item.Value = item.Value
item.Expiration = int32(p.maxlifetime)
client.Set(item)
contain = item.Value
}
var kv map[interface{}]interface{}
if len(contain) == 0 {
kv = make(map[interface{}]interface{})
} else {
var err error
kv, err = session.DecodeGob(contain)
if err != nil {
return nil, err
}
}
rs := &MemcacheSessionStore{sid: sid, data: kv, maxlifetime: p.maxlifetime}
return rs, nil
return NewMemcacheStore(p.c, sid, p.expire, kv), nil
}
// Count counts and returns number of sessions.
func (p *MemProvider) Count() int {
// FIXME
return 0
func (p *MemcacheProvider) Count() int {
// FIXME: how come this library does not have Stats method?
return -1
}
// GC calls GC to clean expired sessions.
func (p *MemProvider) GC() {}
func (p *MemcacheProvider) GC() {}
func init() {
session.Register("memcache", &MemProvider{})
session.Register("memcache", &MemcacheProvider{})
}

View File

@ -0,0 +1 @@
ignore

View File

@ -0,0 +1,107 @@
// Copyright 2014 Unknwon
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package session
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Unknwon/macaron"
. "github.com/smartystreets/goconvey/convey"
"github.com/macaron-contrib/session"
)
func Test_MemcacheProvider(t *testing.T) {
Convey("Test memcache session provider", t, func() {
opt := session.Options{
Provider: "memcache",
ProviderConfig: "127.0.0.1:9090",
}
Convey("Basic operation", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
sess.Set("uname", "unknwon")
})
m.Get("/reg", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := raw.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
})
m.Get("/get", func(ctx *macaron.Context, sess session.Store) {
sid := sess.ID()
So(sid, ShouldNotBeEmpty)
raw, err := sess.Read(sid)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := sess.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
So(sess.Delete("uname"), ShouldBeNil)
So(sess.Get("uname"), ShouldBeNil)
So(sess.Destory(ctx), ShouldBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
m.ServeHTTP(resp, req)
cookie := resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/reg", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
cookie = resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/get", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
})
Convey("Regenrate empty session", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf486; Path=/;")
m.ServeHTTP(resp, req)
})
})
}

View File

@ -22,17 +22,17 @@ import (
"time"
)
// MemSessionStore represents a in-memory session store implementation.
type MemSessionStore struct {
// MemStore represents a in-memory session store implementation.
type MemStore struct {
sid string
lock sync.RWMutex
data map[interface{}]interface{}
lastAccess time.Time
}
// NewMemSessionStore creates and returns a memory session store.
func NewMemSessionStore(sid string) *MemSessionStore {
return &MemSessionStore{
// NewMemStore creates and returns a memory session store.
func NewMemStore(sid string) *MemStore {
return &MemStore{
sid: sid,
data: make(map[interface{}]interface{}),
lastAccess: time.Now(),
@ -40,7 +40,7 @@ func NewMemSessionStore(sid string) *MemSessionStore {
}
// Set sets value to given key in session.
func (s *MemSessionStore) Set(key, val interface{}) error {
func (s *MemStore) Set(key, val interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -49,15 +49,15 @@ func (s *MemSessionStore) Set(key, val interface{}) error {
}
// Get gets value by given key in session.
func (s *MemSessionStore) Get(key interface{}) interface{} {
func (s *MemStore) Get(key interface{}) interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
return s.data[key]
}
// Delete delete a key from session.
func (s *MemSessionStore) Delete(key interface{}) error {
// Delete deletes a key from session.
func (s *MemStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -66,17 +66,17 @@ func (s *MemSessionStore) Delete(key interface{}) error {
}
// ID returns current session ID.
func (s *MemSessionStore) ID() string {
func (s *MemStore) ID() string {
return s.sid
}
// Release releases resource and save data to provider.
func (_ *MemSessionStore) Release() error {
func (_ *MemStore) Release() error {
return nil
}
// Flush deletes all session data.
func (s *MemSessionStore) Flush() error {
func (s *MemStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
@ -105,7 +105,7 @@ func (p *MemProvider) update(sid string) error {
defer p.lock.Unlock()
if e, ok := p.data[sid]; ok {
e.Value.(*MemSessionStore).lastAccess = time.Now()
e.Value.(*MemStore).lastAccess = time.Now()
p.list.MoveToFront(e)
return nil
}
@ -122,14 +122,14 @@ func (p *MemProvider) Read(sid string) (_ RawStore, err error) {
if err = p.update(sid); err != nil {
return nil, err
}
return e.Value.(*MemSessionStore), nil
return e.Value.(*MemStore), nil
}
// Create a new session.
p.lock.Lock()
defer p.lock.Unlock()
s := NewMemSessionStore(sid)
s := NewMemStore(sid)
p.data[sid] = p.list.PushBack(s)
return s, nil
}
@ -173,7 +173,7 @@ func (p *MemProvider) Regenerate(oldsid, sid string) (RawStore, error) {
return nil, err
}
s.(*MemSessionStore).sid = sid
s.(*MemStore).sid = sid
p.data[sid] = p.list.PushBack(s)
return s, nil
}
@ -193,11 +193,11 @@ func (p *MemProvider) GC() {
break
}
if (e.Value.(*MemSessionStore).lastAccess.Unix() + p.maxLifetime) < time.Now().Unix() {
if (e.Value.(*MemStore).lastAccess.Unix() + p.maxLifetime) < time.Now().Unix() {
p.lock.RUnlock()
p.lock.Lock()
p.list.Remove(e)
delete(p.data, e.Value.(*MemSessionStore).sid)
delete(p.data, e.Value.(*MemStore).sid)
p.lock.Unlock()
p.lock.RLock()
} else {

View File

@ -17,6 +17,8 @@ package session
import (
"database/sql"
"fmt"
"log"
"sync"
"time"
@ -25,16 +27,25 @@ import (
"github.com/macaron-contrib/session"
)
// MysqlSessionStore represents a mysql session store implementation.
type MysqlSessionStore struct {
// MysqlStore represents a mysql session store implementation.
type MysqlStore struct {
c *sql.DB
sid string
lock sync.RWMutex
data map[interface{}]interface{}
}
// NewMysqlStore creates and returns a mysql session store.
func NewMysqlStore(c *sql.DB, sid string, kv map[interface{}]interface{}) *MysqlStore {
return &MysqlStore{
c: c,
sid: sid,
data: kv,
}
}
// Set sets value to given key in session.
func (s *MysqlSessionStore) Set(key, val interface{}) error {
func (s *MysqlStore) Set(key, val interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -43,7 +54,7 @@ func (s *MysqlSessionStore) Set(key, val interface{}) error {
}
// Get gets value by given key in session.
func (s *MysqlSessionStore) Get(key interface{}) interface{} {
func (s *MysqlStore) Get(key interface{}) interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
@ -51,7 +62,7 @@ func (s *MysqlSessionStore) Get(key interface{}) interface{} {
}
// Delete delete a key from session.
func (s *MysqlSessionStore) Delete(key interface{}) error {
func (s *MysqlStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -60,24 +71,24 @@ func (s *MysqlSessionStore) Delete(key interface{}) error {
}
// ID returns current session ID.
func (s *MysqlSessionStore) ID() string {
func (s *MysqlStore) ID() string {
return s.sid
}
// Release releases resource and save data to provider.
func (s *MysqlSessionStore) Release() error {
defer s.c.Close()
func (s *MysqlStore) Release() error {
data, err := session.EncodeGob(s.data)
if err != nil {
return err
}
_, err = s.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
_, err = s.c.Exec("UPDATE session SET data=?, expiry=? WHERE `key`=?",
data, time.Now().Unix(), s.sid)
return err
}
// Flush deletes all session data.
func (s *MysqlSessionStore) Flush() error {
func (s *MysqlStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
@ -87,113 +98,96 @@ func (s *MysqlSessionStore) Flush() error {
// MysqlProvider represents a mysql session provider implementation.
type MysqlProvider struct {
maxlifetime int64
connStr string
c *sql.DB
expire int64
}
func (p *MysqlProvider) connectInit() *sql.DB {
db, e := sql.Open("mysql", p.connStr)
if e != nil {
return nil
// Init initializes mysql session provider.
// connStr: username:password@protocol(address)/dbname?param=value
func (p *MysqlProvider) Init(expire int64, connStr string) (err error) {
p.expire = expire
p.c, err = sql.Open("mysql", connStr)
if err != nil {
return err
}
return db
}
// Init initializes memory session provider.
func (p *MysqlProvider) Init(maxlifetime int64, connStr string) error {
p.maxlifetime = maxlifetime
p.connStr = connStr
return nil
return p.c.Ping()
}
// Read returns raw session store by session ID.
func (p *MysqlProvider) Read(sid string) (session.RawStore, error) {
c := p.connectInit()
row := c.QueryRow("select session_data from session where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
var data []byte
err := p.c.QueryRow("SELECT data FROM session WHERE `key`=?", sid).Scan(&data)
if err == sql.ErrNoRows {
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
_, err = p.c.Exec("INSERT INTO session(`key`,data,expiry) VALUES(?,?,?)",
sid, "", time.Now().Unix())
}
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
if len(data) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
kv, err = session.DecodeGob(data)
if err != nil {
return nil, err
}
}
rs := &MysqlSessionStore{c: c, sid: sid, data: kv}
return rs, nil
return NewMysqlStore(p.c, sid, kv), nil
}
// Exist returns true if session with given ID exists.
func (p *MysqlProvider) Exist(sid string) bool {
c := p.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
return false
} else {
return true
var data []byte
err := p.c.QueryRow("SELECT data FROM session WHERE `key`=?", sid).Scan(&data)
if err != nil && err != sql.ErrNoRows {
panic("session/mysql: error checking existence: " + err.Error())
}
return err != sql.ErrNoRows
}
// Destory deletes a session by session ID.
func (p *MysqlProvider) Destory(sid string) (err error) {
c := p.connectInit()
if _, err = c.Exec("DELETE FROM session where session_key=?", sid); err != nil {
func (p *MysqlProvider) Destory(sid string) error {
_, err := p.c.Exec("DELETE FROM session WHERE `key`=?", sid)
return err
}
return c.Close()
}
// Regenerate regenerates a session store from old session ID to new one.
func (p *MysqlProvider) Regenerate(oldsid, sid string) (session.RawStore, error) {
c := p.connectInit()
row := c.QueryRow("select session_data from session where session_key=?", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
func (p *MysqlProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) {
if p.Exist(sid) {
return nil, fmt.Errorf("new sid '%s' already exists", sid)
}
c.Exec("update session set `session_key`=? where session_key=?", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
if !p.Exist(oldsid) {
if _, err = p.c.Exec("INSERT INTO session(`key`,data,expiry) VALUES(?,?,?)",
oldsid, "", time.Now().Unix()); err != nil {
return nil, err
}
}
rs := &MysqlSessionStore{c: c, sid: sid, data: kv}
return rs, nil
if _, err = p.c.Exec("UPDATE session SET `key`=? WHERE `key`=?", sid, oldsid); err != nil {
return nil, err
}
return p.Read(sid)
}
// Count counts and returns number of sessions.
func (p *MysqlProvider) Count() int {
c := p.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
if err != nil {
return 0
func (p *MysqlProvider) Count() (total int) {
if err := p.c.QueryRow("SELECT COUNT(*) AS NUM FROM session").Scan(&total); err != nil {
panic("session/mysql: error counting records: " + err.Error())
}
return total
}
// GC calls GC to clean expired sessions.
func (mp *MysqlProvider) GC() {
c := mp.connectInit()
c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
c.Close()
func (p *MysqlProvider) GC() {
if _, err := p.c.Exec("DELETE FROM session WHERE UNIX_TIMESTAMP(NOW()) - expiry > ?", p.expire); err != nil {
log.Printf("session/mysql: error garbage collecting: %v", err)
}
}
func init() {

View File

@ -0,0 +1 @@
ignore

View File

@ -0,0 +1,138 @@
// Copyright 2014 Unknwon
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package session
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Unknwon/macaron"
. "github.com/smartystreets/goconvey/convey"
"github.com/macaron-contrib/session"
)
func Test_MysqlProvider(t *testing.T) {
Convey("Test mysql session provider", t, func() {
opt := session.Options{
Provider: "mysql",
ProviderConfig: "root:@tcp(localhost:3306)/macaron?charset=utf8",
}
Convey("Basic operation", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
sess.Set("uname", "unknwon")
})
m.Get("/reg", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := raw.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
})
m.Get("/get", func(ctx *macaron.Context, sess session.Store) {
sid := sess.ID()
So(sid, ShouldNotBeEmpty)
raw, err := sess.Read(sid)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
So(raw.Release(), ShouldBeNil)
uname := sess.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
So(sess.Delete("uname"), ShouldBeNil)
So(sess.Get("uname"), ShouldBeNil)
So(sess.Destory(ctx), ShouldBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
m.ServeHTTP(resp, req)
cookie := resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/reg", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
cookie = resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/get", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
})
Convey("Regenrate empty session", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
So(sess.Destory(ctx), ShouldBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf48; Path=/;")
m.ServeHTTP(resp, req)
})
Convey("GC session", func() {
m := macaron.New()
opt2 := opt
opt2.Gclifetime = 1
m.Use(session.Sessioner(opt2))
m.Get("/", func(sess session.Store) {
sess.Set("uname", "unknwon")
So(sess.ID(), ShouldNotBeEmpty)
uname := sess.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
So(sess.Flush(), ShouldBeNil)
So(sess.Get("uname"), ShouldBeNil)
time.Sleep(2 * time.Second)
sess.GC()
So(sess.Count(), ShouldEqual, 0)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
m.ServeHTTP(resp, req)
})
})
}

View File

@ -0,0 +1,203 @@
// Copyright 2015 Unknwon
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package session
import (
"fmt"
"sync"
"github.com/lunny/nodb"
"github.com/lunny/nodb/config"
"github.com/macaron-contrib/session"
)
// NodbStore represents a nodb session store implementation.
type NodbStore struct {
c *nodb.DB
sid string
expire int64
lock sync.RWMutex
data map[interface{}]interface{}
}
// NewNodbStore creates and returns a ledis session store.
func NewNodbStore(c *nodb.DB, sid string, expire int64, kv map[interface{}]interface{}) *NodbStore {
return &NodbStore{
c: c,
expire: expire,
sid: sid,
data: kv,
}
}
// Set sets value to given key in session.
func (s *NodbStore) Set(key, val interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
s.data[key] = val
return nil
}
// Get gets value by given key in session.
func (s *NodbStore) Get(key interface{}) interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
return s.data[key]
}
// Delete delete a key from session.
func (s *NodbStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.data, key)
return nil
}
// ID returns current session ID.
func (s *NodbStore) ID() string {
return s.sid
}
// Release releases resource and save data to provider.
func (s *NodbStore) Release() error {
data, err := session.EncodeGob(s.data)
if err != nil {
return err
}
if err = s.c.Set([]byte(s.sid), data); err != nil {
return err
}
_, err = s.c.Expire([]byte(s.sid), s.expire)
return err
}
// Flush deletes all session data.
func (s *NodbStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
s.data = make(map[interface{}]interface{})
return nil
}
// NodbProvider represents a ledis session provider implementation.
type NodbProvider struct {
c *nodb.DB
expire int64
}
// Init initializes nodb session provider.
func (p *NodbProvider) Init(expire int64, configs string) error {
p.expire = expire
cfg := new(config.Config)
cfg.DataDir = configs
dbs, err := nodb.Open(cfg)
if err != nil {
return fmt.Errorf("session/nodb: error opening db: %v", err)
}
p.c, err = dbs.Select(0)
return err
}
// Read returns raw session store by session ID.
func (p *NodbProvider) Read(sid string) (session.RawStore, error) {
if !p.Exist(sid) {
if err := p.c.Set([]byte(sid), []byte("")); err != nil {
return nil, err
}
}
var kv map[interface{}]interface{}
kvs, err := p.c.Get([]byte(sid))
if err != nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(kvs)
if err != nil {
return nil, err
}
}
return NewNodbStore(p.c, sid, p.expire, kv), nil
}
// Exist returns true if session with given ID exists.
func (p *NodbProvider) Exist(sid string) bool {
count, err := p.c.Exists([]byte(sid))
return err == nil && count > 0
}
// Destory deletes a session by session ID.
func (p *NodbProvider) Destory(sid string) error {
_, err := p.c.Del([]byte(sid))
return err
}
// Regenerate regenerates a session store from old session ID to new one.
func (p *NodbProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) {
if p.Exist(sid) {
return nil, fmt.Errorf("new sid '%s' already exists", sid)
}
kvs := make([]byte, 0)
if p.Exist(oldsid) {
if kvs, err = p.c.Get([]byte(oldsid)); err != nil {
return nil, err
} else if _, err = p.c.Del([]byte(oldsid)); err != nil {
return nil, err
}
}
if err = p.c.Set([]byte(sid), kvs); err != nil {
return nil, err
} else if _, err = p.c.Expire([]byte(sid), p.expire); err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob([]byte(kvs))
if err != nil {
return nil, err
}
}
return NewNodbStore(p.c, sid, p.expire, kv), nil
}
// Count counts and returns number of sessions.
func (p *NodbProvider) Count() int {
// FIXME: how come this library does not have DbSize() method?
return -1
}
// GC calls GC to clean expired sessions.
func (p *NodbProvider) GC() {}
func init() {
session.Register("nodb", &NodbProvider{})
}

View File

@ -0,0 +1 @@
ignore

View File

@ -0,0 +1,105 @@
// Copyright 2015 Unknwon
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package session
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Unknwon/macaron"
. "github.com/smartystreets/goconvey/convey"
"github.com/macaron-contrib/session"
)
func Test_LedisProvider(t *testing.T) {
Convey("Test nodb session provider", t, func() {
opt := session.Options{
Provider: "nodb",
ProviderConfig: "./tmp.db",
}
Convey("Basic operation", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
sess.Set("uname", "unknwon")
})
m.Get("/reg", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := raw.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
})
m.Get("/get", func(ctx *macaron.Context, sess session.Store) {
sid := sess.ID()
So(sid, ShouldNotBeEmpty)
raw, err := sess.Read(sid)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := sess.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
So(sess.Delete("uname"), ShouldBeNil)
So(sess.Get("uname"), ShouldBeNil)
So(sess.Destory(ctx), ShouldBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
m.ServeHTTP(resp, req)
cookie := resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/reg", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
cookie = resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/get", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
Convey("Regenrate empty session", func() {
m.Get("/empty", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
})
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/empty", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf486; Path=/;")
m.ServeHTTP(resp, req)
})
})
})
}

View File

@ -0,0 +1,196 @@
// Copyright 2013 Beego Authors
// Copyright 2014 Unknwon
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package session
import (
"database/sql"
"fmt"
"log"
"sync"
"time"
_ "github.com/lib/pq"
"github.com/macaron-contrib/session"
)
// PostgresStore represents a postgres session store implementation.
type PostgresStore struct {
c *sql.DB
sid string
lock sync.RWMutex
data map[interface{}]interface{}
}
// NewPostgresStore creates and returns a postgres session store.
func NewPostgresStore(c *sql.DB, sid string, kv map[interface{}]interface{}) *PostgresStore {
return &PostgresStore{
c: c,
sid: sid,
data: kv,
}
}
// Set sets value to given key in session.
func (s *PostgresStore) Set(key, value interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
s.data[key] = value
return nil
}
// Get gets value by given key in session.
func (s *PostgresStore) Get(key interface{}) interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
return s.data[key]
}
// Delete delete a key from session.
func (s *PostgresStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.data, key)
return nil
}
// ID returns current session ID.
func (s *PostgresStore) ID() string {
return s.sid
}
// save postgres session values to database.
// must call this method to save values to database.
func (s *PostgresStore) Release() error {
data, err := session.EncodeGob(s.data)
if err != nil {
return err
}
_, err = s.c.Exec("UPDATE session SET data=$1, expiry=$2 WHERE key=$3",
data, time.Now().Unix(), s.sid)
return err
}
// Flush deletes all session data.
func (s *PostgresStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
s.data = make(map[interface{}]interface{})
return nil
}
// PostgresProvider represents a postgres session provider implementation.
type PostgresProvider struct {
c *sql.DB
maxlifetime int64
}
// Init initializes postgres session provider.
// connStr: user=a password=b host=localhost port=5432 dbname=c sslmode=disable
func (p *PostgresProvider) Init(maxlifetime int64, connStr string) (err error) {
p.maxlifetime = maxlifetime
p.c, err = sql.Open("postgres", connStr)
if err != nil {
return err
}
return p.c.Ping()
}
// Read returns raw session store by session ID.
func (p *PostgresProvider) Read(sid string) (session.RawStore, error) {
var data []byte
err := p.c.QueryRow("SELECT data FROM session WHERE key=$1", sid).Scan(&data)
if err == sql.ErrNoRows {
_, err = p.c.Exec("INSERT INTO session(key,data,expiry) VALUES($1,$2,$3)",
sid, "", time.Now().Unix())
}
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(data) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(data)
if err != nil {
return nil, err
}
}
return NewPostgresStore(p.c, sid, kv), nil
}
// Exist returns true if session with given ID exists.
func (p *PostgresProvider) Exist(sid string) bool {
var data []byte
err := p.c.QueryRow("SELECT data FROM session WHERE key=$1", sid).Scan(&data)
if err != nil && err != sql.ErrNoRows {
panic("session/postgres: error checking existence: " + err.Error())
}
return err != sql.ErrNoRows
}
// Destory deletes a session by session ID.
func (p *PostgresProvider) Destory(sid string) error {
_, err := p.c.Exec("DELETE FROM session WHERE key=$1", sid)
return err
}
// Regenerate regenerates a session store from old session ID to new one.
func (p *PostgresProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) {
if p.Exist(sid) {
return nil, fmt.Errorf("new sid '%s' already exists", sid)
}
if !p.Exist(oldsid) {
if _, err = p.c.Exec("INSERT INTO session(key,data,expiry) VALUES($1,$2,$3)",
oldsid, "", time.Now().Unix()); err != nil {
return nil, err
}
}
if _, err = p.c.Exec("UPDATE session SET key=$1 WHERE key=$2", sid, oldsid); err != nil {
return nil, err
}
return p.Read(sid)
}
// Count counts and returns number of sessions.
func (p *PostgresProvider) Count() (total int) {
if err := p.c.QueryRow("SELECT COUNT(*) AS NUM FROM session").Scan(&total); err != nil {
panic("session/postgres: error counting records: " + err.Error())
}
return total
}
// GC calls GC to clean expired sessions.
func (p *PostgresProvider) GC() {
if _, err := p.c.Exec("DELETE FROM session WHERE EXTRACT(EPOCH FROM NOW()) - expiry > $1", p.maxlifetime); err != nil {
log.Printf("session/postgres: error garbage collecting: %v", err)
}
}
func init() {
session.Register("postgres", &PostgresProvider{})
}

View File

@ -0,0 +1 @@
ignore

View File

@ -0,0 +1,138 @@
// Copyright 2014 Unknwon
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package session
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Unknwon/macaron"
. "github.com/smartystreets/goconvey/convey"
"github.com/macaron-contrib/session"
)
func Test_PostgresProvider(t *testing.T) {
Convey("Test postgres session provider", t, func() {
opt := session.Options{
Provider: "postgres",
ProviderConfig: "user=jiahuachen dbname=macaron port=5432 sslmode=disable",
}
Convey("Basic operation", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
sess.Set("uname", "unknwon")
})
m.Get("/reg", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := raw.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
})
m.Get("/get", func(ctx *macaron.Context, sess session.Store) {
sid := sess.ID()
So(sid, ShouldNotBeEmpty)
raw, err := sess.Read(sid)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
So(raw.Release(), ShouldBeNil)
uname := sess.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
So(sess.Delete("uname"), ShouldBeNil)
So(sess.Get("uname"), ShouldBeNil)
So(sess.Destory(ctx), ShouldBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
m.ServeHTTP(resp, req)
cookie := resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/reg", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
cookie = resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/get", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
})
Convey("Regenrate empty session", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
So(sess.Destory(ctx), ShouldBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf48; Path=/;")
m.ServeHTTP(resp, req)
})
Convey("GC session", func() {
m := macaron.New()
opt2 := opt
opt2.Gclifetime = 1
m.Use(session.Sessioner(opt2))
m.Get("/", func(sess session.Store) {
sess.Set("uname", "unknwon")
So(sess.ID(), ShouldNotBeEmpty)
uname := sess.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
So(sess.Flush(), ShouldBeNil)
So(sess.Get("uname"), ShouldBeNil)
time.Sleep(2 * time.Second)
sess.GC()
So(sess.Count(), ShouldEqual, 0)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
m.ServeHTTP(resp, req)
})
})
}

View File

@ -1,211 +0,0 @@
// Copyright 2013 Beego Authors
// Copyright 2014 Unknwon
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package session
import (
"database/sql"
"sync"
"time"
_ "github.com/lib/pq"
"github.com/macaron-contrib/session"
)
// PostgresqlSessionStore represents a postgresql session store implementation.
type PostgresqlSessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
data map[interface{}]interface{}
}
// Set sets value to given key in session.
func (s *PostgresqlSessionStore) Set(key, value interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
s.data[key] = value
return nil
}
// Get gets value by given key in session.
func (s *PostgresqlSessionStore) Get(key interface{}) interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
return s.data[key]
}
// Delete delete a key from session.
func (s *PostgresqlSessionStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.data, key)
return nil
}
// ID returns current session ID.
func (s *PostgresqlSessionStore) ID() string {
return s.sid
}
// save postgresql session values to database.
// must call this method to save values to database.
func (s *PostgresqlSessionStore) Release() error {
defer s.c.Close()
data, err := session.EncodeGob(s.data)
if err != nil {
return err
}
_, err = s.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3",
data, time.Now().Format(time.RFC3339), s.sid)
return err
}
// Flush deletes all session data.
func (s *PostgresqlSessionStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
s.data = make(map[interface{}]interface{})
return nil
}
// PostgresqlProvider represents a postgresql session provider implementation.
type PostgresqlProvider struct {
maxlifetime int64
connStr string
}
func (p *PostgresqlProvider) connectInit() *sql.DB {
db, e := sql.Open("postgres", p.connStr)
if e != nil {
return nil
}
return db
}
// Init initializes memory session provider.
func (p *PostgresqlProvider) Init(maxlifetime int64, connStr string) error {
p.maxlifetime = maxlifetime
p.connStr = connStr
return nil
}
// Read returns raw session store by session ID.
func (p *PostgresqlProvider) Read(sid string) (session.RawStore, error) {
c := p.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
_, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
sid, "", time.Now().Format(time.RFC3339))
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &PostgresqlSessionStore{c: c, sid: sid, data: kv}
return rs, nil
}
// Exist returns true if session with given ID exists.
func (p *PostgresqlProvider) Exist(sid string) bool {
c := p.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
return false
} else {
return true
}
}
// Destory deletes a session by session ID.
func (p *PostgresqlProvider) Destory(sid string) (err error) {
c := p.connectInit()
if _, err = c.Exec("DELETE FROM session where session_key=$1", sid); err != nil {
return err
}
return c.Close()
}
// Regenerate regenerates a session store from old session ID to new one.
func (p *PostgresqlProvider) Regenerate(oldsid, sid string) (session.RawStore, error) {
c := p.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
oldsid, "", time.Now().Format(time.RFC3339))
}
c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &PostgresqlSessionStore{c: c, sid: sid, data: kv}
return rs, nil
}
// Count counts and returns number of sessions.
func (p *PostgresqlProvider) Count() int {
c := p.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
if err != nil {
return 0
}
return total
}
// GC calls GC to clean expired sessions.
func (mp *PostgresqlProvider) GC() {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
}
func init() {
session.Register("postgresql", &PostgresqlProvider{})
}

View File

@ -16,31 +16,39 @@
package session
import (
"strconv"
"fmt"
"strings"
"sync"
"time"
"github.com/beego/redigo/redis"
"github.com/Unknwon/com"
"gopkg.in/ini.v1"
"gopkg.in/redis.v2"
"github.com/macaron-contrib/session"
)
// redis max pool size
var MAX_POOL_SIZE = 100
var redisPool chan redis.Conn
// RedisSessionStore represents a redis session store implementation.
type RedisSessionStore struct {
p *redis.Pool
// RedisStore represents a redis session store implementation.
type RedisStore struct {
c *redis.Client
sid string
duration time.Duration
lock sync.RWMutex
data map[interface{}]interface{}
maxlifetime int64
}
// NewRedisStore creates and returns a redis session store.
func NewRedisStore(c *redis.Client, sid string, dur time.Duration, kv map[interface{}]interface{}) *RedisStore {
return &RedisStore{
c: c,
sid: sid,
duration: dur,
data: kv,
}
}
// Set sets value to given key in session.
func (s *RedisSessionStore) Set(key, val interface{}) error {
func (s *RedisStore) Set(key, val interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -49,7 +57,7 @@ func (s *RedisSessionStore) Set(key, val interface{}) error {
}
// Get gets value by given key in session.
func (s *RedisSessionStore) Get(key interface{}) interface{} {
func (s *RedisStore) Get(key interface{}) interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
@ -57,7 +65,7 @@ func (s *RedisSessionStore) Get(key interface{}) interface{} {
}
// Delete delete a key from session.
func (s *RedisSessionStore) Delete(key interface{}) error {
func (s *RedisStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
@ -66,26 +74,22 @@ func (s *RedisSessionStore) Delete(key interface{}) error {
}
// ID returns current session ID.
func (s *RedisSessionStore) ID() string {
func (s *RedisStore) ID() string {
return s.sid
}
// Release releases resource and save data to provider.
func (s *RedisSessionStore) Release() error {
c := s.p.Get()
defer c.Close()
func (s *RedisStore) Release() error {
data, err := session.EncodeGob(s.data)
if err != nil {
return err
}
_, err = c.Do("SETEX", s.sid, s.maxlifetime, string(data))
return err
return s.c.SetEx(s.sid, s.duration, string(data)).Err()
}
// Flush deletes all session data.
func (s *RedisSessionStore) Flush() error {
func (s *RedisStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
@ -95,59 +99,65 @@ func (s *RedisSessionStore) Flush() error {
// RedisProvider represents a redis session provider implementation.
type RedisProvider struct {
maxlifetime int64
connAddr string
poolsize int
password string
poollist *redis.Pool
c *redis.Client
duration time.Duration
}
// Init initializes memory session provider.
// connStr: <redis server addr>,<pool size>,<password>
// e.g. 127.0.0.1:6379,100,macaron
func (p *RedisProvider) Init(maxlifetime int64, connStr string) error {
p.maxlifetime = maxlifetime
configs := strings.Split(connStr, ",")
if len(configs) > 0 {
p.connAddr = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize <= 0 {
p.poolsize = MAX_POOL_SIZE
} else {
p.poolsize = poolsize
}
} else {
p.poolsize = MAX_POOL_SIZE
}
if len(configs) > 2 {
p.password = configs[2]
}
p.poollist = redis.NewPool(func() (redis.Conn, error) {
c, err := redis.Dial("tcp", p.connAddr)
// Init initializes redis session provider.
// configs: network=tcp,addr=:6379,password=macaron,db=0,pool_size=100,idle_timeout=180
func (p *RedisProvider) Init(maxlifetime int64, configs string) (err error) {
p.duration, err = time.ParseDuration(fmt.Sprintf("%ds", maxlifetime))
if err != nil {
return nil, err
return err
}
if p.password != "" {
if _, err := c.Do("AUTH", p.password); err != nil {
c.Close()
return nil, err
}
}
return c, err
}, p.poolsize)
return p.poollist.Get().Err()
cfg, err := ini.Load([]byte(strings.Replace(configs, ",", "\n", -1)))
if err != nil {
return err
}
opt := &redis.Options{
Network: "tcp",
}
for k, v := range cfg.Section("").KeysHash() {
switch k {
case "network":
opt.Network = v
case "addr":
opt.Addr = v
case "password":
opt.Password = v
case "db":
opt.DB = com.StrTo(v).MustInt64()
case "pool_size":
opt.PoolSize = com.StrTo(v).MustInt()
case "idle_timeout":
opt.IdleTimeout, err = time.ParseDuration(v + "s")
if err != nil {
return fmt.Errorf("error parsing idle timeout: %v", err)
}
default:
return fmt.Errorf("session/redis: unsupported option '%s'", k)
}
}
p.c = redis.NewClient(opt)
return p.c.Ping().Err()
}
// Read returns raw session store by session ID.
func (p *RedisProvider) Read(sid string) (session.RawStore, error) {
c := p.poollist.Get()
defer c.Close()
if !p.Exist(sid) {
if err := p.c.Set(sid, "").Err(); err != nil {
return nil, err
}
}
kvs, err := redis.String(c.Do("GET", sid))
var kv map[interface{}]interface{}
kvs, err := p.c.Get(sid).Result()
if err != nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
@ -157,48 +167,41 @@ func (p *RedisProvider) Read(sid string) (session.RawStore, error) {
}
}
rs := &RedisSessionStore{p: p.poollist, sid: sid, data: kv, maxlifetime: p.maxlifetime}
return rs, nil
return NewRedisStore(p.c, sid, p.duration, kv), nil
}
// Exist returns true if session with given ID exists.
func (p *RedisProvider) Exist(sid string) bool {
c := p.poollist.Get()
defer c.Close()
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
return false
} else {
return true
}
has, err := p.c.Exists(sid).Result()
return err == nil && has
}
// Destory deletes a session by session ID.
func (p *RedisProvider) Destory(sid string) error {
c := p.poollist.Get()
defer c.Close()
_, err := c.Do("DEL", sid)
return err
return p.c.Del(sid).Err()
}
// Regenerate regenerates a session store from old session ID to new one.
func (p *RedisProvider) Regenerate(oldsid, sid string) (session.RawStore, error) {
c := p.poollist.Get()
defer c.Close()
if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Do("SET", sid, "", "EX", p.maxlifetime)
} else {
c.Do("RENAME", oldsid, sid)
c.Do("EXPIRE", sid, p.maxlifetime)
func (p *RedisProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) {
if p.Exist(sid) {
return nil, fmt.Errorf("new sid '%s' already exists", sid)
} else if !p.Exist(oldsid) {
// Make a fake old session.
if err = p.c.SetEx(oldsid, p.duration, "").Err(); err != nil {
return nil, err
}
}
if err = p.c.Rename(oldsid, sid).Err(); err != nil {
return nil, err
}
kvs, err := redis.String(c.Do("GET", sid))
var kv map[interface{}]interface{}
kvs, err := p.c.Get(sid).Result()
if err != nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
@ -208,14 +211,12 @@ func (p *RedisProvider) Regenerate(oldsid, sid string) (session.RawStore, error)
}
}
rs := &RedisSessionStore{p: p.poollist, sid: sid, data: kv, maxlifetime: p.maxlifetime}
return rs, nil
return NewRedisStore(p.c, sid, p.duration, kv), nil
}
// Count counts and returns number of sessions.
func (p *RedisProvider) Count() int {
// FIXME
return 0
return int(p.c.DbSize().Val())
}
// GC calls GC to clean expired sessions.

View File

@ -0,0 +1 @@
ignore

View File

@ -0,0 +1,107 @@
// Copyright 2014 Unknwon
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package session
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Unknwon/macaron"
. "github.com/smartystreets/goconvey/convey"
"github.com/macaron-contrib/session"
)
func Test_RedisProvider(t *testing.T) {
Convey("Test redis session provider", t, func() {
opt := session.Options{
Provider: "redis",
ProviderConfig: "addr=:6379",
}
Convey("Basic operation", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
sess.Set("uname", "unknwon")
})
m.Get("/reg", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := raw.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
})
m.Get("/get", func(ctx *macaron.Context, sess session.Store) {
sid := sess.ID()
So(sid, ShouldNotBeEmpty)
raw, err := sess.Read(sid)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
uname := sess.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon")
So(sess.Delete("uname"), ShouldBeNil)
So(sess.Get("uname"), ShouldBeNil)
So(sess.Destory(ctx), ShouldBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
m.ServeHTTP(resp, req)
cookie := resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/reg", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
cookie = resp.Header().Get("Set-Cookie")
resp = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/get", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", cookie)
m.ServeHTTP(resp, req)
})
Convey("Regenrate empty session", func() {
m := macaron.New()
m.Use(session.Sessioner(opt))
m.Get("/", func(ctx *macaron.Context, sess session.Store) {
raw, err := sess.RegenerateId(ctx)
So(err, ShouldBeNil)
So(raw, ShouldNotBeNil)
})
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
So(err, ShouldBeNil)
req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf486; Path=/;")
m.ServeHTTP(resp, req)
})
})
}

View File

@ -13,7 +13,7 @@
// License for the specific language governing permissions and limitations
// under the License.
// Package session a middleware that provides the session manager of Macaron.
// Package session a middleware that provides the session management of Macaron.
package session
// NOTE: last sync 000033e on Nov 4, 2014.
@ -28,7 +28,7 @@ import (
"github.com/Unknwon/macaron"
)
const _VERSION = "0.1.1"
const _VERSION = "0.1.6"
func Version() string {
return _VERSION
@ -37,11 +37,11 @@ func Version() string {
// RawStore is the interface that operates the session data.
type RawStore interface {
// Set sets value to given key in session.
Set(key, value interface{}) error
Set(interface{}, interface{}) error
// Get gets value by given key in session.
Get(key interface{}) interface{}
// Delete delete a key from session.
Delete(key interface{}) error
Get(interface{}) interface{}
// Delete deletes a key from session.
Delete(interface{}) error
// ID returns current session ID.
ID() string
// Release releases session resource and save data to provider.
@ -54,7 +54,7 @@ type RawStore interface {
type Store interface {
RawStore
// Read returns raw session store by session ID.
Read(sid string) (RawStore, error)
Read(string) (RawStore, error)
// Destory deletes a session.
Destory(*macaron.Context) error
// RegenerateId regenerates a session store from old session ID to new one.
@ -111,7 +111,7 @@ func prepareOptions(options []Options) Options {
if len(opt.Provider) == 0 {
opt.Provider = sec.Key("PROVIDER").MustString("memory")
}
if len(opt.ProviderConfig) == 0 && opt.Provider == "file" {
if len(opt.ProviderConfig) == 0 {
opt.ProviderConfig = sec.Key("PROVIDER_CONFIG").MustString("data/sessions")
}
if len(opt.CookieName) == 0 {
@ -155,7 +155,7 @@ func Sessioner(options ...Options) macaron.Handler {
return func(ctx *macaron.Context) {
sess, err := manager.Start(ctx)
if err != nil {
panic("session: " + err.Error())
panic("session(start): " + err.Error())
}
// Get flash.
@ -187,8 +187,8 @@ func Sessioner(options ...Options) macaron.Handler {
ctx.Next()
if sess.Release() != nil {
panic("session: " + err.Error())
if err = sess.Release(); err != nil {
panic("session(release): " + err.Error())
}
}
}
@ -242,17 +242,14 @@ type Manager struct {
func NewManager(name string, opt Options) (*Manager, error) {
p, ok := providers[name]
if !ok {
return nil, fmt.Errorf("session: unknown provider %q(forgotten import?)", name)
return nil, fmt.Errorf("session: unknown provider '%s'(forgotten import?)", name)
}
if err := p.Init(opt.Maxlifetime, opt.ProviderConfig); err != nil {
return nil, err
}
return &Manager{p, opt}, nil
return &Manager{p, opt}, p.Init(opt.Maxlifetime, opt.ProviderConfig)
}
// sessionId generates a new session ID with rand string, unix nano time, remote addr by hash function.
func (m *Manager) sessionId() string {
return hex.EncodeToString(generateRandomKey(m.opt.IDLength))
return hex.EncodeToString(generateRandomKey(m.opt.IDLength / 2))
}
// Start starts a session by generating new one
@ -315,17 +312,10 @@ func (m *Manager) Destory(ctx *macaron.Context) error {
func (m *Manager) RegenerateId(ctx *macaron.Context) (sess RawStore, err error) {
sid := m.sessionId()
oldsid := ctx.GetCookie(m.opt.CookieName)
if len(oldsid) == 0 {
sess, err = m.provider.Read(oldsid)
if err != nil {
return nil, err
}
} else {
sess, err = m.provider.Regenerate(oldsid, sid)
if err != nil {
return nil, err
}
}
ck := &http.Cookie{
Name: m.opt.CookieName,
Value: sid,

View File

@ -42,7 +42,7 @@ func Test_Sessioner(t *testing.T) {
m.ServeHTTP(resp, req)
})
Convey("Register invalid provider that", t, func() {
Convey("Register invalid provider", t, func() {
Convey("Provider not exists", func() {
defer func() {
So(recover(), ShouldNotBeNil)

View File

@ -24,39 +24,19 @@ import (
"github.com/Unknwon/com"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
for _, v := range obj {
gob.Register(v)
}
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
err := gob.NewEncoder(buf).Encode(obj)
return buf.Bytes(), err
}
func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) {
func DecodeGob(encoded []byte) (out map[interface{}]interface{}, err error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
err = gob.NewDecoder(buf).Decode(&out)
return out, err
}
// generateRandomKey creates a random key with the given strength.

View File

@ -41,12 +41,18 @@ FAQ
> See: https://github.com/mattn/go-sqlite3/issues/106
> See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html
* Want to get time.Time with current locale
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
License
-------
MIT: http://mattn.mit-license.org/2012
sqlite.c, sqlite3.h, sqlite3ext.h
sqlite3-binding.c, sqlite3-binding.h, sqlite3ext.h
The -binding suffix was added to avoid build failures under gccgo.
In this repository, those files are amalgamation code that copied from SQLite3. The license of those codes are depend on the license of SQLite3.

View File

@ -6,7 +6,7 @@
package sqlite3
/*
#include <sqlite3.h>
#include <sqlite3-binding.h>
#include <stdlib.h>
*/
import "C"

View File

@ -231,6 +231,12 @@ func TestExtendedErrorCodes_Unique(t *testing.T) {
t.Errorf("Wrong extended error code: %d != %d",
sqliteErr.ExtendedCode, ErrConstraintUnique)
}
extended := sqliteErr.Code.Extend(3).Error()
expected := "constraint failed"
if extended != expected {
t.Errorf("Wrong basic error code: %q != %q",
extended, expected)
}
}
}

View File

@ -6,7 +6,10 @@
package sqlite3
/*
#include <sqlite3.h>
#cgo CFLAGS: -std=gnu99
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS
#include <sqlite3-binding.h>
#include <stdlib.h>
#include <string.h>
@ -44,14 +47,23 @@ _sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
#include <stdio.h>
#include <stdint.h>
static long
_sqlite3_last_insert_rowid(sqlite3* db) {
return (long) sqlite3_last_insert_rowid(db);
static int
_sqlite3_exec(sqlite3* db, const char* pcmd, long* rowid, long* changes)
{
int rv = sqlite3_exec(db, pcmd, 0, 0, 0);
*rowid = (long) sqlite3_last_insert_rowid(db);
*changes = (long) sqlite3_changes(db);
return rv;
}
static long
_sqlite3_changes(sqlite3* db) {
return (long) sqlite3_changes(db);
static int
_sqlite3_step(sqlite3_stmt* stmt, long* rowid, long* changes)
{
int rv = sqlite3_step(stmt);
sqlite3* db = sqlite3_db_handle(stmt);
*rowid = (long) sqlite3_last_insert_rowid(db);
*changes = (long) sqlite3_changes(db);
return rv;
}
*/
@ -60,8 +72,11 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"net/url"
"runtime"
"strconv"
"strings"
"time"
"unsafe"
@ -103,6 +118,7 @@ type SQLiteDriver struct {
// Conn struct.
type SQLiteConn struct {
db *C.sqlite3
loc *time.Location
}
// Tx struct.
@ -114,6 +130,8 @@ type SQLiteTx struct {
type SQLiteStmt struct {
c *SQLiteConn
s *C.sqlite3_stmt
nv int
nn []string
t string
closed bool
cls bool
@ -174,7 +192,7 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
if s.(*SQLiteStmt).s != nil {
na := s.NumInput()
if len(args) < na {
return nil, errors.New("args is not enough to execute query")
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
}
res, err = s.Exec(args[:na])
if err != nil && err != driver.ErrSkip {
@ -201,6 +219,9 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
}
s.(*SQLiteStmt).cls = true
na := s.NumInput()
if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
}
rows, err := s.Query(args[:na])
if err != nil && err != driver.ErrSkip {
s.Close()
@ -220,14 +241,13 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd))
rv := C.sqlite3_exec(c.db, pcmd, nil, nil, nil)
var rowid, changes C.long
rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
if rv != C.SQLITE_OK {
return nil, c.lastError()
}
return &SQLiteResult{
int64(C._sqlite3_last_insert_rowid(c.db)),
int64(C._sqlite3_changes(c.db)),
}, nil
return &SQLiteResult{int64(rowid), int64(changes)}, nil
}
// Begin transaction.
@ -248,11 +268,51 @@ func errorString(err Error) string {
// file:test.db?cache=shared&mode=memory
// :memory:
// file::memory:
// go-sqlite handle especially query parameters.
// _loc=XXX
// Specify location of time format. It's possible to specify "auto".
// _busy_timeout=XXX
// Specify value for sqlite3_busy_timeout.
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
}
var loc *time.Location
busy_timeout := 5000
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:])
if err != nil {
return nil, err
}
// _loc
if val := params.Get("_loc"); val != "" {
if val == "auto" {
loc = time.Local
} else {
loc, err = time.LoadLocation(val)
if err != nil {
return nil, fmt.Errorf("Invalid _loc: %v: %v", val, err)
}
}
}
// _busy_timeout
if val := params.Get("_busy_timeout"); val != "" {
iv, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err)
}
busy_timeout = int(iv)
}
if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
}
var db *C.sqlite3
name := C.CString(dsn)
defer C.free(unsafe.Pointer(name))
@ -268,12 +328,12 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New("sqlite succeeded without returning a database")
}
rv = C.sqlite3_busy_timeout(db, 5000)
rv = C.sqlite3_busy_timeout(db, C.int(busy_timeout))
if rv != C.SQLITE_OK {
return nil, Error{Code: ErrNo(rv)}
}
conn := &SQLiteConn{db}
conn := &SQLiteConn{db: db, loc: loc}
if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1)
@ -281,21 +341,15 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
stmt, err := conn.Prepare("SELECT load_extension(?);")
if err != nil {
return nil, err
}
for _, extension := range d.Extensions {
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
return nil, err
cext := C.CString(extension)
defer C.free(unsafe.Pointer(cext))
rv = C.sqlite3_load_extension(db, cext, nil, nil)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
}
if err = stmt.Close(); err != nil {
return nil, err
}
rv = C.sqlite3_enable_load_extension(db, 0)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
@ -333,10 +387,18 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
return nil, c.lastError()
}
var t string
if tail != nil && C.strlen(tail) > 0 {
if tail != nil && *tail != '\000' {
t = strings.TrimSpace(C.GoString(tail))
}
ss := &SQLiteStmt{c: c, s: s, t: t}
nv := int(C.sqlite3_bind_parameter_count(s))
var nn []string
for i := 0; i < nv; i++ {
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
}
}
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
return ss, nil
}
@ -360,7 +422,12 @@ func (s *SQLiteStmt) Close() error {
// Return a number of parameters.
func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
return s.nv
}
type bindArg struct {
n int
v driver.Value
}
func (s *SQLiteStmt) bind(args []driver.Value) error {
@ -369,8 +436,24 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
return s.c.lastError()
}
var vargs []bindArg
narg := len(args)
vargs = make([]bindArg, narg)
if len(s.nn) > 0 {
for i, v := range s.nn {
if pi, err := strconv.Atoi(v[1:]); err == nil {
vargs[i] = bindArg{pi, args[i]}
}
}
} else {
for i, v := range args {
n := C.int(i + 1)
vargs[i] = bindArg{i + 1, v}
}
}
for _, varg := range vargs {
n := C.int(varg.n)
v := varg.v
switch v := v.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
@ -431,19 +514,18 @@ func (r *SQLiteResult) RowsAffected() (int64, error) {
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
}
rv := C.sqlite3_step(s.s)
var rowid, changes C.long
rv := C._sqlite3_step(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
err := s.c.lastError()
C.sqlite3_reset(s.s)
return nil, s.c.lastError()
C.sqlite3_clear_bindings(s.s)
return nil, err
}
res := &SQLiteResult{
int64(C._sqlite3_last_insert_rowid(s.c.db)),
int64(C._sqlite3_changes(s.c.db)),
}
return res, nil
return &SQLiteResult{int64(rowid), int64(changes)}, nil
}
// Close the rows.
@ -499,7 +581,22 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
switch rc.decltype[i] {
case "timestamp", "datetime", "date":
dest[i] = time.Unix(val, 0).Local()
unixTimestamp := strconv.FormatInt(val, 10)
var t time.Time
if len(unixTimestamp) == 13 {
duration, err := time.ParseDuration(unixTimestamp + "ms")
if err != nil {
return fmt.Errorf("error parsing %s value %d, %s", rc.decltype[i], val, err)
}
epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)
t = epoch.Add(duration)
} else {
t = time.Unix(val, 0)
}
if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc)
}
dest[i] = t
case "boolean":
dest[i] = val > 0
default:
@ -531,16 +628,21 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
switch rc.decltype[i] {
case "timestamp", "datetime", "date":
var t time.Time
for _, format := range SQLiteTimestampFormats {
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
dest[i] = timeVal.Local()
t = timeVal
break
}
}
if err != nil {
// The column is a time value, so return the zero time on parse failure.
dest[i] = time.Time{}
t = time.Time{}
}
if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc)
}
dest[i] = t
default:
dest[i] = []byte(s)
}

View File

@ -0,0 +1,83 @@
// Copyright (C) 2015 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
import (
"database/sql"
"os"
"testing"
)
func TestFTS3(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts3(id INTEGER PRIMARY KEY, value TEXT)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 1, `今日の 晩御飯は 天麩羅よ`)
if err != nil {
t.Fatal("Failed to insert value:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 2, `今日は いい 天気だ`)
if err != nil {
t.Fatal("Failed to insert value:", err)
}
rows, err := db.Query("SELECT id, value FROM foo WHERE value MATCH '今日* 天*'")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
for rows.Next() {
var id int
var value string
if err := rows.Scan(&id, &value); err != nil {
t.Error("Unable to scan results:", err)
continue
}
if id == 1 && value != `今日の 晩御飯は 天麩羅よ` {
t.Error("Value for id 1 should be `今日の 晩御飯は 天麩羅よ`, but:", value)
} else if id == 2 && value != `今日は いい 天気だ` {
t.Error("Value for id 2 should be `今日は いい 天気だ`, but:", value)
}
}
rows, err = db.Query("SELECT value FROM foo WHERE value MATCH '今日* 天麩羅*'")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
var value string
if !rows.Next() {
t.Fatal("Result should be only one")
}
if err := rows.Scan(&value); err != nil {
t.Fatal("Unable to scan results:", err)
}
if value != `今日の 晩御飯は 天麩羅よ` {
t.Fatal("Value should be `今日の 晩御飯は 天麩羅よ`, but:", value)
}
if rows.Next() {
t.Fatal("Result should be only one")
}
}

View File

@ -9,6 +9,6 @@ package sqlite3
/*
#cgo CFLAGS: -I.
#cgo linux LDFLAGS: -ldl
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
#cgo LDFLAGS: -lpthread
*/
import "C"

View File

@ -9,8 +9,10 @@ import (
"crypto/rand"
"database/sql"
"encoding/hex"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
@ -309,6 +311,7 @@ func TestTimestamp(t *testing.T) {
{"0000-00-00 00:00:00", time.Time{}},
{timestamp1, timestamp1},
{timestamp1.Unix(), timestamp1},
{timestamp1.UnixNano() / int64(time.Millisecond), timestamp1},
{timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1},
{timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1},
{timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1},
@ -633,6 +636,102 @@ func TestWAL(t *testing.T) {
}
}
func TestTimezoneConversion(t *testing.T) {
zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
for _, tz := range zones {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz))
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
loc, err := time.LoadLocation(tz)
if err != nil {
t.Fatal("Failed to load location:", err)
}
timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC)
timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC)
timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC)
tests := []struct {
value interface{}
expected time.Time
}{
{"nonsense", time.Time{}.In(loc)},
{"0000-00-00 00:00:00", time.Time{}.In(loc)},
{timestamp1, timestamp1.In(loc)},
{timestamp1.Unix(), timestamp1.In(loc)},
{timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)},
{timestamp2, timestamp2.In(loc)},
{"2006-01-02 15:04:05.123456789", timestamp2.In(loc)},
{"2006-01-02T15:04:05.123456789", timestamp2.In(loc)},
{"2012-11-04", timestamp3.In(loc)},
{"2012-11-04 00:00", timestamp3.In(loc)},
{"2012-11-04 00:00:00", timestamp3.In(loc)},
{"2012-11-04 00:00:00.000", timestamp3.In(loc)},
{"2012-11-04T00:00", timestamp3.In(loc)},
{"2012-11-04T00:00:00", timestamp3.In(loc)},
{"2012-11-04T00:00:00.000", timestamp3.In(loc)},
}
for i := range tests {
_, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value)
if err != nil {
t.Fatal("Failed to insert timestamp:", err)
}
}
rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
seen := 0
for rows.Next() {
var id int
var ts, dt time.Time
if err := rows.Scan(&id, &ts, &dt); err != nil {
t.Error("Unable to scan results:", err)
continue
}
if id < 0 || id >= len(tests) {
t.Error("Bad row id: ", id)
continue
}
seen++
if !tests[id].expected.Equal(ts) {
t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts)
}
if !tests[id].expected.Equal(dt) {
t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
}
if tests[id].expected.Location().String() != ts.Location().String() {
t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), ts.Location().String())
}
if tests[id].expected.Location().String() != dt.Location().String() {
t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), dt.Location().String())
}
}
if seen != len(tests) {
t.Errorf("Expected to see %d rows", len(tests))
}
}
}
func TestSuite(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
@ -742,3 +841,107 @@ func TestStress(t *testing.T) {
db.Close()
}
}
func TestDateTimeLocal(t *testing.T) {
zone := "Asia/Tokyo"
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone)
if err != nil {
t.Fatal("Failed to open database:", err)
}
db.Exec("CREATE TABLE foo (dt datetime);")
db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');")
row := db.QueryRow("select * from foo")
var d time.Time
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.Hour() == 15 || !strings.Contains(d.String(), "JST") {
t.Fatal("Result should have timezone", d)
}
db.Close()
db, err = sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
row = db.QueryRow("select * from foo")
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") {
t.Fatalf("Result should not have timezone %v %v", zone, d.String())
}
_, err = db.Exec("DELETE FROM foo")
if err != nil {
t.Fatal("Failed to delete table:", err)
}
dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST")
if err != nil {
t.Fatal("Failed to parse datetime:", err)
}
db.Exec("INSERT INTO foo VALUES(?);", dt)
db.Close()
db, err = sql.Open("sqlite3", tempFilename+"?_loc="+zone)
if err != nil {
t.Fatal("Failed to open database:", err)
}
row = db.QueryRow("select * from foo")
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.Hour() != 15 || !strings.Contains(d.String(), "JST") {
t.Fatalf("Result should have timezone %v %v", zone, d.String())
}
}
func TestVersion(t *testing.T) {
s, n, id := Version()
if s == "" || n == 0 || id == "" {
t.Errorf("Version failed %q, %d, %q\n", s, n, id)
}
}
func TestNumberNamedParams(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
_, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo")
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, "foo")
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}

View File

@ -2,6 +2,7 @@
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build windows
package sqlite3
@ -9,6 +10,5 @@ package sqlite3
#cgo CFLAGS: -I. -fno-stack-check -fno-stack-protector -mno-stack-arg-probe
#cgo windows,386 CFLAGS: -D_localtime32=localtime
#cgo LDFLAGS: -lmingwex -lmingw32
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
*/
import "C"

View File

@ -17,7 +17,7 @@
*/
#ifndef _SQLITE3EXT_H_
#define _SQLITE3EXT_H_
#include "sqlite3.h"
#include "sqlite3-binding.h"
typedef struct sqlite3_api_routines sqlite3_api_routines;

View File

@ -8,7 +8,7 @@ install:
- export GOPATH="$HOME/gopath"
- mkdir -p "$GOPATH/src/golang.org/x"
- mv "$TRAVIS_BUILD_DIR" "$GOPATH/src/golang.org/x/oauth2"
- go get -v -t -d -tags='appengine appenginevm' golang.org/x/oauth2/...
- go get -v -t -d golang.org/x/oauth2/...
script:
- go test -v -tags='appengine appenginevm' golang.org/x/oauth2/...
- go test -v golang.org/x/oauth2/...

View File

@ -1,25 +1,31 @@
# Contributing
# Contributing to Go
We don't use GitHub pull requests but use Gerrit for code reviews,
similar to the Go project.
Go is an open source project.
1. Sign one of the contributor license agreements below.
2. `go get golang.org/x/review/git-codereview` to install the code reviewing tool.
3. Get the package by running `go get -d golang.org/x/oauth2`.
Make changes and create a change by running `git codereview change <name>`, provide a command message, and use `git codereview mail` to create a Gerrit CL.
Keep amending to the change and mail as your recieve feedback.
It is the work of hundreds of contributors. We appreciate your help!
For more information about the workflow, see Go's [Contribution Guidelines](https://golang.org/doc/contribute.html).
Before we can accept any pull requests
we have to jump through a couple of legal hurdles,
primarily a Contributor License Agreement (CLA):
## Filing issues
- **If you are an individual writing original source code**
and you're sure you own the intellectual property,
then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html).
- **If you work for a company that wants to allow you to contribute your work**,
then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html).
When [filing an issue](https://github.com/golang/oauth2/issues), make sure to answer these five questions:
1. What version of Go are you using (`go version`)?
2. What operating system and processor architecture are you using?
3. What did you do?
4. What did you expect to see?
5. What did you see instead?
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker.
The gophers there will answer or ask you to file an issue if you've tripped over a bug.
## Contributing code
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html)
before sending patches.
**We do not accept GitHub pull requests**
(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review).
Unless otherwise noted, the Go source files are distributed under
the BSD-style license found in the LICENSE file.
You can sign these electronically (just scroll to the bottom).
After that, we'll be able to accept your pull requests.

Some files were not shown because too many files have changed in this diff Show More