diff --git a/conf/defaults.ini b/conf/defaults.ini index ceb916917ab..f0156b70511 100644 --- a/conf/defaults.ini +++ b/conf/defaults.ini @@ -76,8 +76,10 @@ password = # Example: mysql://user:secret@host:port/database url = +# Max idle conn setting default is 2 +max_idle_conn = 2 + # Max conn setting default is 0 (mean not set) -max_idle_conn = max_open_conn = # For "postgres", use either "disable", "require" or "verify-full" diff --git a/conf/sample.ini b/conf/sample.ini index e33f408ecad..80c7464f89c 100644 --- a/conf/sample.ini +++ b/conf/sample.ini @@ -85,8 +85,10 @@ # For "sqlite3" only, path relative to data_path setting ;path = grafana.db +# Max idle conn setting default is 2 +;max_idle_conn = 2 + # Max conn setting default is 0 (mean not set) -;max_idle_conn = ;max_open_conn = diff --git a/docker/blocks/postgres/fig b/docker/blocks/postgres/fig index 33499688810..049339a15c7 100644 --- a/docker/blocks/postgres/fig +++ b/docker/blocks/postgres/fig @@ -6,3 +6,4 @@ postgrestest: POSTGRES_DATABASE: grafana ports: - "5432:5432" + command: postgres -c log_connections=on -c logging_collector=on -c log_destination=stderr -c log_directory=/var/log/postgresql diff --git a/vendor/github.com/lib/pq/.gitignore b/vendor/github.com/lib/pq/.gitignore deleted file mode 100644 index 0f1d00e1196..00000000000 --- a/vendor/github.com/lib/pq/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -.db -*.test -*~ -*.swp diff --git a/vendor/github.com/lib/pq/.travis.yml b/vendor/github.com/lib/pq/.travis.yml deleted file mode 100644 index 82df2a0aabf..00000000000 --- a/vendor/github.com/lib/pq/.travis.yml +++ /dev/null @@ -1,61 +0,0 @@ -language: go - -go: - - 1.1 - - 1.2 - - 1.3 - - 1.4 - - tip - -before_install: - - psql --version - - sudo /etc/init.d/postgresql stop - - sudo apt-get -y --purge remove postgresql libpq-dev libpq5 postgresql-client-common postgresql-common - - sudo rm -rf /var/lib/postgresql - - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - - - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" - - sudo apt-get update -qq - - sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION - - sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgossltest 127.0.0.1/32 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgosslcert 127.0.0.1/32 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgossltest 127.0.0.1/32 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgosslcert 127.0.0.1/32 cert" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "host all all 127.0.0.1/32 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgossltest ::1/128 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgosslcert ::1/128 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgossltest ::1/128 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgosslcert ::1/128 cert" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "host all all ::1/128 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - sudo install -o postgres -g postgres -m 600 -t /var/lib/postgresql/$PGVERSION/main/ certs/server.key certs/server.crt certs/root.crt - - sudo bash -c "[[ '${PGVERSION}' < '9.2' ]] || (echo \"ssl_cert_file = 'server.crt'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf)" - - sudo bash -c "[[ '${PGVERSION}' < '9.2' ]] || (echo \"ssl_key_file = 'server.key'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf)" - - sudo bash -c "[[ '${PGVERSION}' < '9.2' ]] || (echo \"ssl_ca_file = 'root.crt'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf)" - - sudo sh -c "echo 127.0.0.1 postgres >> /etc/hosts" - - sudo ls -l /var/lib/postgresql/$PGVERSION/main/ - - sudo cat /etc/postgresql/$PGVERSION/main/postgresql.conf - - sudo chmod 600 $PQSSLCERTTEST_PATH/postgresql.key - - sudo /etc/init.d/postgresql restart - -env: - global: - - PGUSER=postgres - - PQGOSSLTESTS=1 - - PQSSLCERTTEST_PATH=$PWD/certs - matrix: - - PGVERSION=9.4 - - PGVERSION=9.3 - - PGVERSION=9.2 - - PGVERSION=9.1 - - PGVERSION=9.0 - - PGVERSION=8.4 - -script: - - go test -v ./... - -before_script: - - psql -c 'create database pqgotest' -U postgres - - psql -c 'create user pqgossltest' -U postgres - - psql -c 'create user pqgosslcert' -U postgres diff --git a/vendor/github.com/lib/pq/README.md b/vendor/github.com/lib/pq/README.md index cb15d6f0db0..7670fc87a51 100644 --- a/vendor/github.com/lib/pq/README.md +++ b/vendor/github.com/lib/pq/README.md @@ -1,6 +1,6 @@ # pq - A pure Go postgres driver for Go's database/sql package -[![Build Status](https://travis-ci.org/lib/pq.png?branch=master)](https://travis-ci.org/lib/pq) +[![Build Status](https://travis-ci.org/lib/pq.svg?branch=master)](https://travis-ci.org/lib/pq) ## Install @@ -20,11 +20,11 @@ variables. Example: - PGHOST=/var/run/postgresql go test github.com/lib/pq + PGHOST=/run/postgresql go test github.com/lib/pq Optionally, a benchmark suite can be run as part of the tests: - PGHOST=/var/run/postgresql go test -bench . + PGHOST=/run/postgresql go test -bench . ## Features @@ -38,6 +38,7 @@ Optionally, a benchmark suite can be run as part of the tests: * Many libpq compatible environment variables * Unix socket support * Notifications: `LISTEN`/`NOTIFY` +* pgpass support ## Future / Things you can help with @@ -57,13 +58,17 @@ code still exists in here. * Brad Fitzpatrick (bradfitz) * Charlie Melbye (cmelbye) * Chris Bandy (cbandy) +* Chris Gilling (cgilling) * Chris Walsh (cwds) * Dan Sosedoff (sosedoff) * Daniel Farina (fdr) * Eric Chlebek (echlebek) +* Eric Garrido (minusnine) +* Eric Urban (hydrogen18) * Everyone at The Go Team * Evan Shaw (edsrzf) * Ewan Chou (coocood) +* Fazal Majid (fazalmajid) * Federico Romero (federomero) * Fumin (fumin) * Gary Burd (garyburd) @@ -80,7 +85,7 @@ code still exists in here. * Keith Rarick (kr) * Kir Shatrov (kirs) * Lann Martin (lann) -* Maciek Sakrejda (deafbybeheading) +* Maciek Sakrejda (uhoh-itsmaciek) * Marc Brinkmann (mbr) * Marko Tiikkaja (johto) * Matt Newberry (MattNewberry) @@ -94,5 +99,7 @@ code still exists in here. * Ryan Smith (ryandotsmith) * Samuel Stauffer (samuel) * Timothée Peignier (cyberdelia) +* Travis Cline (tmc) * TruongSinh Tran-Nguyen (truongsinh) +* Yaismel Miranda (ympons) * notedit (notedit) diff --git a/vendor/github.com/lib/pq/array.go b/vendor/github.com/lib/pq/array.go new file mode 100644 index 00000000000..e7b2145d67b --- /dev/null +++ b/vendor/github.com/lib/pq/array.go @@ -0,0 +1,756 @@ +package pq + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/hex" + "fmt" + "reflect" + "strconv" + "strings" +) + +var typeByteSlice = reflect.TypeOf([]byte{}) +var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +// Array returns the optimal driver.Valuer and sql.Scanner for an array or +// slice of any dimension. +// +// For example: +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) +// +// Scanning multi-dimensional arrays is not supported. Arrays where the lower +// bound is not one (such as `[0:0]={1}') are not supported. +func Array(a interface{}) interface { + driver.Valuer + sql.Scanner +} { + switch a := a.(type) { + case []bool: + return (*BoolArray)(&a) + case []float64: + return (*Float64Array)(&a) + case []int64: + return (*Int64Array)(&a) + case []string: + return (*StringArray)(&a) + + case *[]bool: + return (*BoolArray)(a) + case *[]float64: + return (*Float64Array)(a) + case *[]int64: + return (*Int64Array)(a) + case *[]string: + return (*StringArray)(a) + } + + return GenericArray{a} +} + +// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner +// to override the array delimiter used by GenericArray. +type ArrayDelimiter interface { + // ArrayDelimiter returns the delimiter character(s) for this element's type. + ArrayDelimiter() string +} + +// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. +type BoolArray []bool + +// Scan implements the sql.Scanner interface. +func (a *BoolArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to BoolArray", src) +} + +func (a *BoolArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BoolArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(BoolArray, len(elems)) + for i, v := range elems { + if len(v) != 1 { + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + switch v[0] { + case 't': + b[i] = true + case 'f': + b[i] = false + default: + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a BoolArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be exactly two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1+2*n) + + for i := 0; i < n; i++ { + b[2*i] = ',' + if a[i] { + b[1+2*i] = 't' + } else { + b[1+2*i] = 'f' + } + } + + b[0] = '{' + b[2*n] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. +type ByteaArray [][]byte + +// Scan implements the sql.Scanner interface. +func (a *ByteaArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) +} + +func (a *ByteaArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(ByteaArray, len(elems)) + for i, v := range elems { + b[i], err = parseBytea(v) + if err != nil { + return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. It uses the "hex" format which +// is only supported on PostgreSQL 9.0 or newer. +func (a ByteaArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // 3*N bytes of hex formatting, and N-1 bytes of delimiters. + size := 1 + 6*n + for _, x := range a { + size += hex.EncodedLen(len(x)) + } + + b := make([]byte, size) + + for i, s := 0, b; i < n; i++ { + o := copy(s, `,"\\x`) + o += hex.Encode(s[o:], a[i]) + s[o] = '"' + s = s[o+1:] + } + + b[0] = '{' + b[size-1] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// Float64Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float64Array []float64 + +// Scan implements the sql.Scanner interface. +func (a *Float64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float64Array", src) +} + +func (a *Float64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float64Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, a[0], 'f', -1, 64) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, a[i], 'f', -1, 64) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// GenericArray implements the driver.Valuer and sql.Scanner interfaces for +// an array or slice of any dimension. +type GenericArray struct{ A interface{} } + +func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { + var assign func([]byte, reflect.Value) error + var del = "," + + // TODO calculate the assign function for other types + // TODO repeat this section on the element type of arrays or slices (multidimensional) + { + if reflect.PtrTo(rt).Implements(typeSqlScanner) { + // dest is always addressable because it is an element of a slice. + assign = func(src []byte, dest reflect.Value) (err error) { + ss := dest.Addr().Interface().(sql.Scanner) + if src == nil { + err = ss.Scan(nil) + } else { + err = ss.Scan(src) + } + return + } + goto FoundType + } + + assign = func([]byte, reflect.Value) error { + return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) + } + } + +FoundType: + + if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + return rt, assign, del +} + +// Scan implements the sql.Scanner interface. +func (a GenericArray) Scan(src interface{}) error { + dpv := reflect.ValueOf(a.A) + switch { + case dpv.Kind() != reflect.Ptr: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + case dpv.IsNil(): + return fmt.Errorf("pq: destination %T is nil", a.A) + } + + dv := dpv.Elem() + switch dv.Kind() { + case reflect.Slice: + case reflect.Array: + default: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + } + + switch src := src.(type) { + case []byte: + return a.scanBytes(src, dv) + case string: + return a.scanBytes([]byte(src), dv) + case nil: + if dv.Kind() == reflect.Slice { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + } + + return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) +} + +func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { + dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) + dims, elems, err := parseArray(src, []byte(del)) + if err != nil { + return err + } + + // TODO allow multidimensional + + if len(dims) > 1 { + return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", + strings.Replace(fmt.Sprint(dims), " ", "][", -1)) + } + + // Treat a zero-dimensional array like an array with a single dimension of zero. + if len(dims) == 0 { + dims = append(dims, 0) + } + + for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { + switch rt.Kind() { + case reflect.Slice: + case reflect.Array: + if rt.Len() != dims[i] { + return fmt.Errorf("pq: cannot convert ARRAY%s to %s", + strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) + } + default: + // TODO handle multidimensional + } + } + + values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) + for i, e := range elems { + if err := assign(e, values.Index(i)); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + + // TODO handle multidimensional + + switch dv.Kind() { + case reflect.Slice: + dv.Set(values.Slice(0, dims[0])) + case reflect.Array: + for i := 0; i < dims[0]; i++ { + dv.Index(i).Set(values.Index(i)) + } + } + + return nil +} + +// Value implements the driver.Valuer interface. +func (a GenericArray) Value() (driver.Value, error) { + if a.A == nil { + return nil, nil + } + + rv := reflect.ValueOf(a.A) + + switch rv.Kind() { + case reflect.Slice: + if rv.IsNil() { + return nil, nil + } + case reflect.Array: + default: + return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) + } + + if n := rv.Len(); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 0, 1+2*n) + + b, _, err := appendArray(b, rv, n) + return string(b), err + } + + return "{}", nil +} + +// Int64Array represents a one-dimensional array of the PostgreSQL integer types. +type Int64Array []int64 + +// Scan implements the sql.Scanner interface. +func (a *Int64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int64Array", src) +} + +func (a *Int64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int64Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, a[0], 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, a[i], 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// StringArray represents a one-dimensional array of the PostgreSQL character types. +type StringArray []string + +// Scan implements the sql.Scanner interface. +func (a *StringArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to StringArray", src) +} + +func (a *StringArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "StringArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(StringArray, len(elems)) + for i, v := range elems { + if b[i] = string(v); v == nil { + return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedBytes(b, []byte(a[0])) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedBytes(b, []byte(a[i])) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// appendArray appends rv to the buffer, returning the extended buffer and +// the delimiter used between elements. +// +// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. +func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { + var del string + var err error + + b = append(b, '{') + + if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { + return b, del, err + } + + for i := 1; i < n; i++ { + b = append(b, del...) + if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { + return b, del, err + } + } + + return append(b, '}'), del, nil +} + +// appendArrayElement appends rv to the buffer, returning the extended buffer +// and the delimiter to use before the next element. +// +// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted +// using driver.DefaultParameterConverter and the resulting []byte or string +// is double-quoted. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { + if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { + if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { + if n := rv.Len(); n > 0 { + return appendArray(b, rv, n) + } + + return b, "", nil + } + } + + var del string = "," + var err error + var iv interface{} = rv.Interface() + + if ad, ok := iv.(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { + return b, del, err + } + + switch v := iv.(type) { + case nil: + return append(b, "NULL"...), del, nil + case []byte: + return appendArrayQuotedBytes(b, v), del, nil + case string: + return appendArrayQuotedBytes(b, []byte(v)), del, nil + } + + b, err = appendValue(b, iv) + return b, del, err +} + +func appendArrayQuotedBytes(b, v []byte) []byte { + b = append(b, '"') + for { + i := bytes.IndexAny(v, `"\`) + if i < 0 { + b = append(b, v...) + break + } + if i > 0 { + b = append(b, v[:i]...) + } + b = append(b, '\\', v[i]) + v = v[i+1:] + } + return append(b, '"') +} + +func appendValue(b []byte, v driver.Value) ([]byte, error) { + return append(b, encode(nil, v, 0)...), nil +} + +// parseArray extracts the dimensions and elements of an array represented in +// text format. Only representations emitted by the backend are supported. +// Notably, whitespace around brackets and delimiters is significant, and NULL +// is case-sensitive. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { + var depth, i int + + if len(src) < 1 || src[0] != '{' { + return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) + } + +Open: + for i < len(src) { + switch src[i] { + case '{': + depth++ + i++ + case '}': + elems = make([][]byte, 0) + goto Close + default: + break Open + } + } + dims = make([]int, i) + +Element: + for i < len(src) { + switch src[i] { + case '{': + if depth == len(dims) { + break Element + } + depth++ + dims[depth-1] = 0 + i++ + case '"': + var elem = []byte{} + var escape bool + for i++; i < len(src); i++ { + if escape { + elem = append(elem, src[i]) + escape = false + } else { + switch src[i] { + default: + elem = append(elem, src[i]) + case '\\': + escape = true + case '"': + elems = append(elems, elem) + i++ + break Element + } + } + } + default: + for start := i; i < len(src); i++ { + if bytes.HasPrefix(src[i:], del) || src[i] == '}' { + elem := src[start:i] + if len(elem) == 0 { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + if bytes.Equal(elem, []byte("NULL")) { + elem = nil + } + elems = append(elems, elem) + break Element + } + } + } + } + + for i < len(src) { + if bytes.HasPrefix(src[i:], del) && depth > 0 { + dims[depth-1]++ + i += len(del) + goto Element + } else if src[i] == '}' && depth > 0 { + dims[depth-1]++ + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + +Close: + for i < len(src) { + if src[i] == '}' && depth > 0 { + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + if depth > 0 { + err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) + } + if err == nil { + for _, d := range dims { + if (len(elems) % d) != 0 { + err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") + } + } + } + return +} + +func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { + dims, elems, err := parseArray(src, del) + if err != nil { + return nil, err + } + if len(dims) > 1 { + return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) + } + return elems, err +} diff --git a/vendor/github.com/lib/pq/bench_test.go b/vendor/github.com/lib/pq/bench_test.go deleted file mode 100644 index f2eb1dab367..00000000000 --- a/vendor/github.com/lib/pq/bench_test.go +++ /dev/null @@ -1,434 +0,0 @@ -// +build go1.1 - -package pq - -import ( - "bufio" - "bytes" - "database/sql" - "database/sql/driver" - "github.com/lib/pq/oid" - "io" - "math/rand" - "net" - "runtime" - "strconv" - "strings" - "sync" - "testing" - "time" -) - -var ( - selectStringQuery = "SELECT '" + strings.Repeat("0123456789", 10) + "'" - selectSeriesQuery = "SELECT generate_series(1, 100)" -) - -func BenchmarkSelectString(b *testing.B) { - var result string - benchQuery(b, selectStringQuery, &result) -} - -func BenchmarkSelectSeries(b *testing.B) { - var result int - benchQuery(b, selectSeriesQuery, &result) -} - -func benchQuery(b *testing.B, query string, result interface{}) { - b.StopTimer() - db := openTestConn(b) - defer db.Close() - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchQueryLoop(b, db, query, result) - } -} - -func benchQueryLoop(b *testing.B, db *sql.DB, query string, result interface{}) { - rows, err := db.Query(query) - if err != nil { - b.Fatal(err) - } - defer rows.Close() - for rows.Next() { - err = rows.Scan(result) - if err != nil { - b.Fatal("failed to scan", err) - } - } -} - -// reading from circularConn yields content[:prefixLen] once, followed by -// content[prefixLen:] over and over again. It never returns EOF. -type circularConn struct { - content string - prefixLen int - pos int - net.Conn // for all other net.Conn methods that will never be called -} - -func (r *circularConn) Read(b []byte) (n int, err error) { - n = copy(b, r.content[r.pos:]) - r.pos += n - if r.pos >= len(r.content) { - r.pos = r.prefixLen - } - return -} - -func (r *circularConn) Write(b []byte) (n int, err error) { return len(b), nil } - -func (r *circularConn) Close() error { return nil } - -func fakeConn(content string, prefixLen int) *conn { - c := &circularConn{content: content, prefixLen: prefixLen} - return &conn{buf: bufio.NewReader(c), c: c} -} - -// This benchmark is meant to be the same as BenchmarkSelectString, but takes -// out some of the factors this package can't control. The numbers are less noisy, -// but also the costs of network communication aren't accurately represented. -func BenchmarkMockSelectString(b *testing.B) { - b.StopTimer() - // taken from a recorded run of BenchmarkSelectString - // See: http://www.postgresql.org/docs/current/static/protocol-message-formats.html - const response = "1\x00\x00\x00\x04" + - "t\x00\x00\x00\x06\x00\x00" + - "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + - "Z\x00\x00\x00\x05I" + - "2\x00\x00\x00\x04" + - "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + - "C\x00\x00\x00\rSELECT 1\x00" + - "Z\x00\x00\x00\x05I" + - "3\x00\x00\x00\x04" + - "Z\x00\x00\x00\x05I" - c := fakeConn(response, 0) - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchMockQuery(b, c, selectStringQuery) - } -} - -var seriesRowData = func() string { - var buf bytes.Buffer - for i := 1; i <= 100; i++ { - digits := byte(2) - if i >= 100 { - digits = 3 - } else if i < 10 { - digits = 1 - } - buf.WriteString("D\x00\x00\x00") - buf.WriteByte(10 + digits) - buf.WriteString("\x00\x01\x00\x00\x00") - buf.WriteByte(digits) - buf.WriteString(strconv.Itoa(i)) - } - return buf.String() -}() - -func BenchmarkMockSelectSeries(b *testing.B) { - b.StopTimer() - var response = "1\x00\x00\x00\x04" + - "t\x00\x00\x00\x06\x00\x00" + - "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + - "Z\x00\x00\x00\x05I" + - "2\x00\x00\x00\x04" + - seriesRowData + - "C\x00\x00\x00\x0fSELECT 100\x00" + - "Z\x00\x00\x00\x05I" + - "3\x00\x00\x00\x04" + - "Z\x00\x00\x00\x05I" - c := fakeConn(response, 0) - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchMockQuery(b, c, selectSeriesQuery) - } -} - -func benchMockQuery(b *testing.B, c *conn, query string) { - stmt, err := c.Prepare(query) - if err != nil { - b.Fatal(err) - } - defer stmt.Close() - rows, err := stmt.Query(nil) - if err != nil { - b.Fatal(err) - } - defer rows.Close() - var dest [1]driver.Value - for { - if err := rows.Next(dest[:]); err != nil { - if err == io.EOF { - break - } - b.Fatal(err) - } - } -} - -func BenchmarkPreparedSelectString(b *testing.B) { - var result string - benchPreparedQuery(b, selectStringQuery, &result) -} - -func BenchmarkPreparedSelectSeries(b *testing.B) { - var result int - benchPreparedQuery(b, selectSeriesQuery, &result) -} - -func benchPreparedQuery(b *testing.B, query string, result interface{}) { - b.StopTimer() - db := openTestConn(b) - defer db.Close() - stmt, err := db.Prepare(query) - if err != nil { - b.Fatal(err) - } - defer stmt.Close() - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchPreparedQueryLoop(b, db, stmt, result) - } -} - -func benchPreparedQueryLoop(b *testing.B, db *sql.DB, stmt *sql.Stmt, result interface{}) { - rows, err := stmt.Query() - if err != nil { - b.Fatal(err) - } - if !rows.Next() { - rows.Close() - b.Fatal("no rows") - } - defer rows.Close() - for rows.Next() { - err = rows.Scan(&result) - if err != nil { - b.Fatal("failed to scan") - } - } -} - -// See the comment for BenchmarkMockSelectString. -func BenchmarkMockPreparedSelectString(b *testing.B) { - b.StopTimer() - const parseResponse = "1\x00\x00\x00\x04" + - "t\x00\x00\x00\x06\x00\x00" + - "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + - "Z\x00\x00\x00\x05I" - const responses = parseResponse + - "2\x00\x00\x00\x04" + - "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + - "C\x00\x00\x00\rSELECT 1\x00" + - "Z\x00\x00\x00\x05I" - c := fakeConn(responses, len(parseResponse)) - - stmt, err := c.Prepare(selectStringQuery) - if err != nil { - b.Fatal(err) - } - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchPreparedMockQuery(b, c, stmt) - } -} - -func BenchmarkMockPreparedSelectSeries(b *testing.B) { - b.StopTimer() - const parseResponse = "1\x00\x00\x00\x04" + - "t\x00\x00\x00\x06\x00\x00" + - "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + - "Z\x00\x00\x00\x05I" - var responses = parseResponse + - "2\x00\x00\x00\x04" + - seriesRowData + - "C\x00\x00\x00\x0fSELECT 100\x00" + - "Z\x00\x00\x00\x05I" - c := fakeConn(responses, len(parseResponse)) - - stmt, err := c.Prepare(selectSeriesQuery) - if err != nil { - b.Fatal(err) - } - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchPreparedMockQuery(b, c, stmt) - } -} - -func benchPreparedMockQuery(b *testing.B, c *conn, stmt driver.Stmt) { - rows, err := stmt.Query(nil) - if err != nil { - b.Fatal(err) - } - defer rows.Close() - var dest [1]driver.Value - for { - if err := rows.Next(dest[:]); err != nil { - if err == io.EOF { - break - } - b.Fatal(err) - } - } -} - -func BenchmarkEncodeInt64(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{}, int64(1234), oid.T_int8) - } -} - -func BenchmarkEncodeFloat64(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{}, 3.14159, oid.T_float8) - } -} - -var testByteString = []byte("abcdefghijklmnopqrstuvwxyz") - -func BenchmarkEncodeByteaHex(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{serverVersion: 90000}, testByteString, oid.T_bytea) - } -} -func BenchmarkEncodeByteaEscape(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{serverVersion: 84000}, testByteString, oid.T_bytea) - } -} - -func BenchmarkEncodeBool(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{}, true, oid.T_bool) - } -} - -var testTimestamptz = time.Date(2001, time.January, 1, 0, 0, 0, 0, time.Local) - -func BenchmarkEncodeTimestamptz(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{}, testTimestamptz, oid.T_timestamptz) - } -} - -var testIntBytes = []byte("1234") - -func BenchmarkDecodeInt64(b *testing.B) { - for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testIntBytes, oid.T_int8) - } -} - -var testFloatBytes = []byte("3.14159") - -func BenchmarkDecodeFloat64(b *testing.B) { - for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testFloatBytes, oid.T_float8) - } -} - -var testBoolBytes = []byte{'t'} - -func BenchmarkDecodeBool(b *testing.B) { - for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testBoolBytes, oid.T_bool) - } -} - -func TestDecodeBool(t *testing.T) { - db := openTestConn(t) - rows, err := db.Query("select true") - if err != nil { - t.Fatal(err) - } - rows.Close() -} - -var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07") - -func BenchmarkDecodeTimestamptz(b *testing.B) { - for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz) - } -} - -func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) { - oldProcs := runtime.GOMAXPROCS(0) - defer runtime.GOMAXPROCS(oldProcs) - runtime.GOMAXPROCS(runtime.NumCPU()) - globalLocationCache = newLocationCache() - - f := func(wg *sync.WaitGroup, loops int) { - defer wg.Done() - for i := 0; i < loops; i++ { - decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz) - } - } - - wg := &sync.WaitGroup{} - b.ResetTimer() - for j := 0; j < 10; j++ { - wg.Add(1) - go f(wg, b.N/10) - } - wg.Wait() -} - -func BenchmarkLocationCache(b *testing.B) { - globalLocationCache = newLocationCache() - for i := 0; i < b.N; i++ { - globalLocationCache.getLocation(rand.Intn(10000)) - } -} - -func BenchmarkLocationCacheMultiThread(b *testing.B) { - oldProcs := runtime.GOMAXPROCS(0) - defer runtime.GOMAXPROCS(oldProcs) - runtime.GOMAXPROCS(runtime.NumCPU()) - globalLocationCache = newLocationCache() - - f := func(wg *sync.WaitGroup, loops int) { - defer wg.Done() - for i := 0; i < loops; i++ { - globalLocationCache.getLocation(rand.Intn(10000)) - } - } - - wg := &sync.WaitGroup{} - b.ResetTimer() - for j := 0; j < 10; j++ { - wg.Add(1) - go f(wg, b.N/10) - } - wg.Wait() -} - -// Stress test the performance of parsing results from the wire. -func BenchmarkResultParsing(b *testing.B) { - b.StopTimer() - - db := openTestConn(b) - defer db.Close() - _, err := db.Exec("BEGIN") - if err != nil { - b.Fatal(err) - } - - b.StartTimer() - for i := 0; i < b.N; i++ { - res, err := db.Query("SELECT generate_series(1, 50000)") - if err != nil { - b.Fatal(err) - } - res.Close() - } -} diff --git a/vendor/github.com/lib/pq/buf.go b/vendor/github.com/lib/pq/buf.go index 9f417a1ebae..666b0012a79 100644 --- a/vendor/github.com/lib/pq/buf.go +++ b/vendor/github.com/lib/pq/buf.go @@ -3,6 +3,7 @@ package pq import ( "bytes" "encoding/binary" + "github.com/lib/pq/oid" ) @@ -20,6 +21,7 @@ func (b *readBuf) oid() (n oid.Oid) { return } +// N.B: this is actually an unsigned 16-bit integer, unlike int32 func (b *readBuf) int16() (n int) { n = int(binary.BigEndian.Uint16(*b)) *b = (*b)[2:] @@ -46,28 +48,44 @@ func (b *readBuf) byte() byte { return b.next(1)[0] } -type writeBuf []byte +type writeBuf struct { + buf []byte + pos int +} func (b *writeBuf) int32(n int) { x := make([]byte, 4) binary.BigEndian.PutUint32(x, uint32(n)) - *b = append(*b, x...) + b.buf = append(b.buf, x...) } func (b *writeBuf) int16(n int) { x := make([]byte, 2) binary.BigEndian.PutUint16(x, uint16(n)) - *b = append(*b, x...) + b.buf = append(b.buf, x...) } func (b *writeBuf) string(s string) { - *b = append(*b, (s + "\000")...) + b.buf = append(b.buf, (s + "\000")...) } func (b *writeBuf) byte(c byte) { - *b = append(*b, c) + b.buf = append(b.buf, c) } func (b *writeBuf) bytes(v []byte) { - *b = append(*b, v...) + b.buf = append(b.buf, v...) +} + +func (b *writeBuf) wrap() []byte { + p := b.buf[b.pos:] + binary.BigEndian.PutUint32(p, uint32(len(p))) + return b.buf +} + +func (b *writeBuf) next(c byte) { + p := b.buf[b.pos:] + binary.BigEndian.PutUint32(p, uint32(len(p))) + b.pos = len(b.buf) + 1 + b.buf = append(b.buf, c, 0, 0, 0, 0) } diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index 44e4833921f..3747ffcaf84 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -3,16 +3,12 @@ package pq import ( "bufio" "crypto/md5" - "crypto/tls" - "crypto/x509" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" - "github.com/lib/pq/oid" "io" - "io/ioutil" "net" "os" "os/user" @@ -22,6 +18,8 @@ import ( "strings" "time" "unicode" + + "github.com/lib/pq/oid" ) // Common error types @@ -31,16 +29,20 @@ var ( ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") + + errUnexpectedReady = errors.New("unexpected ReadyForQuery") + errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") + errNoLastInsertId = errors.New("no LastInsertId available after the empty statement") ) -type drv struct{} +type Driver struct{} -func (d *drv) Open(name string) (driver.Conn, error) { +func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } func init() { - sql.Register("postgres", &drv{}) + sql.Register("postgres", &Driver{}) } type parameterStatus struct { @@ -96,6 +98,15 @@ type conn struct { namei int scratch [512]byte txnStatus transactionStatus + txnFinish func() + + // Save connection arguments to use during CancelRequest. + dialer Dialer + opts values + + // Cancellation key data for use with CancelRequest messages. + processID int + secretKey int parameterStatus parameterStatus @@ -105,12 +116,125 @@ type conn struct { // If true, this connection is bad and all public-facing functions should // return ErrBadConn. bad bool + + // If set, this connection should never use the binary format when + // receiving query results from prepared statements. Only provided for + // debugging. + disablePreparedBinaryResult bool + + // Whether to always send []byte parameters over as binary. Enables single + // round-trip mode for non-prepared Query calls. + binaryParameters bool + + // If true this connection is in the middle of a COPY + inCopy bool +} + +// Handle driver-side settings in parsed connection string. +func (c *conn) handleDriverSettings(o values) (err error) { + boolSetting := func(key string, val *bool) error { + if value, ok := o[key]; ok { + if value == "yes" { + *val = true + } else if value == "no" { + *val = false + } else { + return fmt.Errorf("unrecognized value %q for %s", value, key) + } + } + return nil + } + + err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) + if err != nil { + return err + } + err = boolSetting("binary_parameters", &c.binaryParameters) + if err != nil { + return err + } + return nil +} + +func (c *conn) handlePgpass(o values) { + // if a password was supplied, do not process .pgpass + if _, ok := o["password"]; ok { + return + } + filename := os.Getenv("PGPASSFILE") + if filename == "" { + // XXX this code doesn't work on Windows where the default filename is + // XXX %APPDATA%\postgresql\pgpass.conf + user, err := user.Current() + if err != nil { + return + } + filename = filepath.Join(user.HomeDir, ".pgpass") + } + fileinfo, err := os.Stat(filename) + if err != nil { + return + } + mode := fileinfo.Mode() + if mode&(0x77) != 0 { + // XXX should warn about incorrect .pgpass permissions as psql does + return + } + file, err := os.Open(filename) + if err != nil { + return + } + defer file.Close() + scanner := bufio.NewScanner(io.Reader(file)) + hostname := o["host"] + ntw, _ := network(o) + port := o["port"] + db := o["dbname"] + username := o["user"] + // From: https://github.com/tg/pgpass/blob/master/reader.go + getFields := func(s string) []string { + fs := make([]string, 0, 5) + f := make([]rune, 0, len(s)) + + var esc bool + for _, c := range s { + switch { + case esc: + f = append(f, c) + esc = false + case c == '\\': + esc = true + case c == ':': + fs = append(fs, string(f)) + f = f[:0] + default: + f = append(f, c) + } + } + return append(fs, string(f)) + } + for scanner.Scan() { + line := scanner.Text() + if len(line) == 0 || line[0] == '#' { + continue + } + split := getFields(line) + if len(split) != 5 { + continue + } + if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { + o["password"] = split[4] + return + } + } } func (c *conn) writeBuf(b byte) *writeBuf { c.scratch[0] = b - w := writeBuf(c.scratch[:5]) - return &w + return &writeBuf{ + buf: c.scratch[:5], + pos: 1, + } } func Open(name string) (_ driver.Conn, err error) { @@ -118,22 +242,11 @@ func Open(name string) (_ driver.Conn, err error) { } func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { - defer func() { - // Handle any panics during connection initialization. Note that we - // specifically do *not* want to use errRecover(), as that would turn - // any connection errors into ErrBadConns, hiding the real error - // message from the user. - e := recover() - if e == nil { - // Do nothing - return - } - var ok bool - err, ok = e.(error) - if !ok { - err = fmt.Errorf("pq: unexpected error: %#v", e) - } - }() + // Handle any panics during connection initialization. Note that we + // specifically do *not* want to use errRecover(), as that would turn any + // connection errors into ErrBadConns, hiding the real error message from + // the user. + defer errRecoverNoErrBadConn(&err) o := make(values) @@ -142,16 +255,16 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // * Very low precedence defaults applied in every situation // * Environment variables // * Explicitly passed connection information - o.Set("host", "localhost") - o.Set("port", "5432") + o["host"] = "localhost" + o["port"] = "5432" // N.B.: Extra float digits should be set to 3, but that breaks // Postgres 8.4 and older, where the max is 2. - o.Set("extra_float_digits", "2") + o["extra_float_digits"] = "2" for k, v := range parseEnviron(os.Environ()) { - o.Set(k, v) + o[k] = v } - if strings.HasPrefix(name, "postgres://") { + if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { name, err = ParseURL(name) if err != nil { return nil, err @@ -163,9 +276,9 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } // Use the "fallback" application name if necessary - if fallback := o.Get("fallback_application_name"); fallback != "" { - if !o.Isset("application_name") { - o.Set("application_name", fallback) + if fallback, ok := o["fallback_application_name"]; ok { + if _, ok := o["application_name"]; !ok { + o["application_name"] = fallback } } @@ -176,53 +289,66 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // parsing its value is not worth it. Instead, we always explicitly send // client_encoding as a separate run-time parameter, which should override // anything set in options. - if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) { + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { return nil, errors.New("client_encoding must be absent or 'UTF8'") } - o.Set("client_encoding", "UTF8") + o["client_encoding"] = "UTF8" // DateStyle needs a similar treatment. - if datestyle := o.Get("datestyle"); datestyle != "" { + if datestyle, ok := o["datestyle"]; ok { if datestyle != "ISO, MDY" { panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)) } } else { - o.Set("datestyle", "ISO, MDY") + o["datestyle"] = "ISO, MDY" } // If a user is not provided by any other means, the last // resort is to use the current operating system provided user // name. - if o.Get("user") == "" { + if _, ok := o["user"]; !ok { u, err := userCurrent() if err != nil { return nil, err } else { - o.Set("user", u) + o["user"] = u } } - c, err := dial(d, o) + cn := &conn{ + opts: o, + dialer: d, + } + err = cn.handleDriverSettings(o) if err != nil { return nil, err } + cn.handlePgpass(o) - cn := &conn{c: c} + cn.c, err = dial(d, o) + if err != nil { + return nil, err + } cn.ssl(o) cn.buf = bufio.NewReader(cn.c) cn.startup(o) + // reset the deadline, in case one was set (see dial) - err = cn.c.SetDeadline(time.Time{}) + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { + err = cn.c.SetDeadline(time.Time{}) + } return cn, err } func dial(d Dialer, o values) (net.Conn, error) { ntw, addr := network(o) - - timeout := o.Get("connect_timeout") + // SSL is not necessary or supported over UNIX domain sockets + if ntw == "unix" { + o["sslmode"] = "disable" + } // Zero or not specified means wait indefinitely. - if timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { seconds, err := strconv.ParseInt(timeout, 10, 0) if err != nil { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) @@ -244,31 +370,18 @@ func dial(d Dialer, o values) (net.Conn, error) { } func network(o values) (string, string) { - host := o.Get("host") + host := o["host"] if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) + sockPath := path.Join(host, ".s.PGSQL."+o["port"]) return "unix", sockPath } - return "tcp", host + ":" + o.Get("port") + return "tcp", net.JoinHostPort(host, o["port"]) } type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - -func (vs values) Isset(k string) bool { - _, ok := vs[k] - return ok -} - // scanner implements a tokenizer for libpq-style option strings. type scanner struct { s []rune @@ -339,7 +452,7 @@ func parseOpts(name string, o values) error { // Skip any whitespace after the = if r, ok = s.SkipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. - o.Set(string(keyRunes), "") + o[string(keyRunes)] = "" break } @@ -374,7 +487,7 @@ func parseOpts(name string, o values) error { } } - o.Set(string(keyRunes), string(valRunes)) + o[string(keyRunes)] = string(valRunes) } return nil @@ -393,13 +506,17 @@ func (cn *conn) checkIsInTransaction(intxn bool) { } func (cn *conn) Begin() (_ driver.Tx, err error) { + return cn.begin("") +} + +func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) cn.checkIsInTransaction(false) - _, commandTag, err := cn.simpleExec("BEGIN") + _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { return nil, err } @@ -414,7 +531,14 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { return cn, nil } +func (cn *conn) closeTxn() { + if finish := cn.txnFinish; finish != nil { + finish() + } +} + func (cn *conn) Commit() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -436,6 +560,9 @@ func (cn *conn) Commit() (err error) { _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { + if cn.isInTransaction() { + cn.bad = true + } return err } if commandTag != "COMMIT" { @@ -447,6 +574,7 @@ func (cn *conn) Commit() (err error) { } func (cn *conn) Rollback() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -455,6 +583,9 @@ func (cn *conn) Rollback() (err error) { cn.checkIsInTransaction(true) _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { + if cn.isInTransaction() { + cn.bad = true + } return err } if commandTag != "ROLLBACK" { @@ -481,11 +612,16 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } // done return case 'E': err = parseError(r) - case 'T', 'D', 'I': + case 'I': + res = emptyRows + case 'T', 'D': // ignore any results default: cn.bad = true @@ -494,11 +630,9 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err } } -func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { +func (cn *conn) simpleQuery(q string) (res *rows, err error) { defer cn.errRecover(&err) - st := &stmt{cn: cn, name: ""} - b := cn.writeBuf('Q') b.string(q) cn.send(b) @@ -515,7 +649,18 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { cn.bad = true errorf("unexpected message %q in simple query execution", t) } - res = &rows{st: st, done: true} + if res == nil { + res = &rows{ + cn: cn, + } + } + // Set the result and tag to the last command complete if there wasn't a + // query already run. Although queries usually return from here and cede + // control to Next, a query with zero results does not. + if t == 'C' && res.colNames == nil { + res.result, res.tag = cn.parseComplete(r.string()) + } + res.done = true case 'Z': cn.processReadyForQuery(r) // done @@ -534,8 +679,8 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { case 'T': // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it - res = &rows{st: st} - st.cols, st.rowTyps = parseMeta(r) + res = &rows{cn: cn} + res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. @@ -546,47 +691,90 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { } } -func (cn *conn) prepareTo(q, stmtName string) (_ *stmt, err error) { +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errNoLastInsertId +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errNoRowsAffected +} + +// Decides which column formats to use for a prepared statement. The input is +// an array of type oids, one element per result column. +func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { + if len(colTyps) == 0 { + return nil, colFmtDataAllText + } + + colFmts = make([]format, len(colTyps)) + if forceText { + return colFmts, colFmtDataAllText + } + + allBinary := true + allText := true + for i, o := range colTyps { + switch o { + // This is the list of types to use binary mode for when receiving them + // through a prepared statement. If a type appears in this list, it + // must also be implemented in binaryDecode in encode.go. + case oid.T_bytea: + fallthrough + case oid.T_int8: + fallthrough + case oid.T_int4: + fallthrough + case oid.T_int2: + fallthrough + case oid.T_uuid: + colFmts[i] = formatBinary + allText = false + + default: + allBinary = false + } + } + + if allBinary { + return colFmts, colFmtDataAllBinary + } else if allText { + return colFmts, colFmtDataAllText + } else { + colFmtData = make([]byte, 2+len(colFmts)*2) + binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) + for i, v := range colFmts { + binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) + } + return colFmts, colFmtData + } +} + +func (cn *conn) prepareTo(q, stmtName string) *stmt { st := &stmt{cn: cn, name: stmtName} b := cn.writeBuf('P') b.string(st.name) b.string(q) b.int16(0) - cn.send(b) - b = cn.writeBuf('D') + b.next('D') b.byte('S') b.string(st.name) + + b.next('S') cn.send(b) - cn.send(cn.writeBuf('S')) - - for { - t, r := cn.recv1() - switch t { - case '1': - case 't': - nparams := r.int16() - st.paramTyps = make([]oid.Oid, nparams) - - for i := range st.paramTyps { - st.paramTyps[i] = r.oid() - } - case 'T': - st.cols, st.rowTyps = parseMeta(r) - case 'n': - // no data - case 'Z': - cn.processReadyForQuery(r) - return st, err - case 'E': - err = parseError(r) - default: - cn.bad = true - errorf("unexpected describe rows response: %q", t) - } - } + cn.readParseResponse() + st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() + st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) + cn.readReadyForQuery() + return st } func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { @@ -596,32 +784,45 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { defer cn.errRecover(&err) if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { - return cn.prepareCopyIn(q) + s, err := cn.prepareCopyIn(q) + if err == nil { + cn.inCopy = true + } + return s, err } - return cn.prepareTo(q, cn.gname()) + return cn.prepareTo(q, cn.gname()), nil } func (cn *conn) Close() (err error) { - if cn.bad { - return driver.ErrBadConn - } + // Skip cn.bad return here because we always want to close a connection. defer cn.errRecover(&err) + // Ensure that cn.c.Close is always run. Since error handling is done with + // panics and cn.errRecover, the Close must be in a defer. + defer func() { + cerr := cn.c.Close() + if err == nil { + err = cerr + } + }() + // Don't go through send(); ListenerConn relies on us not scribbling on the // scratch buffer of this connection. - err = cn.sendSimpleMessage('X') - if err != nil { - return err - } - - return cn.c.Close() + return cn.sendSimpleMessage('X') } // Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) { +func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return cn.query(query, args) +} + +func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { if cn.bad { return nil, driver.ErrBadConn } + if cn.inCopy { + return nil, errCopyInProgress + } defer cn.errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is @@ -630,17 +831,29 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err return cn.simpleQuery(query) } - st, err := cn.prepareTo(query, "") - if err != nil { - panic(err) - } + if cn.binaryParameters { + cn.sendBinaryModeQuery(query, args) - st.exec(args) - return &rows{st: st}, nil + cn.readParseResponse() + cn.readBindResponse() + rows := &rows{cn: cn} + rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() + cn.postExecuteWorkaround() + return rows, nil + } else { + st := cn.prepareTo(query, "") + st.exec(args) + return &rows{ + cn: cn, + colNames: st.colNames, + colTyps: st.colTyps, + colFmts: st.colFmts, + }, nil + } } // Implement the optional "Execer" interface for one-shot queries -func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) { +func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { if cn.bad { return nil, driver.ErrBadConn } @@ -654,37 +867,40 @@ func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err er return r, err } - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st, err := cn.prepareTo(query, "") - if err != nil { - panic(err) - } + if cn.binaryParameters { + cn.sendBinaryModeQuery(query, args) - r, err := st.Exec(args) - if err != nil { - panic(err) + cn.readParseResponse() + cn.readBindResponse() + cn.readPortalDescribeResponse() + cn.postExecuteWorkaround() + res, _, err = cn.readExecuteResponse("Execute") + return res, err + } else { + // Use the unnamed statement to defer planning until bind + // time, or else value-based selectivity estimates cannot be + // used. + st := cn.prepareTo(query, "") + r, err := st.Exec(args) + if err != nil { + panic(err) + } + return r, err } - - return r, err } -// Assumes len(*m) is > 5 func (cn *conn) send(m *writeBuf) { - b := (*m)[1:] - binary.BigEndian.PutUint32(b, uint32(len(b))) - - if (*m)[0] == 0 { - *m = b - } - - _, err := cn.c.Write(*m) + _, err := cn.c.Write(m.wrap()) if err != nil { panic(err) } } +func (cn *conn) sendStartupPacket(m *writeBuf) error { + _, err := cn.c.Write((m.wrap())[1:]) + return err +} + // Send a message of type typ to the server on the other end of cn. The // message should have no payload. This method does not use the scratch // buffer. @@ -796,30 +1012,17 @@ func (cn *conn) recv1() (t byte, r *readBuf) { } func (cn *conn) ssl(o values) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o.Get("sslmode"); mode { - case "require", "": - tlsConf.InsecureSkipVerify = true - case "verify-ca": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o.Get("host") - case "disable": + upgrade := ssl(o) + if upgrade == nil { + // Nothing to do return - default: - errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode) } - cn.setupSSLClientCertificates(&tlsConf, o) - cn.setupSSLCA(&tlsConf, o) - w := cn.writeBuf(0) w.int32(80877103) - cn.send(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } b := cn.scratch[:1] _, err := io.ReadFull(cn.c, b) @@ -831,114 +1034,7 @@ func (cn *conn) ssl(o values) { panic(ErrSSLNotSupported) } - client := tls.Client(cn.c, &tlsConf) - if verifyCaOnly { - cn.verifyCA(client, &tlsConf) - } - cn.c = client -} - -// verifyCA carries out a TLS handshake to the server and verifies the -// presented certificate against the effective CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) { - err := client.Handshake() - if err != nil { - panic(err) - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, - } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - _, err = certs[0].Verify(opts) - if err != nil { - panic(err) - } -} - -// This function sets up SSL client certificates based on either the "sslkey" -// and "sslcert" settings (possibly set via the environment variables PGSSLKEY -// and PGSSLCERT, respectively), or if they aren't set, from the .postgresql -// directory in the user's home directory. If the file paths are set -// explicitly, the files must exist. The key file must also not be -// world-readable, or this function will panic with -// ErrSSLKeyHasWorldPermissions. -func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) { - var missingOk bool - - sslkey := o.Get("sslkey") - sslcert := o.Get("sslcert") - if sslkey != "" && sslcert != "" { - // If the user has set an sslkey and sslcert, they *must* exist. - missingOk = false - } else { - // Automatically load certificates from ~/.postgresql. - user, err := user.Current() - if err != nil { - // user.Current() might fail when cross-compiling. We have to - // ignore the error and continue without client certificates, since - // we wouldn't know where to load them from. - return - } - - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") - missingOk = true - } - - // Check that both files exist, and report the error or stop, depending on - // which behaviour we want. Note that we don't do any more extensive - // checks than this (such as checking that the paths aren't directories); - // LoadX509KeyPair() will take care of the rest. - keyfinfo, err := os.Stat(sslkey) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - _, err = os.Stat(sslcert) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - - // If we got this far, the key file must also have the correct permissions - kmode := keyfinfo.Mode() - if kmode != kmode&0600 { - panic(ErrSSLKeyHasWorldPermissions) - } - - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - panic(err) - } - tlsConf.Certificates = []tls.Certificate{cert} -} - -// Sets up RootCAs in the TLS configuration if sslrootcert is set. -func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) { - if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" { - tlsConf.RootCAs = x509.NewCertPool() - - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - panic(err) - } - - ok := tlsConf.RootCAs.AppendCertsFromPEM(cert) - if !ok { - errorf("couldn't parse pem in sslrootcert") - } - } + cn.c = upgrade(cn.c) } // isDriverSetting returns true iff a setting is purely for configuring the @@ -956,6 +1052,10 @@ func isDriverSetting(key string) bool { return true case "connect_timeout": return true + case "disable_prepared_binary_result": + return true + case "binary_parameters": + return true default: return false @@ -983,12 +1083,15 @@ func (cn *conn) startup(o values) { w.string(v) } w.string("") - cn.send(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } for { t, r := cn.recv() switch t { case 'K': + cn.processBackendKeyData(r) case 'S': cn.processParameterStatus(r) case 'R': @@ -1008,7 +1111,7 @@ func (cn *conn) auth(r *readBuf, o values) { // OK case 3: w := cn.writeBuf('p') - w.string(o.Get("password")) + w.string(o["password"]) cn.send(w) t, r := cn.recv() @@ -1022,7 +1125,7 @@ func (cn *conn) auth(r *readBuf, o values) { case 5: s := string(r.next(4)) w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) + w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) cn.send(w) t, r := cn.recv() @@ -1038,13 +1141,26 @@ func (cn *conn) auth(r *readBuf, o values) { } } +type format int + +const formatText format = 0 +const formatBinary format = 1 + +// One result-column format code with the value 1 (i.e. all binary). +var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} + +// No result-column format codes (i.e. all text). +var colFmtDataAllText []byte = []byte{0, 0} + type stmt struct { - cn *conn - name string - cols []string - rowTyps []oid.Oid - paramTyps []oid.Oid - closed bool + cn *conn + name string + colNames []string + colFmts []format + colFmtData []byte + colTyps []oid.Oid + paramTyps []oid.Oid + closed bool } func (st *stmt) Close() (err error) { @@ -1087,7 +1203,12 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { defer st.cn.errRecover(&err) st.exec(v) - return &rows{st: st}, nil + return &rows{ + cn: st.cn, + colNames: st.colNames, + colTyps: st.colTyps, + colFmts: st.colFmts, + }, nil } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { @@ -1097,25 +1218,8 @@ func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { defer st.cn.errRecover(&err) st.exec(v) - - for { - t, r := st.cn.recv1() - switch t { - case 'E': - err = parseError(r) - case 'C': - res, _ = st.cn.parseComplete(r.string()) - case 'Z': - st.cn.processReadyForQuery(r) - // done - return - case 'T', 'D', 'I': - // ignore any results - default: - st.cn.bad = true - errorf("unknown exec response: %q", t) - } - } + res, _, err = st.cn.readExecuteResponse("simple query") + return res, err } func (st *stmt) exec(v []driver.Value) { @@ -1126,84 +1230,38 @@ func (st *stmt) exec(v []driver.Value) { errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) } - w := st.cn.writeBuf('B') - w.string("") + cn := st.cn + w := cn.writeBuf('B') + w.byte(0) // unnamed portal w.string(st.name) - w.int16(0) - w.int16(len(v)) - for i, x := range v { - if x == nil { - w.int32(-1) - } else { - b := encode(&st.cn.parameterStatus, x, st.paramTyps[i]) - w.int32(len(b)) - w.bytes(b) + + if cn.binaryParameters { + cn.sendBinaryParameters(w, v) + } else { + w.int16(0) + w.int16(len(v)) + for i, x := range v { + if x == nil { + w.int32(-1) + } else { + b := encode(&cn.parameterStatus, x, st.paramTyps[i]) + w.int32(len(b)) + w.bytes(b) + } } } - w.int16(0) - st.cn.send(w) + w.bytes(st.colFmtData) - w = st.cn.writeBuf('E') - w.string("") + w.next('E') + w.byte(0) w.int32(0) - st.cn.send(w) - st.cn.send(st.cn.writeBuf('S')) + w.next('S') + cn.send(w) - var err error - for { - t, r := st.cn.recv1() - switch t { - case 'E': - err = parseError(r) - case '2': - if err != nil { - panic(err) - } - goto workaround - case 'Z': - st.cn.processReadyForQuery(r) - if err != nil { - panic(err) - } - return - default: - st.cn.bad = true - errorf("unexpected bind response: %q", t) - } - } + cn.readBindResponse() + cn.postExecuteWorkaround() - // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores - // any errors from rows.Next, which masks errors that happened during the - // execution of the query. To avoid the problem in common cases, we wait - // here for one more message from the database. If it's not an error the - // query will likely succeed (or perhaps has already, if it's a - // CommandComplete), so we push the message into the conn struct; recv1 - // will return it as the next message for rows.Next or rows.Close. - // However, if it's an error, we wait until ReadyForQuery and then return - // the error to our caller. -workaround: - for { - t, r := st.cn.recv1() - switch t { - case 'E': - err = parseError(r) - case 'C', 'D', 'I': - // the query didn't fail, but we can't process this message - st.cn.saveMessage(t, r) - return - case 'Z': - if err == nil { - st.cn.bad = true - errorf("unexpected ReadyForQuery during extended query execution") - } - st.cn.processReadyForQuery(r) - panic(err) - default: - st.cn.bad = true - errorf("unexpected message during query execution: %q", t) - } - } } func (st *stmt) NumInput() int { @@ -1260,19 +1318,33 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } type rows struct { - st *stmt - done bool - rb readBuf + cn *conn + finish func() + colNames []string + colTyps []oid.Oid + colFmts []format + done bool + rb readBuf + result driver.Result + tag string } func (rs *rows) Close() error { + if finish := rs.finish; finish != nil { + defer finish() + } // no need to look at cn.bad as Next() will for { err := rs.Next(nil) switch err { case nil: case io.EOF: - return nil + // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row + // description, used with HasNextResultSet). We need to fetch messages until + // we hit a 'Z', which is done by waiting for done to be set. + if rs.done { + return nil + } default: return err } @@ -1280,7 +1352,18 @@ func (rs *rows) Close() error { } func (rs *rows) Columns() []string { - return rs.st.cols + return rs.colNames +} + +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag } func (rs *rows) Next(dest []driver.Value) (err error) { @@ -1288,7 +1371,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { return io.EOF } - conn := rs.st.cn + conn := rs.cn if conn.bad { return driver.ErrBadConn } @@ -1300,6 +1383,9 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'E': err = parseError(&rs.rb) case 'C', 'I': + if t == 'C' { + rs.result, rs.tag = conn.parseComplete(rs.rb.string()) + } continue case 'Z': conn.processReadyForQuery(&rs.rb) @@ -1310,6 +1396,10 @@ func (rs *rows) Next(dest []driver.Value) (err error) { return io.EOF case 'D': n := rs.rb.int16() + if err != nil { + conn.bad = true + errorf("unexpected DataRow after error %s", err) + } if n < len(dest) { dest = dest[:n] } @@ -1319,15 +1409,26 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.st.rowTyps[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) } return + case 'T': + rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb) + return io.EOF default: errorf("unexpected message after execute: %q", t) } } } +func (rs *rows) HasNextResultSet() bool { + return !rs.done +} + +func (rs *rows) NextResultSet() error { + return nil +} + // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be // used as part of an SQL statement. For example: // @@ -1352,6 +1453,68 @@ func md5s(s string) string { return fmt.Sprintf("%x", h.Sum(nil)) } +func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { + // Do one pass over the parameters to see if we're going to send any of + // them over in binary. If we are, create a paramFormats array at the + // same time. + var paramFormats []int + for i, x := range args { + _, ok := x.([]byte) + if ok { + if paramFormats == nil { + paramFormats = make([]int, len(args)) + } + paramFormats[i] = 1 + } + } + if paramFormats == nil { + b.int16(0) + } else { + b.int16(len(paramFormats)) + for _, x := range paramFormats { + b.int16(x) + } + } + + b.int16(len(args)) + for _, x := range args { + if x == nil { + b.int32(-1) + } else { + datum := binaryEncode(&cn.parameterStatus, x) + b.int32(len(datum)) + b.bytes(datum) + } + } +} + +func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { + if len(args) >= 65536 { + errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) + } + + b := cn.writeBuf('P') + b.byte(0) // unnamed statement + b.string(query) + b.int16(0) + + b.next('B') + b.int16(0) // unnamed portal and statement + cn.sendBinaryParameters(b, args) + b.bytes(colFmtDataAllText) + + b.next('D') + b.byte('P') + b.byte(0) // unnamed portal + + b.next('E') + b.byte(0) + b.int32(0) + + b.next('S') + cn.send(b) +} + func (c *conn) processParameterStatus(r *readBuf) { var err error @@ -1381,15 +1544,186 @@ func (c *conn) processReadyForQuery(r *readBuf) { c.txnStatus = transactionStatus(r.byte()) } -func parseMeta(r *readBuf) (cols []string, rowTyps []oid.Oid) { +func (cn *conn) readReadyForQuery() { + t, r := cn.recv1() + switch t { + case 'Z': + cn.processReadyForQuery(r) + return + default: + cn.bad = true + errorf("unexpected message %q; expected ReadyForQuery", t) + } +} + +func (c *conn) processBackendKeyData(r *readBuf) { + c.processID = r.int32() + c.secretKey = r.int32() +} + +func (cn *conn) readParseResponse() { + t, r := cn.recv1() + switch t { + case '1': + return + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + default: + cn.bad = true + errorf("unexpected Parse response %q", t) + } +} + +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) { + for { + t, r := cn.recv1() + switch t { + case 't': + nparams := r.int16() + paramTyps = make([]oid.Oid, nparams) + for i := range paramTyps { + paramTyps[i] = r.oid() + } + case 'n': + return paramTyps, nil, nil + case 'T': + colNames, colTyps = parseStatementRowDescribe(r) + return paramTyps, colNames, colTyps + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + default: + cn.bad = true + errorf("unexpected Describe statement response %q", t) + } + } +} + +func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) { + t, r := cn.recv1() + switch t { + case 'T': + return parsePortalRowDescribe(r) + case 'n': + return nil, nil, nil + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + default: + cn.bad = true + errorf("unexpected Describe response %q", t) + } + panic("not reached") +} + +func (cn *conn) readBindResponse() { + t, r := cn.recv1() + switch t { + case '2': + return + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + default: + cn.bad = true + errorf("unexpected Bind response %q", t) + } +} + +func (cn *conn) postExecuteWorkaround() { + // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores + // any errors from rows.Next, which masks errors that happened during the + // execution of the query. To avoid the problem in common cases, we wait + // here for one more message from the database. If it's not an error the + // query will likely succeed (or perhaps has already, if it's a + // CommandComplete), so we push the message into the conn struct; recv1 + // will return it as the next message for rows.Next or rows.Close. + // However, if it's an error, we wait until ReadyForQuery and then return + // the error to our caller. + for { + t, r := cn.recv1() + switch t { + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + case 'C', 'D', 'I': + // the query didn't fail, but we can't process this message + cn.saveMessage(t, r) + return + default: + cn.bad = true + errorf("unexpected message during extended query execution: %q", t) + } + } +} + +// Only for Exec(), since we ignore the returned data +func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { + for { + t, r := cn.recv1() + switch t { + case 'C': + if err != nil { + cn.bad = true + errorf("unexpected CommandComplete after error %s", err) + } + res, commandTag = cn.parseComplete(r.string()) + case 'Z': + cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } + return res, commandTag, err + case 'E': + err = parseError(r) + case 'T', 'D', 'I': + if err != nil { + cn.bad = true + errorf("unexpected %q after error %s", t, err) + } + if t == 'I' { + res = emptyRows + } + // ignore any results + default: + cn.bad = true + errorf("unknown %s response: %q", protocolState, t) + } + } +} + +func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) { n := r.int16() - cols = make([]string, n) - rowTyps = make([]oid.Oid, n) - for i := range cols { - cols[i] = r.string() + colNames = make([]string, n) + colTyps = make([]oid.Oid, n) + for i := range colNames { + colNames[i] = r.string() r.next(6) - rowTyps[i] = r.oid() - r.next(8) + colTyps[i] = r.oid() + r.next(6) + // format code not known when describing a statement; always 0 + r.next(2) + } + return +} + +func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) { + n := r.int16() + colNames = make([]string, n) + colFmts = make([]format, n) + colTyps = make([]oid.Oid, n) + for i := range colNames { + colNames[i] = r.string() + r.next(6) + colTyps[i] = r.oid() + r.next(6) + colFmts[i] = format(r.int16()) } return } @@ -1434,7 +1768,7 @@ func parseEnviron(env []string) (out map[string]string) { accrue("user") case "PGPASSWORD": accrue("password") - case "PGPASSFILE", "PGSERVICE", "PGSERVICEFILE", "PGREALM": + case "PGSERVICE", "PGSERVICEFILE", "PGREALM": unsupported() case "PGOPTIONS": accrue("options") diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go new file mode 100644 index 00000000000..ab97a104d83 --- /dev/null +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -0,0 +1,128 @@ +// +build go1.8 + +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "io/ioutil" +) + +// Implement the "QueryerContext" interface +func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + finish := cn.watchCancel(ctx) + r, err := cn.query(query, list) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil +} + +// Implement the "ExecerContext" interface +func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + + return cn.Exec(query, list) +} + +// Implement the "ConnBeginTx" interface +func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var mode string + + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + // Don't touch mode: use the server's default + case sql.LevelReadUncommitted: + mode = " ISOLATION LEVEL READ UNCOMMITTED" + case sql.LevelReadCommitted: + mode = " ISOLATION LEVEL READ COMMITTED" + case sql.LevelRepeatableRead: + mode = " ISOLATION LEVEL REPEATABLE READ" + case sql.LevelSerializable: + mode = " ISOLATION LEVEL SERIALIZABLE" + default: + return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) + } + + if opts.ReadOnly { + mode += " READ ONLY" + } else { + mode += " READ WRITE" + } + + tx, err := cn.begin(mode) + if err != nil { + return nil, err + } + cn.txnFinish = cn.watchCancel(ctx) + return tx, nil +} + +func (cn *conn) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}) + go func() { + select { + case <-done: + _ = cn.cancel() + finished <- struct{}{} + case <-finished: + } + }() + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } + } + } + return nil +} + +func (cn *conn) cancel() error { + c, err := dial(cn.dialer, cn.opts) + if err != nil { + return err + } + defer c.Close() + + { + can := conn{ + c: c, + } + can.ssl(cn.opts) + + w := can.writeBuf(0) + w.int32(80877102) // cancel request code + w.int32(cn.processID) + w.int32(cn.secretKey) + + if err := can.sendStartupPacket(w); err != nil { + return err + } + } + + // Read until EOF to ensure that the server received the cancel. + { + _, err := io.Copy(ioutil.Discard, c) + return err + } +} diff --git a/vendor/github.com/lib/pq/conn_test.go b/vendor/github.com/lib/pq/conn_test.go deleted file mode 100644 index 6c3c6b54629..00000000000 --- a/vendor/github.com/lib/pq/conn_test.go +++ /dev/null @@ -1,1286 +0,0 @@ -package pq - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "io" - "os" - "reflect" - "testing" - "time" -) - -type Fatalistic interface { - Fatal(args ...interface{}) -} - -func openTestConnConninfo(conninfo string) (*sql.DB, error) { - datname := os.Getenv("PGDATABASE") - sslmode := os.Getenv("PGSSLMODE") - timeout := os.Getenv("PGCONNECT_TIMEOUT") - - if datname == "" { - os.Setenv("PGDATABASE", "pqgotest") - } - - if sslmode == "" { - os.Setenv("PGSSLMODE", "disable") - } - - if timeout == "" { - os.Setenv("PGCONNECT_TIMEOUT", "20") - } - - return sql.Open("postgres", conninfo) -} - -func openTestConn(t Fatalistic) *sql.DB { - conn, err := openTestConnConninfo("") - if err != nil { - t.Fatal(err) - } - - return conn -} - -func getServerVersion(t *testing.T, db *sql.DB) int { - var version int - err := db.QueryRow("SHOW server_version_num").Scan(&version) - if err != nil { - t.Fatal(err) - } - return version -} - -func TestReconnect(t *testing.T) { - db1 := openTestConn(t) - defer db1.Close() - tx, err := db1.Begin() - if err != nil { - t.Fatal(err) - } - var pid1 int - err = tx.QueryRow("SELECT pg_backend_pid()").Scan(&pid1) - if err != nil { - t.Fatal(err) - } - db2 := openTestConn(t) - defer db2.Close() - _, err = db2.Exec("SELECT pg_terminate_backend($1)", pid1) - if err != nil { - t.Fatal(err) - } - // The rollback will probably "fail" because we just killed - // its connection above - _ = tx.Rollback() - - const expected int = 42 - var result int - err = db1.QueryRow(fmt.Sprintf("SELECT %d", expected)).Scan(&result) - if err != nil { - t.Fatal(err) - } - if result != expected { - t.Errorf("got %v; expected %v", result, expected) - } -} - -func TestCommitInFailedTransaction(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - rows, err := txn.Query("SELECT error") - if err == nil { - rows.Close() - t.Fatal("expected failure") - } - err = txn.Commit() - if err != ErrInFailedTransaction { - t.Fatalf("expected ErrInFailedTransaction; got %#v", err) - } -} - -func TestOpenURL(t *testing.T) { - db, err := openTestConnConninfo("postgres://") - if err != nil { - t.Fatal(err) - } - defer db.Close() - // database/sql might not call our Open at all unless we do something with - // the connection - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - txn.Rollback() -} - -func TestExec(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("CREATE TEMP TABLE temp (a int)") - if err != nil { - t.Fatal(err) - } - - r, err := db.Exec("INSERT INTO temp VALUES (1)") - if err != nil { - t.Fatal(err) - } - - if n, _ := r.RowsAffected(); n != 1 { - t.Fatalf("expected 1 row affected, not %d", n) - } - - r, err = db.Exec("INSERT INTO temp VALUES ($1), ($2), ($3)", 1, 2, 3) - if err != nil { - t.Fatal(err) - } - - if n, _ := r.RowsAffected(); n != 3 { - t.Fatalf("expected 3 rows affected, not %d", n) - } - - // SELECT doesn't send the number of returned rows in the command tag - // before 9.0 - if getServerVersion(t, db) >= 90000 { - r, err = db.Exec("SELECT g FROM generate_series(1, 2) g") - if err != nil { - t.Fatal(err) - } - if n, _ := r.RowsAffected(); n != 2 { - t.Fatalf("expected 2 rows affected, not %d", n) - } - - r, err = db.Exec("SELECT g FROM generate_series(1, $1) g", 3) - if err != nil { - t.Fatal(err) - } - if n, _ := r.RowsAffected(); n != 3 { - t.Fatalf("expected 3 rows affected, not %d", n) - } - } -} - -func TestStatment(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - st, err := db.Prepare("SELECT 1") - if err != nil { - t.Fatal(err) - } - - st1, err := db.Prepare("SELECT 2") - if err != nil { - t.Fatal(err) - } - - r, err := st.Query() - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if !r.Next() { - t.Fatal("expected row") - } - - var i int - err = r.Scan(&i) - if err != nil { - t.Fatal(err) - } - - if i != 1 { - t.Fatalf("expected 1, got %d", i) - } - - // st1 - - r1, err := st1.Query() - if err != nil { - t.Fatal(err) - } - defer r1.Close() - - if !r1.Next() { - if r.Err() != nil { - t.Fatal(r1.Err()) - } - t.Fatal("expected row") - } - - err = r1.Scan(&i) - if err != nil { - t.Fatal(err) - } - - if i != 2 { - t.Fatalf("expected 2, got %d", i) - } -} - -func TestRowsCloseBeforeDone(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - r, err := db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - - err = r.Close() - if err != nil { - t.Fatal(err) - } - - if r.Next() { - t.Fatal("unexpected row") - } - - if r.Err() != nil { - t.Fatal(r.Err()) - } -} - -func TestParameterCountMismatch(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - var notused int - err := db.QueryRow("SELECT false", 1).Scan(¬used) - if err == nil { - t.Fatal("expected err") - } - // make sure we clean up correctly - err = db.QueryRow("SELECT 1").Scan(¬used) - if err != nil { - t.Fatal(err) - } - - err = db.QueryRow("SELECT $1").Scan(¬used) - if err == nil { - t.Fatal("expected err") - } - // make sure we clean up correctly - err = db.QueryRow("SELECT 1").Scan(¬used) - if err != nil { - t.Fatal(err) - } -} - -// Test that EmptyQueryResponses are handled correctly. -func TestEmptyQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("") - if err != nil { - t.Fatal(err) - } - rows, err := db.Query("") - if err != nil { - t.Fatal(err) - } - cols, err := rows.Columns() - if err != nil { - t.Fatal(err) - } - if len(cols) != 0 { - t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) - } - if rows.Next() { - t.Fatal("unexpected row") - } - if rows.Err() != nil { - t.Fatal(rows.Err()) - } - - stmt, err := db.Prepare("") - if err != nil { - t.Fatal(err) - } - _, err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - rows, err = stmt.Query() - if err != nil { - t.Fatal(err) - } - cols, err = rows.Columns() - if err != nil { - t.Fatal(err) - } - if len(cols) != 0 { - t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) - } - if rows.Next() { - t.Fatal("unexpected row") - } - if rows.Err() != nil { - t.Fatal(rows.Err()) - } -} - -func TestEncodeDecode(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - q := ` - SELECT - E'\\000\\001\\002'::bytea, - 'foobar'::text, - NULL::integer, - '2000-1-1 01:02:03.04-7'::timestamptz, - 0::boolean, - 123, - 3.14::float8 - WHERE - E'\\000\\001\\002'::bytea = $1 - AND 'foobar'::text = $2 - AND $3::integer is NULL - ` - // AND '2000-1-1 12:00:00.000000-7'::timestamp = $3 - - exp1 := []byte{0, 1, 2} - exp2 := "foobar" - - r, err := db.Query(q, exp1, exp2, nil) - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if !r.Next() { - if r.Err() != nil { - t.Fatal(r.Err()) - } - t.Fatal("expected row") - } - - var got1 []byte - var got2 string - var got3 = sql.NullInt64{Valid: true} - var got4 time.Time - var got5, got6, got7 interface{} - - err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(exp1, got1) { - t.Errorf("expected %q byte: %q", exp1, got1) - } - - if !reflect.DeepEqual(exp2, got2) { - t.Errorf("expected %q byte: %q", exp2, got2) - } - - if got3.Valid { - t.Fatal("expected invalid") - } - - if got4.Year() != 2000 { - t.Fatal("wrong year") - } - - if got5 != false { - t.Fatalf("expected false, got %q", got5) - } - - if got6 != int64(123) { - t.Fatalf("expected 123, got %d", got6) - } - - if got7 != float64(3.14) { - t.Fatalf("expected 3.14, got %f", got7) - } -} - -func TestNoData(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - st, err := db.Prepare("SELECT 1 WHERE true = false") - if err != nil { - t.Fatal(err) - } - defer st.Close() - - r, err := st.Query() - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if r.Next() { - if r.Err() != nil { - t.Fatal(r.Err()) - } - t.Fatal("unexpected row") - } - - _, err = db.Query("SELECT * FROM nonexistenttable WHERE age=$1", 20) - if err == nil { - t.Fatal("Should have raised an error on non existent table") - } - - _, err = db.Query("SELECT * FROM nonexistenttable") - if err == nil { - t.Fatal("Should have raised an error on non existent table") - } -} - -func TestErrorDuringStartup(t *testing.T) { - // Don't use the normal connection setup, this is intended to - // blow up in the startup packet from a non-existent user. - db, err := openTestConnConninfo("user=thisuserreallydoesntexist") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - _, err = db.Begin() - if err == nil { - t.Fatal("expected error") - } - - e, ok := err.(*Error) - if !ok { - t.Fatalf("expected Error, got %#v", err) - } else if e.Code.Name() != "invalid_authorization_specification" && e.Code.Name() != "invalid_password" { - t.Fatalf("expected invalid_authorization_specification or invalid_password, got %s (%+v)", e.Code.Name(), err) - } -} - -func TestBadConn(t *testing.T) { - var err error - - cn := conn{} - func() { - defer cn.errRecover(&err) - panic(io.EOF) - }() - if err != driver.ErrBadConn { - t.Fatalf("expected driver.ErrBadConn, got: %#v", err) - } - if !cn.bad { - t.Fatalf("expected cn.bad") - } - - cn = conn{} - func() { - defer cn.errRecover(&err) - e := &Error{Severity: Efatal} - panic(e) - }() - if err != driver.ErrBadConn { - t.Fatalf("expected driver.ErrBadConn, got: %#v", err) - } - if !cn.bad { - t.Fatalf("expected cn.bad") - } -} - -func TestErrorOnExec(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)") - if err != nil { - t.Fatal(err) - } - - _, err = txn.Exec("INSERT INTO foo VALUES (0), (0)") - if err == nil { - t.Fatal("Should have raised error") - } - - e, ok := err.(*Error) - if !ok { - t.Fatalf("expected Error, got %#v", err) - } else if e.Code.Name() != "unique_violation" { - t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err) - } -} - -func TestErrorOnQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)") - if err != nil { - t.Fatal(err) - } - - _, err = txn.Query("INSERT INTO foo VALUES (0), (0)") - if err == nil { - t.Fatal("Should have raised error") - } - - e, ok := err.(*Error) - if !ok { - t.Fatalf("expected Error, got %#v", err) - } else if e.Code.Name() != "unique_violation" { - t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err) - } -} - -func TestErrorOnQueryRowSimpleQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)") - if err != nil { - t.Fatal(err) - } - - var v int - err = txn.QueryRow("INSERT INTO foo VALUES (0), (0)").Scan(&v) - if err == nil { - t.Fatal("Should have raised error") - } - - e, ok := err.(*Error) - if !ok { - t.Fatalf("expected Error, got %#v", err) - } else if e.Code.Name() != "unique_violation" { - t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err) - } -} - -// Test the QueryRow bug workarounds in stmt.exec() and simpleQuery() -func TestQueryRowBugWorkaround(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - // stmt.exec() - _, err := db.Exec("CREATE TEMP TABLE notnulltemp (a varchar(10) not null)") - if err != nil { - t.Fatal(err) - } - - var a string - err = db.QueryRow("INSERT INTO notnulltemp(a) values($1) RETURNING a", nil).Scan(&a) - if err == sql.ErrNoRows { - t.Fatalf("expected constraint violation error; got: %v", err) - } - pge, ok := err.(*Error) - if !ok { - t.Fatalf("expected *Error; got: %#v", err) - } - if pge.Code.Name() != "not_null_violation" { - t.Fatalf("expected not_null_violation; got: %s (%+v)", pge.Code.Name(), err) - } - - // Test workaround in simpleQuery() - tx, err := db.Begin() - if err != nil { - t.Fatalf("unexpected error %s in Begin", err) - } - defer tx.Rollback() - - _, err = tx.Exec("SET LOCAL check_function_bodies TO FALSE") - if err != nil { - t.Fatalf("could not disable check_function_bodies: %s", err) - } - _, err = tx.Exec(` -CREATE OR REPLACE FUNCTION bad_function() -RETURNS integer --- hack to prevent the function from being inlined -SET check_function_bodies TO TRUE -AS $$ - SELECT text 'bad' -$$ LANGUAGE sql`) - if err != nil { - t.Fatalf("could not create function: %s", err) - } - - err = tx.QueryRow("SELECT * FROM bad_function()").Scan(&a) - if err == nil { - t.Fatalf("expected error") - } - pge, ok = err.(*Error) - if !ok { - t.Fatalf("expected *Error; got: %#v", err) - } - if pge.Code.Name() != "invalid_function_definition" { - t.Fatalf("expected invalid_function_definition; got: %s (%+v)", pge.Code.Name(), err) - } - - err = tx.Rollback() - if err != nil { - t.Fatalf("unexpected error %s in Rollback", err) - } - - // Also test that simpleQuery()'s workaround works when the query fails - // after a row has been received. - rows, err := db.Query(` -select - (select generate_series(1, ss.i)) -from (select gs.i - from generate_series(1, 2) gs(i) - order by gs.i limit 2) ss`) - if err != nil { - t.Fatalf("query failed: %s", err) - } - if !rows.Next() { - t.Fatalf("expected at least one result row; got %s", rows.Err()) - } - var i int - err = rows.Scan(&i) - if err != nil { - t.Fatalf("rows.Scan() failed: %s", err) - } - if i != 1 { - t.Fatalf("unexpected value for i: %d", i) - } - if rows.Next() { - t.Fatalf("unexpected row") - } - pge, ok = rows.Err().(*Error) - if !ok { - t.Fatalf("expected *Error; got: %#v", err) - } - if pge.Code.Name() != "cardinality_violation" { - t.Fatalf("expected cardinality_violation; got: %s (%+v)", pge.Code.Name(), rows.Err()) - } -} - -func TestSimpleQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - r, err := db.Query("select 1") - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if !r.Next() { - t.Fatal("expected row") - } -} - -func TestBindError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("create temp table test (i integer)") - if err != nil { - t.Fatal(err) - } - - _, err = db.Query("select * from test where i=$1", "hhh") - if err == nil { - t.Fatal("expected an error") - } - - // Should not get error here - r, err := db.Query("select * from test where i=$1", 1) - if err != nil { - t.Fatal(err) - } - defer r.Close() -} - -func TestParseErrorInExtendedQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - rows, err := db.Query("PARSE_ERROR $1", 1) - if err == nil { - t.Fatal("expected error") - } - - rows, err = db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - rows.Close() -} - -// TestReturning tests that an INSERT query using the RETURNING clause returns a row. -func TestReturning(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("CREATE TEMP TABLE distributors (did integer default 0, dname text)") - if err != nil { - t.Fatal(err) - } - - rows, err := db.Query("INSERT INTO distributors (did, dname) VALUES (DEFAULT, 'XYZ Widgets') " + - "RETURNING did;") - if err != nil { - t.Fatal(err) - } - if !rows.Next() { - t.Fatal("no rows") - } - var did int - err = rows.Scan(&did) - if err != nil { - t.Fatal(err) - } - if did != 0 { - t.Fatalf("bad value for did: got %d, want %d", did, 0) - } - - if rows.Next() { - t.Fatal("unexpected next row") - } - err = rows.Err() - if err != nil { - t.Fatal(err) - } -} - -func TestIssue186(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - // Exec() a query which returns results - _, err := db.Exec("VALUES (1), (2), (3)") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("VALUES ($1), ($2), ($3)", 1, 2, 3) - if err != nil { - t.Fatal(err) - } - - // Query() a query which doesn't return any results - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - rows, err := txn.Query("CREATE TEMP TABLE foo(f1 int)") - if err != nil { - t.Fatal(err) - } - if err = rows.Close(); err != nil { - t.Fatal(err) - } - - // small trick to get NoData from a parameterized query - _, err = txn.Exec("CREATE RULE nodata AS ON INSERT TO foo DO INSTEAD NOTHING") - if err != nil { - t.Fatal(err) - } - rows, err = txn.Query("INSERT INTO foo VALUES ($1)", 1) - if err != nil { - t.Fatal(err) - } - if err = rows.Close(); err != nil { - t.Fatal(err) - } -} - -func TestIssue196(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - row := db.QueryRow("SELECT float4 '0.10000122' = $1, float8 '35.03554004971999' = $2", - float32(0.10000122), float64(35.03554004971999)) - - var float4match, float8match bool - err := row.Scan(&float4match, &float8match) - if err != nil { - t.Fatal(err) - } - if !float4match { - t.Errorf("Expected float4 fidelity to be maintained; got no match") - } - if !float8match { - t.Errorf("Expected float8 fidelity to be maintained; got no match") - } -} - -// Test that any CommandComplete messages sent before the query results are -// ignored. -func TestIssue282(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - var search_path string - err := db.QueryRow(` - SET LOCAL search_path TO pg_catalog; - SET LOCAL search_path TO pg_catalog; - SHOW search_path`).Scan(&search_path) - if err != nil { - t.Fatal(err) - } - if search_path != "pg_catalog" { - t.Fatalf("unexpected search_path %s", search_path) - } -} - -func TestReadFloatPrecision(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - row := db.QueryRow("SELECT float4 '0.10000122', float8 '35.03554004971999'") - var float4val float32 - var float8val float64 - err := row.Scan(&float4val, &float8val) - if err != nil { - t.Fatal(err) - } - if float4val != float32(0.10000122) { - t.Errorf("Expected float4 fidelity to be maintained; got no match") - } - if float8val != float64(35.03554004971999) { - t.Errorf("Expected float8 fidelity to be maintained; got no match") - } -} - -func TestXactMultiStmt(t *testing.T) { - // minified test case based on bug reports from - // pico303@gmail.com and rangelspam@gmail.com - t.Skip("Skipping failing test") - db := openTestConn(t) - defer db.Close() - - tx, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer tx.Commit() - - rows, err := tx.Query("select 1") - if err != nil { - t.Fatal(err) - } - - if rows.Next() { - var val int32 - if err = rows.Scan(&val); err != nil { - t.Fatal(err) - } - } else { - t.Fatal("Expected at least one row in first query in xact") - } - - rows2, err := tx.Query("select 2") - if err != nil { - t.Fatal(err) - } - - if rows2.Next() { - var val2 int32 - if err := rows2.Scan(&val2); err != nil { - t.Fatal(err) - } - } else { - t.Fatal("Expected at least one row in second query in xact") - } - - if err = rows.Err(); err != nil { - t.Fatal(err) - } - - if err = rows2.Err(); err != nil { - t.Fatal(err) - } - - if err = tx.Commit(); err != nil { - t.Fatal(err) - } -} - -var envParseTests = []struct { - Expected map[string]string - Env []string -}{ - { - Env: []string{"PGDATABASE=hello", "PGUSER=goodbye"}, - Expected: map[string]string{"dbname": "hello", "user": "goodbye"}, - }, - { - Env: []string{"PGDATESTYLE=ISO, MDY"}, - Expected: map[string]string{"datestyle": "ISO, MDY"}, - }, - { - Env: []string{"PGCONNECT_TIMEOUT=30"}, - Expected: map[string]string{"connect_timeout": "30"}, - }, -} - -func TestParseEnviron(t *testing.T) { - for i, tt := range envParseTests { - results := parseEnviron(tt.Env) - if !reflect.DeepEqual(tt.Expected, results) { - t.Errorf("%d: Expected: %#v Got: %#v", i, tt.Expected, results) - } - } -} - -func TestParseComplete(t *testing.T) { - tpc := func(commandTag string, command string, affectedRows int64, shouldFail bool) { - defer func() { - if p := recover(); p != nil { - if !shouldFail { - t.Error(p) - } - } - }() - cn := &conn{} - res, c := cn.parseComplete(commandTag) - if c != command { - t.Errorf("Expected %v, got %v", command, c) - } - n, err := res.RowsAffected() - if err != nil { - t.Fatal(err) - } - if n != affectedRows { - t.Errorf("Expected %d, got %d", affectedRows, n) - } - } - - tpc("ALTER TABLE", "ALTER TABLE", 0, false) - tpc("INSERT 0 1", "INSERT", 1, false) - tpc("UPDATE 100", "UPDATE", 100, false) - tpc("SELECT 100", "SELECT", 100, false) - tpc("FETCH 100", "FETCH", 100, false) - // allow COPY (and others) without row count - tpc("COPY", "COPY", 0, false) - // don't fail on command tags we don't recognize - tpc("UNKNOWNCOMMANDTAG", "UNKNOWNCOMMANDTAG", 0, false) - - // failure cases - tpc("INSERT 1", "", 0, true) // missing oid - tpc("UPDATE 0 1", "", 0, true) // too many numbers - tpc("SELECT foo", "", 0, true) // invalid row count -} - -func TestExecerInterface(t *testing.T) { - // Gin up a straw man private struct just for the type check - cn := &conn{c: nil} - var cni interface{} = cn - - _, ok := cni.(driver.Execer) - if !ok { - t.Fatal("Driver doesn't implement Execer") - } -} - -func TestNullAfterNonNull(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - r, err := db.Query("SELECT 9::integer UNION SELECT NULL::integer") - if err != nil { - t.Fatal(err) - } - - var n sql.NullInt64 - - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected row") - } - - if err := r.Scan(&n); err != nil { - t.Fatal(err) - } - - if n.Int64 != 9 { - t.Fatalf("expected 2, not %d", n.Int64) - } - - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected row") - } - - if err := r.Scan(&n); err != nil { - t.Fatal(err) - } - - if n.Valid { - t.Fatal("expected n to be invalid") - } - - if n.Int64 != 0 { - t.Fatalf("expected n to 2, not %d", n.Int64) - } -} - -func Test64BitErrorChecking(t *testing.T) { - defer func() { - if err := recover(); err != nil { - t.Fatal("panic due to 0xFFFFFFFF != -1 " + - "when int is 64 bits") - } - }() - - db := openTestConn(t) - defer db.Close() - - r, err := db.Query(`SELECT * -FROM (VALUES (0::integer, NULL::text), (1, 'test string')) AS t;`) - - if err != nil { - t.Fatal(err) - } - - defer r.Close() - - for r.Next() { - } -} - -func TestCommit(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("CREATE TEMP TABLE temp (a int)") - if err != nil { - t.Fatal(err) - } - sqlInsert := "INSERT INTO temp VALUES (1)" - sqlSelect := "SELECT * FROM temp" - tx, err := db.Begin() - if err != nil { - t.Fatal(err) - } - _, err = tx.Exec(sqlInsert) - if err != nil { - t.Fatal(err) - } - err = tx.Commit() - if err != nil { - t.Fatal(err) - } - var i int - err = db.QueryRow(sqlSelect).Scan(&i) - if err != nil { - t.Fatal(err) - } - if i != 1 { - t.Fatalf("expected 1, got %d", i) - } -} - -func TestErrorClass(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Query("SELECT int 'notint'") - if err == nil { - t.Fatal("expected error") - } - pge, ok := err.(*Error) - if !ok { - t.Fatalf("expected *pq.Error, got %#+v", err) - } - if pge.Code.Class() != "22" { - t.Fatalf("expected class 28, got %v", pge.Code.Class()) - } - if pge.Code.Class().Name() != "data_exception" { - t.Fatalf("expected data_exception, got %v", pge.Code.Class().Name()) - } -} - -func TestParseOpts(t *testing.T) { - tests := []struct { - in string - expected values - valid bool - }{ - {"dbname=hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname=hello user=goodbye ", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname = hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname=hello user =goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname=hello user= goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"host=localhost password='correct horse battery staple'", values{"host": "localhost", "password": "correct horse battery staple"}, true}, - {"dbname=データベース password=パスワード", values{"dbname": "データベース", "password": "パスワード"}, true}, - {"dbname=hello user=''", values{"dbname": "hello", "user": ""}, true}, - {"user='' dbname=hello", values{"dbname": "hello", "user": ""}, true}, - // The last option value is an empty string if there's no non-whitespace after its = - {"dbname=hello user= ", values{"dbname": "hello", "user": ""}, true}, - - // The parser ignores spaces after = and interprets the next set of non-whitespace characters as the value. - {"user= password=foo", values{"user": "password=foo"}, true}, - - // Backslash escapes next char - {`user=a\ \'\\b`, values{"user": `a '\b`}, true}, - {`user='a \'b'`, values{"user": `a 'b`}, true}, - - // Incomplete escape - {`user=x\`, values{}, false}, - - // No '=' after the key - {"postgre://marko@internet", values{}, false}, - {"dbname user=goodbye", values{}, false}, - {"user=foo blah", values{}, false}, - {"user=foo blah ", values{}, false}, - - // Unterminated quoted value - {"dbname=hello user='unterminated", values{}, false}, - } - - for _, test := range tests { - o := make(values) - err := parseOpts(test.in, o) - - switch { - case err != nil && test.valid: - t.Errorf("%q got unexpected error: %s", test.in, err) - case err == nil && test.valid && !reflect.DeepEqual(test.expected, o): - t.Errorf("%q got: %#v want: %#v", test.in, o, test.expected) - case err == nil && !test.valid: - t.Errorf("%q expected an error", test.in) - } - } -} - -func TestRuntimeParameters(t *testing.T) { - type RuntimeTestResult int - const ( - ResultUnknown RuntimeTestResult = iota - ResultSuccess - ResultError // other error - ) - - tests := []struct { - conninfo string - param string - expected string - expectedOutcome RuntimeTestResult - }{ - // invalid parameter - {"DOESNOTEXIST=foo", "", "", ResultError}, - // we can only work with a specific value for these two - {"client_encoding=SQL_ASCII", "", "", ResultError}, - {"datestyle='ISO, YDM'", "", "", ResultError}, - // "options" should work exactly as it does in libpq - {"options='-c search_path=pqgotest'", "search_path", "pqgotest", ResultSuccess}, - // pq should override client_encoding in this case - {"options='-c client_encoding=SQL_ASCII'", "client_encoding", "UTF8", ResultSuccess}, - // allow client_encoding to be set explicitly - {"client_encoding=UTF8", "client_encoding", "UTF8", ResultSuccess}, - // test a runtime parameter not supported by libpq - {"work_mem='139kB'", "work_mem", "139kB", ResultSuccess}, - // test fallback_application_name - {"application_name=foo fallback_application_name=bar", "application_name", "foo", ResultSuccess}, - {"application_name='' fallback_application_name=bar", "application_name", "", ResultSuccess}, - {"fallback_application_name=bar", "application_name", "bar", ResultSuccess}, - } - - for _, test := range tests { - db, err := openTestConnConninfo(test.conninfo) - if err != nil { - t.Fatal(err) - } - - // application_name didn't exist before 9.0 - if test.param == "application_name" && getServerVersion(t, db) < 90000 { - db.Close() - continue - } - - tryGetParameterValue := func() (value string, outcome RuntimeTestResult) { - defer db.Close() - row := db.QueryRow("SELECT current_setting($1)", test.param) - err = row.Scan(&value) - if err != nil { - return "", ResultError - } - return value, ResultSuccess - } - - value, outcome := tryGetParameterValue() - if outcome != test.expectedOutcome && outcome == ResultError { - t.Fatalf("%v: unexpected error: %v", test.conninfo, err) - } - if outcome != test.expectedOutcome { - t.Fatalf("unexpected outcome %v (was expecting %v) for conninfo \"%s\"", - outcome, test.expectedOutcome, test.conninfo) - } - if value != test.expected { - t.Fatalf("bad value for %s: got %s, want %s with conninfo \"%s\"", - test.param, value, test.expected, test.conninfo) - } - } -} - -func TestIsUTF8(t *testing.T) { - var cases = []struct { - name string - want bool - }{ - {"unicode", true}, - {"utf-8", true}, - {"utf_8", true}, - {"UTF-8", true}, - {"UTF8", true}, - {"utf8", true}, - {"u n ic_ode", true}, - {"ut_f%8", true}, - {"ubf8", false}, - {"punycode", false}, - } - - for _, test := range cases { - if g := isUTF8(test.name); g != test.want { - t.Errorf("isUTF8(%q) = %v want %v", test.name, g, test.want) - } - } -} - -func TestQuoteIdentifier(t *testing.T) { - var cases = []struct { - input string - want string - }{ - {`foo`, `"foo"`}, - {`foo bar baz`, `"foo bar baz"`}, - {`foo"bar`, `"foo""bar"`}, - {"foo\x00bar", `"foo"`}, - {"\x00foo", `""`}, - } - - for _, test := range cases { - got := QuoteIdentifier(test.input) - if got != test.want { - t.Errorf("QuoteIdentifier(%q) = %v want %v", test.input, got, test.want) - } - } -} diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go index 18c04e7407e..345c2398f6c 100644 --- a/vendor/github.com/lib/pq/copy.go +++ b/vendor/github.com/lib/pq/copy.go @@ -13,6 +13,7 @@ var ( errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") errCopyToNotSupported = errors.New("pq: COPY TO is not supported") errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") + errCopyInProgress = errors.New("pq: COPY in progress") ) // CopyIn creates a COPY FROM statement which can be prepared with @@ -96,13 +97,13 @@ awaitCopyInResponse: err = parseError(r) case 'Z': if err == nil { - cn.bad = true + ci.setBad() errorf("unexpected ReadyForQuery in response to COPY") } cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for copy query: %q", t) } } @@ -121,7 +122,7 @@ awaitCopyInResponse: cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for CopyFail: %q", t) } } @@ -142,7 +143,7 @@ func (ci *copyin) resploop() { var r readBuf t, err := ci.cn.recvMessage(&r) if err != nil { - ci.cn.bad = true + ci.setBad() ci.setError(err) ci.done <- true return @@ -150,6 +151,8 @@ func (ci *copyin) resploop() { switch t { case 'C': // complete + case 'N': + // NoticeResponse case 'Z': ci.cn.processReadyForQuery(&r) ci.done <- true @@ -158,7 +161,7 @@ func (ci *copyin) resploop() { err := parseError(&r) ci.setError(err) default: - ci.cn.bad = true + ci.setBad() ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) ci.done <- true return @@ -166,6 +169,19 @@ func (ci *copyin) resploop() { } } +func (ci *copyin) setBad() { + ci.Lock() + ci.cn.bad = true + ci.Unlock() +} + +func (ci *copyin) isBad() bool { + ci.Lock() + b := ci.cn.bad + ci.Unlock() + return b +} + func (ci *copyin) isErrorSet() bool { ci.Lock() isSet := (ci.err != nil) @@ -203,7 +219,7 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { return nil, errCopyInClosed } - if ci.cn.bad { + if ci.isBad() { return nil, driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -213,9 +229,7 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { } if len(v) == 0 { - err = ci.Close() - ci.closed = true - return nil, err + return nil, ci.Close() } numValues := len(v) @@ -238,11 +252,12 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { } func (ci *copyin) Close() (err error) { - if ci.closed { - return errCopyInClosed + if ci.closed { // Don't do anything, we're already closed + return nil } + ci.closed = true - if ci.cn.bad { + if ci.isBad() { return driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -257,6 +272,7 @@ func (ci *copyin) Close() (err error) { } <-ci.done + ci.cn.inCopy = false if ci.isErrorSet() { err = ci.err diff --git a/vendor/github.com/lib/pq/copy_test.go b/vendor/github.com/lib/pq/copy_test.go deleted file mode 100644 index e489f22a86b..00000000000 --- a/vendor/github.com/lib/pq/copy_test.go +++ /dev/null @@ -1,380 +0,0 @@ -package pq - -import ( - "bytes" - "database/sql" - "strings" - "testing" -) - -func TestCopyInStmt(t *testing.T) { - var stmt string - stmt = CopyIn("table name") - if stmt != `COPY "table name" () FROM STDIN` { - t.Fatal(stmt) - } - - stmt = CopyIn("table name", "column 1", "column 2") - if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` { - t.Fatal(stmt) - } - - stmt = CopyIn(`table " name """`, `co"lumn""`) - if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` { - t.Fatal(stmt) - } -} - -func TestCopyInSchemaStmt(t *testing.T) { - var stmt string - stmt = CopyInSchema("schema name", "table name") - if stmt != `COPY "schema name"."table name" () FROM STDIN` { - t.Fatal(stmt) - } - - stmt = CopyInSchema("schema name", "table name", "column 1", "column 2") - if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` { - t.Fatal(stmt) - } - - stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`) - if stmt != `COPY "schema "" name """"""".`+ - `"table "" name """"""" ("co""lumn""""") FROM STDIN` { - t.Fatal(stmt) - } -} - -func TestCopyInMultipleValues(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") - if err != nil { - t.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) - if err != nil { - t.Fatal(err) - } - - longString := strings.Repeat("#", 500) - - for i := 0; i < 500; i++ { - _, err = stmt.Exec(int64(i), longString) - if err != nil { - t.Fatal(err) - } - } - - _, err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - - err = stmt.Close() - if err != nil { - t.Fatal(err) - } - - var num int - err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) - if err != nil { - t.Fatal(err) - } - - if num != 500 { - t.Fatalf("expected 500 items, not %d", num) - } -} - -func TestCopyInTypes(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)") - if err != nil { - t.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing")) - if err != nil { - t.Fatal(err) - } - - _, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil) - if err != nil { - t.Fatal(err) - } - - _, err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - - err = stmt.Close() - if err != nil { - t.Fatal(err) - } - - var num int - var text string - var blob []byte - var nothing sql.NullString - - err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, ¬hing) - if err != nil { - t.Fatal(err) - } - - if num != 1234567890 { - t.Fatal("unexpected result", num) - } - if text != "Héllö\n ☃!\r\t\\" { - t.Fatal("unexpected result", text) - } - if bytes.Compare(blob, []byte{0, 255, 9, 10, 13}) != 0 { - t.Fatal("unexpected result", blob) - } - if nothing.Valid { - t.Fatal("unexpected result", nothing.String) - } -} - -func TestCopyInWrongType(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") - if err != nil { - t.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "num")) - if err != nil { - t.Fatal(err) - } - defer stmt.Close() - - _, err = stmt.Exec("Héllö\n ☃!\r\t\\") - if err != nil { - t.Fatal(err) - } - - _, err = stmt.Exec() - if err == nil { - t.Fatal("expected error") - } - if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" { - t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge) - } -} - -func TestCopyOutsideOfTxnError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Prepare(CopyIn("temp", "num")) - if err == nil { - t.Fatal("COPY outside of transaction did not return an error") - } - if err != errCopyNotSupportedOutsideTxn { - t.Fatalf("expected %s, got %s", err, err.Error()) - } -} - -func TestCopyInBinaryError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") - if err != nil { - t.Fatal(err) - } - _, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary") - if err != errBinaryCopyNotSupported { - t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err) - } - // check that the protocol is in a valid state - err = txn.Rollback() - if err != nil { - t.Fatal(err) - } -} - -func TestCopyFromError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") - if err != nil { - t.Fatal(err) - } - _, err = txn.Prepare("COPY temp (num) TO STDOUT") - if err != errCopyToNotSupported { - t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err) - } - // check that the protocol is in a valid state - err = txn.Rollback() - if err != nil { - t.Fatal(err) - } -} - -func TestCopySyntaxError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Prepare("COPY ") - if err == nil { - t.Fatal("expected error") - } - if pge := err.(*Error); pge.Code.Name() != "syntax_error" { - t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge) - } - // check that the protocol is in a valid state - err = txn.Rollback() - if err != nil { - t.Fatal(err) - } -} - -// Tests for connection errors in copyin.resploop() -func TestCopyRespLoopConnectionError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - var pid int - err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid) - if err != nil { - t.Fatal(err) - } - - _, err = txn.Exec("CREATE TEMP TABLE temp (a int)") - if err != nil { - t.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "a")) - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("SELECT pg_terminate_backend($1)", pid) - if err != nil { - t.Fatal(err) - } - - // We have to try and send something over, since postgres won't process - // SIGTERMs while it's waiting for CopyData/CopyEnd messages; see - // tcop/postgres.c. - _, err = stmt.Exec(1) - if err != nil { - t.Fatal(err) - } - _, err = stmt.Exec() - if err == nil { - t.Fatalf("expected error") - } - pge, ok := err.(*Error) - if !ok { - t.Fatalf("expected *pq.Error, got %+#v", err) - } else if pge.Code.Name() != "admin_shutdown" { - t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name()) - } - - err = stmt.Close() - if err != nil { - t.Fatal(err) - } -} - -func BenchmarkCopyIn(b *testing.B) { - db := openTestConn(b) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - b.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") - if err != nil { - b.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) - if err != nil { - b.Fatal(err) - } - - for i := 0; i < b.N; i++ { - _, err = stmt.Exec(int64(i), "hello world!") - if err != nil { - b.Fatal(err) - } - } - - _, err = stmt.Exec() - if err != nil { - b.Fatal(err) - } - - err = stmt.Close() - if err != nil { - b.Fatal(err) - } - - var num int - err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) - if err != nil { - b.Fatal(err) - } - - if num != b.N { - b.Fatalf("expected %d items, not %d", b.N, num) - } -} diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index 4d7a0e39935..6d252ecee21 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -5,8 +5,9 @@ In most cases clients will use the database/sql package instead of using this package directly. For example: import ( - _ "github.com/lib/pq" "database/sql" + + _ "github.com/lib/pq" ) func main() { @@ -85,9 +86,13 @@ variables not supported by pq are set, pq will panic during connection establishment. Environment variables have a lower precedence than explicitly provided connection parameters. +The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html +is supported, but on Windows PGPASSFILE must be specified explicitly. + Queries + database/sql does not dictate any specific format for parameter markers in query strings, and pq uses the Postgres-native ordinal markers, as shown above. The same marker can be reused for the same parameter: @@ -111,8 +116,29 @@ For more details on RETURNING, see the Postgres documentation: For additional instructions on querying see the documentation for the database/sql package. + +Data Types + + +Parameters pass through driver.DefaultParameterConverter before they are handled +by this package. When the binary_parameters connection option is enabled, +[]byte values are sent directly to the backend as data in binary format. + +This package returns the following types for values from the PostgreSQL backend: + + - integer types smallint, integer, and bigint are returned as int64 + - floating-point types real and double precision are returned as float64 + - character types char, varchar, and text are returned as string + - temporal types date, time, timetz, timestamp, and timestamptz are returned as time.Time + - the boolean type is returned as bool + - the bytea type is returned as []byte + +All other types are returned directly from the backend as []byte values in text format. + + Errors + pq may return errors of type *pq.Error which can be interrogated for error details: if err, ok := err.(*pq.Error); ok { diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go index 3887f72aa29..88a322cda82 100644 --- a/vendor/github.com/lib/pq/encode.go +++ b/vendor/github.com/lib/pq/encode.go @@ -3,24 +3,34 @@ package pq import ( "bytes" "database/sql/driver" + "encoding/binary" "encoding/hex" + "errors" "fmt" - "github.com/lib/pq/oid" "math" "strconv" "strings" "sync" "time" + + "github.com/lib/pq/oid" ) +func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { + switch v := x.(type) { + case []byte: + return v + default: + return encode(parameterStatus, x, oid.T_unknown) + } +} + func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { switch v := x.(type) { case int64: - return []byte(fmt.Sprintf("%d", v)) - case float32: - return []byte(fmt.Sprintf("%.9f", v)) + return strconv.AppendInt(nil, v, 10) case float64: - return []byte(fmt.Sprintf("%.17f", v)) + return strconv.AppendFloat(nil, v, 'f', -1, 64) case []byte: if pgtypOid == oid.T_bytea { return encodeBytea(parameterStatus.serverVersion, v) @@ -34,7 +44,7 @@ func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) [ return []byte(v) case bool: - return []byte(fmt.Sprintf("%t", v)) + return strconv.AppendBool(nil, v) case time.Time: return formatTs(v) @@ -45,10 +55,51 @@ func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) [ panic("not reached") } -func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { +func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { + switch f { + case formatBinary: + return binaryDecode(parameterStatus, s, typ) + case formatText: + return textDecode(parameterStatus, s, typ) + default: + panic("not reached") + } +} + +func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { switch typ { case oid.T_bytea: - return parseBytea(s) + return s + case oid.T_int8: + return int64(binary.BigEndian.Uint64(s)) + case oid.T_int4: + return int64(int32(binary.BigEndian.Uint32(s))) + case oid.T_int2: + return int64(int16(binary.BigEndian.Uint16(s))) + case oid.T_uuid: + b, err := decodeUUIDBinary(s) + if err != nil { + panic(err) + } + return b + + default: + errorf("don't know how to decode binary parameter of type %d", uint32(typ)) + } + + panic("not reached") +} + +func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { + switch typ { + case oid.T_char, oid.T_varchar, oid.T_text: + return string(s) + case oid.T_bytea: + b, err := parseBytea(s) + if err != nil { + errorf("%s", err) + } + return b case oid.T_timestamptz: return parseTs(parameterStatus.currentLocation, string(s)) case oid.T_timestamp, oid.T_date: @@ -59,7 +110,7 @@ func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} return mustParse("15:04:05-07", typ, s) case oid.T_bool: return s[0] == 't' - case oid.T_int8, oid.T_int2, oid.T_int4: + case oid.T_int8, oid.T_int4, oid.T_int2: i, err := strconv.ParseInt(string(s), 10, 64) if err != nil { errorf("%s", err) @@ -86,8 +137,6 @@ func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface switch v := x.(type) { case int64: return strconv.AppendInt(buf, v, 10) - case float32: - return strconv.AppendFloat(buf, float64(v), 'f', -1, 32) case float64: return strconv.AppendFloat(buf, v, 'f', -1, 64) case []byte: @@ -149,12 +198,6 @@ func appendEscapedText(buf []byte, text string) []byte { func mustParse(f string, typ oid.Oid, s []byte) time.Time { str := string(s) - // Special case until time.Parse bug is fixed: - // http://code.google.com/p/go/issues/detail?id=3487 - if str[len(str)-2] == '.' { - str += "0" - } - // check for a 30-minute-offset timezone if (typ == oid.T_timestamptz || typ == oid.T_timetz) && str[len(str)-3] == ':' { @@ -167,16 +210,39 @@ func mustParse(f string, typ oid.Oid, s []byte) time.Time { return t } -func expect(str, char string, pos int) { - if c := str[pos : pos+1]; c != char { - errorf("expected '%v' at position %v; got '%v'", char, pos, c) +var errInvalidTimestamp = errors.New("invalid timestamp") + +type timestampParser struct { + err error +} + +func (p *timestampParser) expect(str string, char byte, pos int) { + if p.err != nil { + return + } + if pos+1 > len(str) { + p.err = errInvalidTimestamp + return + } + if c := str[pos]; c != char && p.err == nil { + p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) } } -func mustAtoi(str string) int { - result, err := strconv.Atoi(str) +func (p *timestampParser) mustAtoi(str string, begin int, end int) int { + if p.err != nil { + return 0 + } + if begin < 0 || end < 0 || begin > end || end > len(str) { + p.err = errInvalidTimestamp + return 0 + } + result, err := strconv.Atoi(str[begin:end]) if err != nil { - errorf("expected number; got '%v'", str) + if p.err == nil { + p.err = fmt.Errorf("expected number; got '%v'", str) + } + return 0 } return result } @@ -191,7 +257,7 @@ type locationCache struct { // about 5% speed could be gained by putting the cache in the connection and // losing the mutex, at the cost of a small amount of memory and a somewhat // significant increase in code complexity. -var globalLocationCache *locationCache = newLocationCache() +var globalLocationCache = newLocationCache() func newLocationCache() *locationCache { return &locationCache{cache: make(map[int]*time.Location)} @@ -212,32 +278,106 @@ func (c *locationCache) getLocation(offset int) *time.Location { return location } +var infinityTsEnabled = false +var infinityTsNegative time.Time +var infinityTsPositive time.Time + +const ( + infinityTsEnabledAlready = "pq: infinity timestamp enabled already" + infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" +) + +// EnableInfinityTs controls the handling of Postgres' "-infinity" and +// "infinity" "timestamp"s. +// +// If EnableInfinityTs is not called, "-infinity" and "infinity" will return +// []byte("-infinity") and []byte("infinity") respectively, and potentially +// cause error "sql: Scan error on column index 0: unsupported driver -> Scan +// pair: []uint8 -> *time.Time", when scanning into a time.Time value. +// +// Once EnableInfinityTs has been called, all connections created using this +// driver will decode Postgres' "-infinity" and "infinity" for "timestamp", +// "timestamp with time zone" and "date" types to the predefined minimum and +// maximum times, respectively. When encoding time.Time values, any time which +// equals or precedes the predefined minimum time will be encoded to +// "-infinity". Any values at or past the maximum time will similarly be +// encoded to "infinity". +// +// If EnableInfinityTs is called with negative >= positive, it will panic. +// Calling EnableInfinityTs after a connection has been established results in +// undefined behavior. If EnableInfinityTs is called more than once, it will +// panic. +func EnableInfinityTs(negative time.Time, positive time.Time) { + if infinityTsEnabled { + panic(infinityTsEnabledAlready) + } + if !negative.Before(positive) { + panic(infinityTsNegativeMustBeSmaller) + } + infinityTsEnabled = true + infinityTsNegative = negative + infinityTsPositive = positive +} + +/* + * Testing might want to toggle infinityTsEnabled + */ +func disableInfinityTs() { + infinityTsEnabled = false +} + // This is a time function specific to the Postgres default DateStyle // setting ("ISO, MDY"), the only one we currently support. This // accounts for the discrepancies between the parsing available with // time.Parse and the Postgres date formatting quirks. -func parseTs(currentLocation *time.Location, str string) (result time.Time) { +func parseTs(currentLocation *time.Location, str string) interface{} { + switch str { + case "-infinity": + if infinityTsEnabled { + return infinityTsNegative + } + return []byte(str) + case "infinity": + if infinityTsEnabled { + return infinityTsPositive + } + return []byte(str) + } + t, err := ParseTimestamp(currentLocation, str) + if err != nil { + panic(err) + } + return t +} + +// ParseTimestamp parses Postgres' text format. It returns a time.Time in +// currentLocation iff that time's offset agrees with the offset sent from the +// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the +// fixed offset offset provided by the Postgres server. +func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { + p := timestampParser{} + monSep := strings.IndexRune(str, '-') // this is Gregorian year, not ISO Year // In Gregorian system, the year 1 BC is followed by AD 1 - year := mustAtoi(str[:monSep]) + year := p.mustAtoi(str, 0, monSep) daySep := monSep + 3 - month := mustAtoi(str[monSep+1 : daySep]) - expect(str, "-", daySep) + month := p.mustAtoi(str, monSep+1, daySep) + p.expect(str, '-', daySep) timeSep := daySep + 3 - day := mustAtoi(str[daySep+1 : timeSep]) + day := p.mustAtoi(str, daySep+1, timeSep) var hour, minute, second int if len(str) > monSep+len("01-01")+1 { - expect(str, " ", timeSep) + p.expect(str, ' ', timeSep) minSep := timeSep + 3 - expect(str, ":", minSep) - hour = mustAtoi(str[timeSep+1 : minSep]) + p.expect(str, ':', minSep) + hour = p.mustAtoi(str, timeSep+1, minSep) secSep := minSep + 3 - expect(str, ":", secSep) - minute = mustAtoi(str[minSep+1 : secSep]) + p.expect(str, ':', secSep) + minute = p.mustAtoi(str, minSep+1, secSep) secEnd := secSep + 3 - second = mustAtoi(str[secSep+1 : secEnd]) + second = p.mustAtoi(str, secSep+1, secEnd) } remainderIdx := monSep + len("01-01 00:00:00") + 1 // Three optional (but ordered) sections follow: the @@ -248,49 +388,50 @@ func parseTs(currentLocation *time.Location, str string) (result time.Time) { nanoSec := 0 tzOff := 0 - if remainderIdx < len(str) && str[remainderIdx:remainderIdx+1] == "." { + if remainderIdx < len(str) && str[remainderIdx] == '.' { fracStart := remainderIdx + 1 fracOff := strings.IndexAny(str[fracStart:], "-+ ") if fracOff < 0 { fracOff = len(str) - fracStart } - fracSec := mustAtoi(str[fracStart : fracStart+fracOff]) + fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) remainderIdx += fracOff + 1 } - if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart:tzStart+1] == "-" || str[tzStart:tzStart+1] == "+") { + if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { // time zone separator is always '-' or '+' (UTC is +00) var tzSign int - if c := str[tzStart : tzStart+1]; c == "-" { + switch c := str[tzStart]; c { + case '-': tzSign = -1 - } else if c == "+" { + case '+': tzSign = +1 - } else { - errorf("expected '-' or '+' at position %v; got %v", tzStart, c) + default: + return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) } - tzHours := mustAtoi(str[tzStart+1 : tzStart+3]) + tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) remainderIdx += 3 var tzMin, tzSec int - if tzStart+3 < len(str) && str[tzStart+3:tzStart+4] == ":" { - tzMin = mustAtoi(str[tzStart+4 : tzStart+6]) + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) remainderIdx += 3 } - if tzStart+6 < len(str) && str[tzStart+6:tzStart+7] == ":" { - tzSec = mustAtoi(str[tzStart+7 : tzStart+9]) + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) remainderIdx += 3 } tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) } var isoYear int - if remainderIdx < len(str) && str[remainderIdx:remainderIdx+3] == " BC" { + if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" { isoYear = 1 - year remainderIdx += 3 } else { isoYear = year } if remainderIdx < len(str) { - errorf("expected end of input, got %v", str[remainderIdx:]) + return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) } t := time.Date(isoYear, time.Month(month), day, hour, minute, second, nanoSec, @@ -307,13 +448,26 @@ func parseTs(currentLocation *time.Location, str string) (result time.Time) { } } - return t + return t, p.err } -// formatTs formats t as time.RFC3339Nano and appends time zone seconds if -// needed. -func formatTs(t time.Time) (b []byte) { - b = []byte(t.Format(time.RFC3339Nano)) +// formatTs formats t into a format postgres understands. +func formatTs(t time.Time) []byte { + if infinityTsEnabled { + // t <= -infinity : ! (t > -infinity) + if !t.After(infinityTsNegative) { + return []byte("-infinity") + } + // t >= infinity : ! (!t < infinity) + if !t.Before(infinityTsPositive) { + return []byte("infinity") + } + } + return FormatTimestamp(t) +} + +// FormatTimestamp formats t into Postgres' text format for timestamps. +func FormatTimestamp(t time.Time) []byte { // Need to send dates before 0001 A.D. with " BC" suffix, instead of the // minus sign preferred by Go. // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on @@ -323,38 +477,39 @@ func formatTs(t time.Time) (b []byte) { t = t.AddDate((-t.Year())*2+1, 0, 0) bc = true } - b = []byte(t.Format(time.RFC3339Nano)) - if bc { - b = append(b, " BC"...) - } + b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) _, offset := t.Zone() offset = offset % 60 - if offset == 0 { - return b + if offset != 0 { + // RFC3339Nano already printed the minus sign + if offset < 0 { + offset = -offset + } + + b = append(b, ':') + if offset < 10 { + b = append(b, '0') + } + b = strconv.AppendInt(b, int64(offset), 10) } - if offset < 0 { - offset = -offset + if bc { + b = append(b, " BC"...) } - - b = append(b, ':') - if offset < 10 { - b = append(b, '0') - } - return strconv.AppendInt(b, int64(offset), 10) + return b } // Parse a bytea value received from the server. Both "hex" and the legacy // "escape" format are supported. -func parseBytea(s []byte) (result []byte) { +func parseBytea(s []byte) (result []byte, err error) { if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { // bytea_output = hex s = s[2:] // trim off leading "\\x" result = make([]byte, hex.DecodedLen(len(s))) _, err := hex.Decode(result, s) if err != nil { - errorf("%s", err) + return nil, err } } else { // bytea_output = escape @@ -369,11 +524,11 @@ func parseBytea(s []byte) (result []byte) { // '\\' followed by an octal number if len(s) < 4 { - errorf("invalid bytea sequence %v", s) + return nil, fmt.Errorf("invalid bytea sequence %v", s) } r, err := strconv.ParseInt(string(s[1:4]), 8, 9) if err != nil { - errorf("could not parse bytea value: %s", err.Error()) + return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) } result = append(result, byte(r)) s = s[4:] @@ -391,13 +546,16 @@ func parseBytea(s []byte) (result []byte) { } } - return result + return result, nil } func encodeBytea(serverVersion int, v []byte) (result []byte) { if serverVersion >= 90000 { // Use the hex format if we know that the server supports it - result = []byte(fmt.Sprintf("\\x%x", v)) + result = make([]byte, 2+hex.EncodedLen(len(v))) + result[0] = '\\' + result[1] = 'x' + hex.Encode(result[2:], v) } else { // .. or resort to "escape" for _, b := range v { diff --git a/vendor/github.com/lib/pq/encode_test.go b/vendor/github.com/lib/pq/encode_test.go deleted file mode 100644 index e835f2b36fd..00000000000 --- a/vendor/github.com/lib/pq/encode_test.go +++ /dev/null @@ -1,433 +0,0 @@ -package pq - -import ( - "github.com/lib/pq/oid" - - "bytes" - "fmt" - "testing" - "time" -) - -func TestScanTimestamp(t *testing.T) { - var nt NullTime - tn := time.Now() - nt.Scan(tn) - if !nt.Valid { - t.Errorf("Expected Valid=false") - } - if nt.Time != tn { - t.Errorf("Time value mismatch") - } -} - -func TestScanNilTimestamp(t *testing.T) { - var nt NullTime - nt.Scan(nil) - if nt.Valid { - t.Errorf("Expected Valid=false") - } -} - -var timeTests = []struct { - str string - timeval time.Time -}{ - {"22001-02-03", time.Date(22001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, - {"2001-02-03", time.Date(2001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.000001", time.Date(2001, time.February, 3, 4, 5, 6, 1000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.00001", time.Date(2001, time.February, 3, 4, 5, 6, 10000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.0001", time.Date(2001, time.February, 3, 4, 5, 6, 100000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.001", time.Date(2001, time.February, 3, 4, 5, 6, 1000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.01", time.Date(2001, time.February, 3, 4, 5, 6, 10000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.1", time.Date(2001, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.12", time.Date(2001, time.February, 3, 4, 5, 6, 120000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.123", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.1234", time.Date(2001, time.February, 3, 4, 5, 6, 123400000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.12345", time.Date(2001, time.February, 3, 4, 5, 6, 123450000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.123456", time.Date(2001, time.February, 3, 4, 5, 6, 123456000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.123-07", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, - time.FixedZone("", -7*60*60))}, - {"2001-02-03 04:05:06-07", time.Date(2001, time.February, 3, 4, 5, 6, 0, - time.FixedZone("", -7*60*60))}, - {"2001-02-03 04:05:06-07:42", time.Date(2001, time.February, 3, 4, 5, 6, 0, - time.FixedZone("", -(7*60*60+42*60)))}, - {"2001-02-03 04:05:06-07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0, - time.FixedZone("", -(7*60*60+30*60+9)))}, - {"2001-02-03 04:05:06+07", time.Date(2001, time.February, 3, 4, 5, 6, 0, - time.FixedZone("", 7*60*60))}, - {"0011-02-03 04:05:06 BC", time.Date(-10, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, - {"0011-02-03 04:05:06.123 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"0011-02-03 04:05:06.123-07 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, - time.FixedZone("", -7*60*60))}, - {"0001-02-03 04:05:06.123", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"0001-02-03 04:05:06.123 BC", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)}, - {"0001-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"0002-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)}, - {"0002-02-03 04:05:06.123 BC", time.Date(-1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"12345-02-03 04:05:06.1", time.Date(12345, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, - {"123456-02-03 04:05:06.1", time.Date(123456, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, -} - -// Helper function for the two tests below -func tryParse(str string) (t time.Time, err error) { - defer func() { - if p := recover(); p != nil { - err = fmt.Errorf("%v", p) - return - } - }() - t = parseTs(nil, str) - return -} - -// Test that parsing the string results in the expected value. -func TestParseTs(t *testing.T) { - for i, tt := range timeTests { - val, err := tryParse(tt.str) - if err != nil { - t.Errorf("%d: got error: %v", i, err) - } else if val.String() != tt.timeval.String() { - t.Errorf("%d: expected to parse %q into %q; got %q", - i, tt.str, tt.timeval, val) - } - } -} - -// Now test that sending the value into the database and parsing it back -// returns the same time.Time value. -func TestEncodeAndParseTs(t *testing.T) { - db, err := openTestConnConninfo("timezone='Etc/UTC'") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - for i, tt := range timeTests { - var dbstr string - err = db.QueryRow("SELECT ($1::timestamptz)::text", tt.timeval).Scan(&dbstr) - if err != nil { - t.Errorf("%d: could not send value %q to the database: %s", i, tt.timeval, err) - continue - } - - val, err := tryParse(dbstr) - if err != nil { - t.Errorf("%d: could not parse value %q: %s", i, dbstr, err) - continue - } - val = val.In(tt.timeval.Location()) - if val.String() != tt.timeval.String() { - t.Errorf("%d: expected to parse %q into %q; got %q", i, dbstr, tt.timeval, val) - } - } -} - -var formatTimeTests = []struct { - time time.Time - expected string -}{ - {time.Time{}, "0001-01-01T00:00:00Z"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03T04:05:06.123456789Z"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03T04:05:06.123456789+02:00"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03T04:05:06.123456789-06:00"}, - {time.Date(1, time.January, 1, 0, 0, 0, 0, time.FixedZone("", 19*60+32)), "0001-01-01T00:00:00+00:19:32"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03T04:05:06-07:30:09"}, -} - -func TestFormatTs(t *testing.T) { - for i, tt := range formatTimeTests { - val := string(formatTs(tt.time)) - if val != tt.expected { - t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected) - } - } -} - -func TestTimestampWithTimeZone(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - tx, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer tx.Rollback() - - // try several different locations, all included in Go's zoneinfo.zip - for _, locName := range []string{ - "UTC", - "America/Chicago", - "America/New_York", - "Australia/Darwin", - "Australia/Perth", - } { - loc, err := time.LoadLocation(locName) - if err != nil { - t.Logf("Could not load time zone %s - skipping", locName) - continue - } - - // Postgres timestamps have a resolution of 1 microsecond, so don't - // use the full range of the Nanosecond argument - refTime := time.Date(2012, 11, 6, 10, 23, 42, 123456000, loc) - - for _, pgTimeZone := range []string{"US/Eastern", "Australia/Darwin"} { - // Switch Postgres's timezone to test different output timestamp formats - _, err = tx.Exec(fmt.Sprintf("set time zone '%s'", pgTimeZone)) - if err != nil { - t.Fatal(err) - } - - var gotTime time.Time - row := tx.QueryRow("select $1::timestamp with time zone", refTime) - err = row.Scan(&gotTime) - if err != nil { - t.Fatal(err) - } - - if !refTime.Equal(gotTime) { - t.Errorf("timestamps not equal: %s != %s", refTime, gotTime) - } - - // check that the time zone is set correctly based on TimeZone - pgLoc, err := time.LoadLocation(pgTimeZone) - if err != nil { - t.Logf("Could not load time zone %s - skipping", pgLoc) - continue - } - translated := refTime.In(pgLoc) - if translated.String() != gotTime.String() { - t.Errorf("timestamps not equal: %s != %s", translated, gotTime) - } - } - } -} - -func TestTimestampWithOutTimezone(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - test := func(ts, pgts string) { - r, err := db.Query("SELECT $1::timestamp", pgts) - if err != nil { - t.Fatalf("Could not run query: %v", err) - } - - n := r.Next() - - if n != true { - t.Fatal("Expected at least one row") - } - - var result time.Time - err = r.Scan(&result) - if err != nil { - t.Fatalf("Did not expect error scanning row: %v", err) - } - - expected, err := time.Parse(time.RFC3339, ts) - if err != nil { - t.Fatalf("Could not parse test time literal: %v", err) - } - - if !result.Equal(expected) { - t.Fatalf("Expected time to match %v: got mismatch %v", - expected, result) - } - - n = r.Next() - if n != false { - t.Fatal("Expected only one row") - } - } - - test("2000-01-01T00:00:00Z", "2000-01-01T00:00:00") - - // Test higher precision time - test("2013-01-04T20:14:58.80033Z", "2013-01-04 20:14:58.80033") -} - -func TestStringWithNul(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - hello0world := string("hello\x00world") - _, err := db.Query("SELECT $1::text", &hello0world) - if err == nil { - t.Fatal("Postgres accepts a string with nul in it; " + - "injection attacks may be plausible") - } -} - -func TestByteaToText(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - b := []byte("hello world") - row := db.QueryRow("SELECT $1::text", b) - - var result []byte - err := row.Scan(&result) - if err != nil { - t.Fatal(err) - } - - if string(result) != string(b) { - t.Fatalf("expected %v but got %v", b, result) - } -} - -func TestTextToBytea(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - b := "hello world" - row := db.QueryRow("SELECT $1::bytea", b) - - var result []byte - err := row.Scan(&result) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(result, []byte(b)) { - t.Fatalf("expected %v but got %v", b, result) - } -} - -func TestByteaOutputFormatEncoding(t *testing.T) { - input := []byte("\\x\x00\x01\x02\xFF\xFEabcdefg0123") - want := []byte("\\x5c78000102fffe6162636465666730313233") - got := encode(¶meterStatus{serverVersion: 90000}, input, oid.T_bytea) - if !bytes.Equal(want, got) { - t.Errorf("invalid hex bytea output, got %v but expected %v", got, want) - } - - want = []byte("\\\\x\\000\\001\\002\\377\\376abcdefg0123") - got = encode(¶meterStatus{serverVersion: 84000}, input, oid.T_bytea) - if !bytes.Equal(want, got) { - t.Errorf("invalid escape bytea output, got %v but expected %v", got, want) - } -} - -func TestByteaOutputFormats(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - if getServerVersion(t, db) < 90000 { - // skip - return - } - - testByteaOutputFormat := func(f string) { - expectedData := []byte("\x5c\x78\x00\xff\x61\x62\x63\x01\x08") - sqlQuery := "SELECT decode('5c7800ff6162630108', 'hex')" - - var data []byte - - // use a txn to avoid relying on getting the same connection - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("SET LOCAL bytea_output TO " + f) - if err != nil { - t.Fatal(err) - } - // use Query; QueryRow would hide the actual error - rows, err := txn.Query(sqlQuery) - if err != nil { - t.Fatal(err) - } - if !rows.Next() { - if rows.Err() != nil { - t.Fatal(rows.Err()) - } - t.Fatal("shouldn't happen") - } - err = rows.Scan(&data) - if err != nil { - t.Fatal(err) - } - err = rows.Close() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, expectedData) { - t.Errorf("unexpected bytea value %v for format %s; expected %v", data, f, expectedData) - } - } - - testByteaOutputFormat("hex") - testByteaOutputFormat("escape") -} - -func TestAppendEncodedText(t *testing.T) { - var buf []byte - - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, int64(10)) - buf = append(buf, '\t') - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, float32(42.0000000001)) - buf = append(buf, '\t') - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, 42.0000000001) - buf = append(buf, '\t') - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, "hello\tworld") - buf = append(buf, '\t') - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, []byte{0, 128, 255}) - - if string(buf) != "10\t42\t42.0000000001\thello\\tworld\t\\\\x0080ff" { - t.Fatal(string(buf)) - } -} - -func TestAppendEscapedText(t *testing.T) { - if esc := appendEscapedText(nil, "hallo\tescape"); string(esc) != "hallo\\tescape" { - t.Fatal(string(esc)) - } - if esc := appendEscapedText(nil, "hallo\\tescape\n"); string(esc) != "hallo\\\\tescape\\n" { - t.Fatal(string(esc)) - } - if esc := appendEscapedText(nil, "\n\r\t\f"); string(esc) != "\\n\\r\\t\f" { - t.Fatal(string(esc)) - } -} - -func TestAppendEscapedTextExistingBuffer(t *testing.T) { - var buf []byte - buf = []byte("123\t") - if esc := appendEscapedText(buf, "hallo\tescape"); string(esc) != "123\thallo\\tescape" { - t.Fatal(string(esc)) - } - buf = []byte("123\t") - if esc := appendEscapedText(buf, "hallo\\tescape\n"); string(esc) != "123\thallo\\\\tescape\\n" { - t.Fatal(string(esc)) - } - buf = []byte("123\t") - if esc := appendEscapedText(buf, "\n\r\t\f"); string(esc) != "123\t\\n\\r\\t\f" { - t.Fatal(string(esc)) - } -} - -func BenchmarkAppendEscapedText(b *testing.B) { - longString := "" - for i := 0; i < 100; i++ { - longString += "123456789\n" - } - for i := 0; i < b.N; i++ { - appendEscapedText(nil, longString) - } -} - -func BenchmarkAppendEscapedTextNoEscape(b *testing.B) { - longString := "" - for i := 0; i < 100; i++ { - longString += "1234567890" - } - for i := 0; i < b.N; i++ { - appendEscapedText(nil, longString) - } -} diff --git a/vendor/github.com/lib/pq/error.go b/vendor/github.com/lib/pq/error.go index 0a49364d9ec..b4bb44cee39 100644 --- a/vendor/github.com/lib/pq/error.go +++ b/vendor/github.com/lib/pq/error.go @@ -459,6 +459,19 @@ func errorf(s string, args ...interface{}) { panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } +func errRecoverNoErrBadConn(err *error) { + e := recover() + if e == nil { + // Do nothing + return + } + var ok bool + *err, ok = e.(error) + if !ok { + *err = fmt.Errorf("pq: unexpected error: %#v", e) + } +} + func (c *conn) errRecover(err *error) { e := recover() switch v := e.(type) { diff --git a/vendor/github.com/lib/pq/notify.go b/vendor/github.com/lib/pq/notify.go index e3b08d59a2f..09f94244b9b 100644 --- a/vendor/github.com/lib/pq/notify.go +++ b/vendor/github.com/lib/pq/notify.go @@ -6,7 +6,6 @@ package pq import ( "errors" "fmt" - "io" "sync" "sync/atomic" "time" @@ -63,14 +62,18 @@ type ListenerConn struct { // Creates a new ListenerConn. Use NewListener instead. func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) { - cn, err := Open(name) + return newDialListenerConn(defaultDialer{}, name, notificationChan) +} + +func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*ListenerConn, error) { + cn, err := DialOpen(d, name) if err != nil { return nil, err } l := &ListenerConn{ cn: cn.(*conn), - notificationChan: notificationChan, + notificationChan: c, connState: connStateIdle, replyChan: make(chan message, 2), } @@ -87,12 +90,16 @@ func NewListenerConn(name string, notificationChan chan<- *Notification) (*Liste // Returns an error if an unrecoverable error has occurred and the ListenerConn // should be abandoned. func (l *ListenerConn) acquireSenderLock() error { - l.connectionLock.Lock() - defer l.connectionLock.Unlock() - if l.err != nil { - return l.err - } + // we must acquire senderLock first to avoid deadlocks; see ExecSimpleQuery l.senderLock.Lock() + + l.connectionLock.Lock() + err := l.err + l.connectionLock.Unlock() + if err != nil { + l.senderLock.Unlock() + return err + } return nil } @@ -125,7 +132,7 @@ func (l *ListenerConn) setState(newState int32) bool { // away or should be discarded because we couldn't agree on the state with the // server backend. func (l *ListenerConn) listenerConnLoop() (err error) { - defer l.cn.errRecover(&err) + defer errRecoverNoErrBadConn(&err) r := &readBuf{} for { @@ -140,6 +147,9 @@ func (l *ListenerConn) listenerConnLoop() (err error) { // about the scratch buffer being overwritten. l.notificationChan <- recvNotification(r) + case 'T', 'D': + // only used by tests; ignore + case 'E': // We might receive an ErrorResponse even when not in a query; it // is expected that the server will close the connection after @@ -238,7 +248,7 @@ func (l *ListenerConn) Ping() error { // The caller must be holding senderLock (see acquireSenderLock and // releaseSenderLock). func (l *ListenerConn) sendSimpleQuery(q string) (err error) { - defer l.cn.errRecover(&err) + defer errRecoverNoErrBadConn(&err) // must set connection state before sending the query if !l.setState(connStateExpectResponse) { @@ -247,8 +257,10 @@ func (l *ListenerConn) sendSimpleQuery(q string) (err error) { // Can't use l.cn.writeBuf here because it uses the scratch buffer which // might get overwritten by listenerConnLoop. - data := writeBuf([]byte("Q\x00\x00\x00\x00")) - b := &data + b := &writeBuf{ + buf: []byte("Q\x00\x00\x00\x00"), + pos: 1, + } b.string(q) l.cn.send(b) @@ -277,13 +289,13 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { // We can't know what state the protocol is in, so we need to abandon // this connection. l.connectionLock.Lock() - defer l.connectionLock.Unlock() // Set the error pointer if it hasn't been set already; see // listenerConnMain. if l.err == nil { l.err = err } - l.cn.Close() + l.connectionLock.Unlock() + l.cn.c.Close() return false, err } @@ -292,8 +304,11 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { m, ok := <-l.replyChan if !ok { // We lost the connection to server, don't bother waiting for a - // a response. - return false, io.EOF + // a response. err should have been set already. + l.connectionLock.Lock() + err := l.err + l.connectionLock.Unlock() + return false, err } switch m.typ { case 'Z': @@ -320,12 +335,15 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { func (l *ListenerConn) Close() error { l.connectionLock.Lock() - defer l.connectionLock.Unlock() if l.err != nil { + l.connectionLock.Unlock() return errListenerConnClosed } l.err = errListenerConnClosed - return l.cn.Close() + l.connectionLock.Unlock() + // We can't send anything on the connection without holding senderLock. + // Simply close the net.Conn to wake up everyone operating on it. + return l.cn.c.Close() } // Err() returns the reason the connection was closed. It is not safe to call @@ -377,6 +395,7 @@ type Listener struct { name string minReconnectInterval time.Duration maxReconnectInterval time.Duration + dialer Dialer eventCallback EventCallbackType lock sync.Mutex @@ -407,10 +426,21 @@ func NewListener(name string, minReconnectInterval time.Duration, maxReconnectInterval time.Duration, eventCallback EventCallbackType) *Listener { + return NewDialListener(defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback) +} + +// NewDialListener is like NewListener but it takes a Dialer. +func NewDialListener(d Dialer, + name string, + minReconnectInterval time.Duration, + maxReconnectInterval time.Duration, + eventCallback EventCallbackType) *Listener { + l := &Listener{ name: name, minReconnectInterval: minReconnectInterval, maxReconnectInterval: maxReconnectInterval, + dialer: d, eventCallback: eventCallback, channels: make(map[string]struct{}), @@ -646,7 +676,7 @@ func (l *Listener) closed() bool { func (l *Listener) connect() error { notificationChan := make(chan *Notification, 32) - cn, err := NewListenerConn(l.name, notificationChan) + cn, err := newDialListenerConn(l.dialer, l.name, notificationChan) if err != nil { return err } diff --git a/vendor/github.com/lib/pq/notify_test.go b/vendor/github.com/lib/pq/notify_test.go deleted file mode 100644 index ae1208d160f..00000000000 --- a/vendor/github.com/lib/pq/notify_test.go +++ /dev/null @@ -1,502 +0,0 @@ -package pq - -import ( - "errors" - "fmt" - "io" - "os" - "testing" - "time" -) - -var errNilNotification = errors.New("nil notification") - -func expectNotification(t *testing.T, ch <-chan *Notification, relname string, extra string) error { - select { - case n := <-ch: - if n == nil { - return errNilNotification - } - if n.Channel != relname || n.Extra != extra { - return fmt.Errorf("unexpected notification %v", n) - } - return nil - case <-time.After(1500 * time.Millisecond): - return fmt.Errorf("timeout") - } -} - -func expectNoNotification(t *testing.T, ch <-chan *Notification) error { - select { - case n := <-ch: - return fmt.Errorf("unexpected notification %v", n) - case <-time.After(100 * time.Millisecond): - return nil - } -} - -func expectEvent(t *testing.T, eventch <-chan ListenerEventType, et ListenerEventType) error { - select { - case e := <-eventch: - if e != et { - return fmt.Errorf("unexpected event %v", e) - } - return nil - case <-time.After(1500 * time.Millisecond): - return fmt.Errorf("timeout") - } -} - -func expectNoEvent(t *testing.T, eventch <-chan ListenerEventType) error { - select { - case e := <-eventch: - return fmt.Errorf("unexpected event %v", e) - case <-time.After(100 * time.Millisecond): - return nil - } -} - -func newTestListenerConn(t *testing.T) (*ListenerConn, <-chan *Notification) { - datname := os.Getenv("PGDATABASE") - sslmode := os.Getenv("PGSSLMODE") - - if datname == "" { - os.Setenv("PGDATABASE", "pqgotest") - } - - if sslmode == "" { - os.Setenv("PGSSLMODE", "disable") - } - - notificationChan := make(chan *Notification) - l, err := NewListenerConn("", notificationChan) - if err != nil { - t.Fatal(err) - } - - return l, notificationChan -} - -func TestNewListenerConn(t *testing.T) { - l, _ := newTestListenerConn(t) - - defer l.Close() -} - -func TestConnListen(t *testing.T) { - l, channel := newTestListenerConn(t) - - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - ok, err := l.Listen("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, channel, "notify_test", "") - if err != nil { - t.Fatal(err) - } -} - -func TestConnUnlisten(t *testing.T) { - l, channel := newTestListenerConn(t) - - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - ok, err := l.Listen("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - - err = expectNotification(t, channel, "notify_test", "") - if err != nil { - t.Fatal(err) - } - - ok, err = l.Unlisten("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - if err != nil { - t.Fatal(err) - } - - err = expectNoNotification(t, channel) - if err != nil { - t.Fatal(err) - } -} - -func TestConnUnlistenAll(t *testing.T) { - l, channel := newTestListenerConn(t) - - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - ok, err := l.Listen("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - - err = expectNotification(t, channel, "notify_test", "") - if err != nil { - t.Fatal(err) - } - - ok, err = l.UnlistenAll() - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - if err != nil { - t.Fatal(err) - } - - err = expectNoNotification(t, channel) - if err != nil { - t.Fatal(err) - } -} - -func TestConnClose(t *testing.T) { - l, _ := newTestListenerConn(t) - defer l.Close() - - err := l.Close() - if err != nil { - t.Fatal(err) - } - err = l.Close() - if err != errListenerConnClosed { - t.Fatalf("expected errListenerConnClosed; got %v", err) - } -} - -func TestConnPing(t *testing.T) { - l, _ := newTestListenerConn(t) - defer l.Close() - err := l.Ping() - if err != nil { - t.Fatal(err) - } - err = l.Close() - if err != nil { - t.Fatal(err) - } - err = l.Ping() - if err != errListenerConnClosed { - t.Fatalf("expected errListenerConnClosed; got %v", err) - } -} - -func TestNotifyExtra(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - if getServerVersion(t, db) < 90000 { - t.Skip("skipping NOTIFY payload test since the server does not appear to support it") - } - - l, channel := newTestListenerConn(t) - defer l.Close() - - ok, err := l.Listen("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test, 'something'") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, channel, "notify_test", "something") - if err != nil { - t.Fatal(err) - } -} - -// create a new test listener and also set the timeouts -func newTestListenerTimeout(t *testing.T, min time.Duration, max time.Duration) (*Listener, <-chan ListenerEventType) { - datname := os.Getenv("PGDATABASE") - sslmode := os.Getenv("PGSSLMODE") - - if datname == "" { - os.Setenv("PGDATABASE", "pqgotest") - } - - if sslmode == "" { - os.Setenv("PGSSLMODE", "disable") - } - - eventch := make(chan ListenerEventType, 16) - l := NewListener("", min, max, func(t ListenerEventType, err error) { eventch <- t }) - err := expectEvent(t, eventch, ListenerEventConnected) - if err != nil { - t.Fatal(err) - } - return l, eventch -} - -func newTestListener(t *testing.T) (*Listener, <-chan ListenerEventType) { - return newTestListenerTimeout(t, time.Hour, time.Hour) -} - -func TestListenerListen(t *testing.T) { - l, _ := newTestListener(t) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } -} - -func TestListenerUnlisten(t *testing.T) { - l, _ := newTestListener(t) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = l.Unlisten("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNoNotification(t, l.Notify) - if err != nil { - t.Fatal(err) - } -} - -func TestListenerUnlistenAll(t *testing.T) { - l, _ := newTestListener(t) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = l.UnlistenAll() - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNoNotification(t, l.Notify) - if err != nil { - t.Fatal(err) - } -} - -func TestListenerFailedQuery(t *testing.T) { - l, eventch := newTestListener(t) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } - - // shouldn't cause a disconnect - ok, err := l.cn.ExecSimpleQuery("SELECT error") - if !ok { - t.Fatalf("could not send query to server: %v", err) - } - _, ok = err.(PGError) - if !ok { - t.Fatalf("unexpected error %v", err) - } - err = expectNoEvent(t, eventch) - if err != nil { - t.Fatal(err) - } - - // should still work - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } -} - -func TestListenerReconnect(t *testing.T) { - l, eventch := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } - - // kill the connection and make sure it comes back up - ok, err := l.cn.ExecSimpleQuery("SELECT pg_terminate_backend(pg_backend_pid())") - if ok { - t.Fatalf("could not kill the connection: %v", err) - } - if err != io.EOF { - t.Fatalf("unexpected error %v", err) - } - err = expectEvent(t, eventch, ListenerEventDisconnected) - if err != nil { - t.Fatal(err) - } - err = expectEvent(t, eventch, ListenerEventReconnected) - if err != nil { - t.Fatal(err) - } - - // should still work - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - // should get nil after Reconnected - err = expectNotification(t, l.Notify, "", "") - if err != errNilNotification { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } -} - -func TestListenerClose(t *testing.T) { - l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) - defer l.Close() - - err := l.Close() - if err != nil { - t.Fatal(err) - } - err = l.Close() - if err != errListenerClosed { - t.Fatalf("expected errListenerClosed; got %v", err) - } -} - -func TestListenerPing(t *testing.T) { - l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) - defer l.Close() - - err := l.Ping() - if err != nil { - t.Fatal(err) - } - - err = l.Close() - if err != nil { - t.Fatal(err) - } - - err = l.Ping() - if err != errListenerClosed { - t.Fatalf("expected errListenerClosed; got %v", err) - } -} diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go new file mode 100644 index 00000000000..7deb304366f --- /dev/null +++ b/vendor/github.com/lib/pq/ssl.go @@ -0,0 +1,158 @@ +package pq + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "os" + "os/user" + "path/filepath" +) + +// ssl generates a function to upgrade a net.Conn based on the "sslmode" and +// related settings. The function is nil when no upgrade should take place. +func ssl(o values) func(net.Conn) net.Conn { + verifyCaOnly := false + tlsConf := tls.Config{} + switch mode := o["sslmode"]; mode { + // "require" is the default. + case "", "require": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + + // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: + // + // Note: For backwards compatibility with earlier versions of + // PostgreSQL, if a root CA file exists, the behavior of + // sslmode=require will be the same as that of verify-ca, meaning the + // server certificate is validated against the CA. Relying on this + // behavior is discouraged, and applications that need certificate + // validation should always use verify-ca or verify-full. + if sslrootcert, ok := o["sslrootcert"]; ok { + if _, err := os.Stat(sslrootcert); err == nil { + verifyCaOnly = true + } else { + delete(o, "sslrootcert") + } + } + case "verify-ca": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + verifyCaOnly = true + case "verify-full": + tlsConf.ServerName = o["host"] + case "disable": + return nil + default: + errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) + } + + sslClientCertificates(&tlsConf, o) + sslCertificateAuthority(&tlsConf, o) + sslRenegotiation(&tlsConf) + + return func(conn net.Conn) net.Conn { + client := tls.Client(conn, &tlsConf) + if verifyCaOnly { + sslVerifyCertificateAuthority(client, &tlsConf) + } + return client + } +} + +// sslClientCertificates adds the certificate specified in the "sslcert" and +// "sslkey" settings, or if they aren't set, from the .postgresql directory +// in the user's home directory. The configured files must exist and have +// the correct permissions. +func sslClientCertificates(tlsConf *tls.Config, o values) { + // user.Current() might fail when cross-compiling. We have to ignore the + // error and continue without home directory defaults, since we wouldn't + // know from where to load them. + user, _ := user.Current() + + // In libpq, the client certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 + sslcert := o["sslcert"] + if len(sslcert) == 0 && user != nil { + sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 + if len(sslcert) == 0 { + return + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 + if _, err := os.Stat(sslcert); os.IsNotExist(err) { + return + } else if err != nil { + panic(err) + } + + // In libpq, the ssl key is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 + sslkey := o["sslkey"] + if len(sslkey) == 0 && user != nil { + sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + } + + if len(sslkey) > 0 { + if err := sslKeyPermissions(sslkey); err != nil { + panic(err) + } + } + + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + panic(err) + } + tlsConf.Certificates = []tls.Certificate{cert} +} + +// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. +func sslCertificateAuthority(tlsConf *tls.Config, o values) { + // In libpq, the root certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 + if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { + tlsConf.RootCAs = x509.NewCertPool() + + cert, err := ioutil.ReadFile(sslrootcert) + if err != nil { + panic(err) + } + + if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { + errorf("couldn't parse pem in sslrootcert") + } + } +} + +// sslVerifyCertificateAuthority carries out a TLS handshake to the server and +// verifies the presented certificate against the CA, i.e. the one specified in +// sslrootcert or the system CA if sslrootcert was not specified. +func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) { + err := client.Handshake() + if err != nil { + panic(err) + } + certs := client.ConnectionState().PeerCertificates + opts := x509.VerifyOptions{ + DNSName: client.ConnectionState().ServerName, + Intermediates: x509.NewCertPool(), + Roots: tlsConf.RootCAs, + } + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + _, err = certs[0].Verify(opts) + if err != nil { + panic(err) + } +} diff --git a/vendor/github.com/lib/pq/ssl_go1.7.go b/vendor/github.com/lib/pq/ssl_go1.7.go new file mode 100644 index 00000000000..d7ba43b32a1 --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_go1.7.go @@ -0,0 +1,14 @@ +// +build go1.7 + +package pq + +import "crypto/tls" + +// Accept renegotiation requests initiated by the backend. +// +// Renegotiation was deprecated then removed from PostgreSQL 9.5, but +// the default configuration of older versions has it enabled. Redshift +// also initiates renegotiations and cannot be reconfigured. +func sslRenegotiation(conf *tls.Config) { + conf.Renegotiation = tls.RenegotiateFreelyAsClient +} diff --git a/vendor/github.com/lib/pq/ssl_permissions.go b/vendor/github.com/lib/pq/ssl_permissions.go new file mode 100644 index 00000000000..3b7c3a2a319 --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_permissions.go @@ -0,0 +1,20 @@ +// +build !windows + +package pq + +import "os" + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(sslkey string) error { + info, err := os.Stat(sslkey) + if err != nil { + return err + } + if info.Mode().Perm()&0077 != 0 { + return ErrSSLKeyHasWorldPermissions + } + return nil +} diff --git a/vendor/github.com/lib/pq/ssl_renegotiation.go b/vendor/github.com/lib/pq/ssl_renegotiation.go new file mode 100644 index 00000000000..85ed5e437fb --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_renegotiation.go @@ -0,0 +1,8 @@ +// +build !go1.7 + +package pq + +import "crypto/tls" + +// Renegotiation is not supported by crypto/tls until Go 1.7. +func sslRenegotiation(*tls.Config) {} diff --git a/vendor/github.com/lib/pq/ssl_test.go b/vendor/github.com/lib/pq/ssl_test.go deleted file mode 100644 index 932b336f530..00000000000 --- a/vendor/github.com/lib/pq/ssl_test.go +++ /dev/null @@ -1,226 +0,0 @@ -package pq - -// This file contains SSL tests - -import ( - _ "crypto/sha256" - "crypto/x509" - "database/sql" - "fmt" - "os" - "path/filepath" - "testing" -) - -func maybeSkipSSLTests(t *testing.T) { - // Require some special variables for testing certificates - if os.Getenv("PQSSLCERTTEST_PATH") == "" { - t.Skip("PQSSLCERTTEST_PATH not set, skipping SSL tests") - } - - value := os.Getenv("PQGOSSLTESTS") - if value == "" || value == "0" { - t.Skip("PQGOSSLTESTS not enabled, skipping SSL tests") - } else if value != "1" { - t.Fatalf("unexpected value %q for PQGOSSLTESTS", value) - } -} - -func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) { - db, err := openTestConnConninfo(conninfo) - if err != nil { - // should never fail - t.Fatal(err) - } - // Do something with the connection to see whether it's working or not. - tx, err := db.Begin() - if err == nil { - return db, tx.Rollback() - } - _ = db.Close() - return nil, err -} - -func checkSSLSetup(t *testing.T, conninfo string) { - db, err := openSSLConn(t, conninfo) - if err == nil { - db.Close() - t.Fatalf("expected error with conninfo=%q", conninfo) - } -} - -// Connect over SSL and run a simple query to test the basics -func TestSSLConnection(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - db, err := openSSLConn(t, "sslmode=require user=pqgossltest") - if err != nil { - t.Fatal(err) - } - rows, err := db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - rows.Close() -} - -// Test sslmode=verify-full -func TestSSLVerifyFull(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - // Not OK according to the system CA - _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest") - if err == nil { - t.Fatal("expected error") - } - _, ok := err.(x509.UnknownAuthorityError) - if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) - } - - rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") - rootCert := "sslrootcert=" + rootCertPath + " " - // No match on Common Name - _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest") - if err == nil { - t.Fatal("expected error") - } - _, ok = err.(x509.HostnameError) - if !ok { - t.Fatalf("expected x509.HostnameError, got %#+v", err) - } - // OK - _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest") - if err != nil { - t.Fatal(err) - } -} - -// Test sslmode=verify-ca -func TestSSLVerifyCA(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - // Not OK according to the system CA - _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") - if err == nil { - t.Fatal("expected error") - } - _, ok := err.(x509.UnknownAuthorityError) - if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) - } - - rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") - rootCert := "sslrootcert=" + rootCertPath + " " - // No match on Common Name, but that's OK - _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest") - if err != nil { - t.Fatal(err) - } - // Everything OK - _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest") - if err != nil { - t.Fatal(err) - } -} - -func getCertConninfo(t *testing.T, source string) string { - var sslkey string - var sslcert string - - certpath := os.Getenv("PQSSLCERTTEST_PATH") - - switch source { - case "missingkey": - sslkey = "/tmp/filedoesnotexist" - sslcert = filepath.Join(certpath, "postgresql.crt") - case "missingcert": - sslkey = filepath.Join(certpath, "postgresql.key") - sslcert = "/tmp/filedoesnotexist" - case "certtwice": - sslkey = filepath.Join(certpath, "postgresql.crt") - sslcert = filepath.Join(certpath, "postgresql.crt") - case "valid": - sslkey = filepath.Join(certpath, "postgresql.key") - sslcert = filepath.Join(certpath, "postgresql.crt") - default: - t.Fatalf("invalid source %q", source) - } - return fmt.Sprintf("sslmode=require user=pqgosslcert sslkey=%s sslcert=%s", sslkey, sslcert) -} - -// Authenticate over SSL using client certificates -func TestSSLClientCertificates(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - // Should also fail without a valid certificate - db, err := openSSLConn(t, "sslmode=require user=pqgosslcert") - if err == nil { - db.Close() - t.Fatal("expected error") - } - pge, ok := err.(*Error) - if !ok { - t.Fatal("expected pq.Error") - } - if pge.Code.Name() != "invalid_authorization_specification" { - t.Fatalf("unexpected error code %q", pge.Code.Name()) - } - - // Should work - db, err = openSSLConn(t, getCertConninfo(t, "valid")) - if err != nil { - t.Fatal(err) - } - rows, err := db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - rows.Close() -} - -// Test errors with ssl certificates -func TestSSLClientCertificatesMissingFiles(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - // Key missing, should fail - _, err := openSSLConn(t, getCertConninfo(t, "missingkey")) - if err == nil { - t.Fatal("expected error") - } - // should be a PathError - _, ok := err.(*os.PathError) - if !ok { - t.Fatalf("expected PathError, got %#+v", err) - } - - // Cert missing, should fail - _, err = openSSLConn(t, getCertConninfo(t, "missingcert")) - if err == nil { - t.Fatal("expected error") - } - // should be a PathError - _, ok = err.(*os.PathError) - if !ok { - t.Fatalf("expected PathError, got %#+v", err) - } - - // Key has wrong permissions, should fail - _, err = openSSLConn(t, getCertConninfo(t, "certtwice")) - if err == nil { - t.Fatal("expected error") - } - if err != ErrSSLKeyHasWorldPermissions { - t.Fatalf("expected ErrSSLKeyHasWorldPermissions, got %#+v", err) - } -} diff --git a/vendor/github.com/lib/pq/ssl_windows.go b/vendor/github.com/lib/pq/ssl_windows.go new file mode 100644 index 00000000000..5d2c763cebc --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_windows.go @@ -0,0 +1,9 @@ +// +build windows + +package pq + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(string) error { return nil } diff --git a/vendor/github.com/lib/pq/url.go b/vendor/github.com/lib/pq/url.go index b83e806b8e3..f4d8a7c2062 100644 --- a/vendor/github.com/lib/pq/url.go +++ b/vendor/github.com/lib/pq/url.go @@ -2,6 +2,7 @@ package pq import ( "fmt" + "net" nurl "net/url" "sort" "strings" @@ -34,7 +35,7 @@ func ParseURL(url string) (string, error) { return "", err } - if u.Scheme != "postgres" { + if u.Scheme != "postgres" && u.Scheme != "postgresql" { return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) } @@ -54,12 +55,11 @@ func ParseURL(url string) (string, error) { accrue("password", v) } - i := strings.Index(u.Host, ":") - if i < 0 { + if host, port, err := net.SplitHostPort(u.Host); err != nil { accrue("host", u.Host) } else { - accrue("host", u.Host[:i]) - accrue("port", u.Host[i+1:]) + accrue("host", host) + accrue("port", port) } if u.Path != "" { diff --git a/vendor/github.com/lib/pq/url_test.go b/vendor/github.com/lib/pq/url_test.go deleted file mode 100644 index 29f4a7c7516..00000000000 --- a/vendor/github.com/lib/pq/url_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package pq - -import ( - "testing" -) - -func TestSimpleParseURL(t *testing.T) { - expected := "host=hostname.remote" - str, err := ParseURL("postgres://hostname.remote") - if err != nil { - t.Fatal(err) - } - - if str != expected { - t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) - } -} - -func TestFullParseURL(t *testing.T) { - expected := `dbname=database host=hostname.remote password=top\ secret port=1234 user=username` - str, err := ParseURL("postgres://username:top%20secret@hostname.remote:1234/database") - if err != nil { - t.Fatal(err) - } - - if str != expected { - t.Fatalf("unexpected result from ParseURL:\n+ %s\n- %s", str, expected) - } -} - -func TestInvalidProtocolParseURL(t *testing.T) { - _, err := ParseURL("http://hostname.remote") - switch err { - case nil: - t.Fatal("Expected an error from parsing invalid protocol") - default: - msg := "invalid connection protocol: http" - if err.Error() != msg { - t.Fatalf("Unexpected error message:\n+ %s\n- %s", - err.Error(), msg) - } - } -} - -func TestMinimalURL(t *testing.T) { - cs, err := ParseURL("postgres://") - if err != nil { - t.Fatal(err) - } - - if cs != "" { - t.Fatalf("expected blank connection string, got: %q", cs) - } -} diff --git a/vendor/github.com/lib/pq/user_posix.go b/vendor/github.com/lib/pq/user_posix.go index e937d7d0872..bf982524f93 100644 --- a/vendor/github.com/lib/pq/user_posix.go +++ b/vendor/github.com/lib/pq/user_posix.go @@ -1,6 +1,6 @@ // Package pq is a pure Go Postgres driver for the database/sql package. -// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris rumprun package pq diff --git a/vendor/github.com/lib/pq/uuid.go b/vendor/github.com/lib/pq/uuid.go new file mode 100644 index 00000000000..9a1b9e0748e --- /dev/null +++ b/vendor/github.com/lib/pq/uuid.go @@ -0,0 +1,23 @@ +package pq + +import ( + "encoding/hex" + "fmt" +) + +// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. +func decodeUUIDBinary(src []byte) ([]byte, error) { + if len(src) != 16 { + return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) + } + + dst := make([]byte, 36) + dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' + hex.Encode(dst[0:], src[0:4]) + hex.Encode(dst[9:], src[4:6]) + hex.Encode(dst[14:], src[6:8]) + hex.Encode(dst[19:], src[8:10]) + hex.Encode(dst[24:], src[10:16]) + + return dst, nil +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 5a60e70f653..07190f3fe94 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -448,6 +448,12 @@ "revision": "1bab8b35b6bb565f92cbc97939610af9369f942a", "revisionTime": "2017-02-10T14:05:23Z" }, + { + "checksumSHA1": "ZAj/o03zG8Ui4mZ4XmzU4yyKC04=", + "path": "github.com/lib/pq", + "revision": "dd1fe2071026ce53f36a39112e645b4d4f5793a4", + "revisionTime": "2017-07-07T05:36:02Z" + }, { "checksumSHA1": "8z32QKTSDusa4QQyunKE4kyYXZ8=", "path": "github.com/patrickmn/go-cache",