From a47b31ac62053a59fb27e17bed53367645104e6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torkel=20=C3=96degaard?= Date: Tue, 8 Aug 2017 16:17:52 +0200 Subject: [PATCH] fix: MySQL/Postgress max_idle_conn default was wrongly set to zero, which does not mean unlimited but zero, which in practice disables connection pooling, not good. now max idle conn is set to golang's default which is 2, fixes #8513 --- conf/defaults.ini | 4 +- conf/sample.ini | 4 +- docker/blocks/postgres/fig | 1 + vendor/github.com/lib/pq/.gitignore | 4 - vendor/github.com/lib/pq/.travis.yml | 61 - vendor/github.com/lib/pq/README.md | 15 +- vendor/github.com/lib/pq/array.go | 756 ++++++++++ vendor/github.com/lib/pq/bench_test.go | 434 ------ vendor/github.com/lib/pq/buf.go | 30 +- vendor/github.com/lib/pq/conn.go | 1098 +++++++++----- vendor/github.com/lib/pq/conn_go18.go | 128 ++ vendor/github.com/lib/pq/conn_test.go | 1286 ----------------- vendor/github.com/lib/pq/copy.go | 40 +- vendor/github.com/lib/pq/copy_test.go | 380 ----- vendor/github.com/lib/pq/doc.go | 28 +- vendor/github.com/lib/pq/encode.go | 306 +++- vendor/github.com/lib/pq/encode_test.go | 433 ------ vendor/github.com/lib/pq/error.go | 13 + vendor/github.com/lib/pq/notify.go | 68 +- vendor/github.com/lib/pq/notify_test.go | 502 ------- vendor/github.com/lib/pq/ssl.go | 158 ++ vendor/github.com/lib/pq/ssl_go1.7.go | 14 + vendor/github.com/lib/pq/ssl_permissions.go | 20 + vendor/github.com/lib/pq/ssl_renegotiation.go | 8 + vendor/github.com/lib/pq/ssl_test.go | 226 --- vendor/github.com/lib/pq/ssl_windows.go | 9 + vendor/github.com/lib/pq/url.go | 10 +- vendor/github.com/lib/pq/url_test.go | 54 - vendor/github.com/lib/pq/user_posix.go | 2 +- vendor/github.com/lib/pq/uuid.go | 23 + vendor/vendor.json | 6 + 31 files changed, 2235 insertions(+), 3886 deletions(-) delete mode 100644 vendor/github.com/lib/pq/.gitignore delete mode 100644 vendor/github.com/lib/pq/.travis.yml create mode 100644 vendor/github.com/lib/pq/array.go delete mode 100644 vendor/github.com/lib/pq/bench_test.go create mode 100644 vendor/github.com/lib/pq/conn_go18.go delete mode 100644 vendor/github.com/lib/pq/conn_test.go delete mode 100644 vendor/github.com/lib/pq/copy_test.go delete mode 100644 vendor/github.com/lib/pq/encode_test.go delete mode 100644 vendor/github.com/lib/pq/notify_test.go create mode 100644 vendor/github.com/lib/pq/ssl.go create mode 100644 vendor/github.com/lib/pq/ssl_go1.7.go create mode 100644 vendor/github.com/lib/pq/ssl_permissions.go create mode 100644 vendor/github.com/lib/pq/ssl_renegotiation.go delete mode 100644 vendor/github.com/lib/pq/ssl_test.go create mode 100644 vendor/github.com/lib/pq/ssl_windows.go delete mode 100644 vendor/github.com/lib/pq/url_test.go create mode 100644 vendor/github.com/lib/pq/uuid.go diff --git a/conf/defaults.ini b/conf/defaults.ini index ceb916917ab..f0156b70511 100644 --- a/conf/defaults.ini +++ b/conf/defaults.ini @@ -76,8 +76,10 @@ password = # Example: mysql://user:secret@host:port/database url = +# Max idle conn setting default is 2 +max_idle_conn = 2 + # Max conn setting default is 0 (mean not set) -max_idle_conn = max_open_conn = # For "postgres", use either "disable", "require" or "verify-full" diff --git a/conf/sample.ini b/conf/sample.ini index e33f408ecad..80c7464f89c 100644 --- a/conf/sample.ini +++ b/conf/sample.ini @@ -85,8 +85,10 @@ # For "sqlite3" only, path relative to data_path setting ;path = grafana.db +# Max idle conn setting default is 2 +;max_idle_conn = 2 + # Max conn setting default is 0 (mean not set) -;max_idle_conn = ;max_open_conn = diff --git a/docker/blocks/postgres/fig b/docker/blocks/postgres/fig index 33499688810..049339a15c7 100644 --- a/docker/blocks/postgres/fig +++ b/docker/blocks/postgres/fig @@ -6,3 +6,4 @@ postgrestest: POSTGRES_DATABASE: grafana ports: - "5432:5432" + command: postgres -c log_connections=on -c logging_collector=on -c log_destination=stderr -c log_directory=/var/log/postgresql diff --git a/vendor/github.com/lib/pq/.gitignore b/vendor/github.com/lib/pq/.gitignore deleted file mode 100644 index 0f1d00e1196..00000000000 --- a/vendor/github.com/lib/pq/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -.db -*.test -*~ -*.swp diff --git a/vendor/github.com/lib/pq/.travis.yml b/vendor/github.com/lib/pq/.travis.yml deleted file mode 100644 index 82df2a0aabf..00000000000 --- a/vendor/github.com/lib/pq/.travis.yml +++ /dev/null @@ -1,61 +0,0 @@ -language: go - -go: - - 1.1 - - 1.2 - - 1.3 - - 1.4 - - tip - -before_install: - - psql --version - - sudo /etc/init.d/postgresql stop - - sudo apt-get -y --purge remove postgresql libpq-dev libpq5 postgresql-client-common postgresql-common - - sudo rm -rf /var/lib/postgresql - - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - - - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" - - sudo apt-get update -qq - - sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION - - sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgossltest 127.0.0.1/32 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgosslcert 127.0.0.1/32 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgossltest 127.0.0.1/32 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgosslcert 127.0.0.1/32 cert" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "host all all 127.0.0.1/32 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgossltest ::1/128 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgosslcert ::1/128 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgossltest ::1/128 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgosslcert ::1/128 cert" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "host all all ::1/128 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - sudo install -o postgres -g postgres -m 600 -t /var/lib/postgresql/$PGVERSION/main/ certs/server.key certs/server.crt certs/root.crt - - sudo bash -c "[[ '${PGVERSION}' < '9.2' ]] || (echo \"ssl_cert_file = 'server.crt'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf)" - - sudo bash -c "[[ '${PGVERSION}' < '9.2' ]] || (echo \"ssl_key_file = 'server.key'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf)" - - sudo bash -c "[[ '${PGVERSION}' < '9.2' ]] || (echo \"ssl_ca_file = 'root.crt'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf)" - - sudo sh -c "echo 127.0.0.1 postgres >> /etc/hosts" - - sudo ls -l /var/lib/postgresql/$PGVERSION/main/ - - sudo cat /etc/postgresql/$PGVERSION/main/postgresql.conf - - sudo chmod 600 $PQSSLCERTTEST_PATH/postgresql.key - - sudo /etc/init.d/postgresql restart - -env: - global: - - PGUSER=postgres - - PQGOSSLTESTS=1 - - PQSSLCERTTEST_PATH=$PWD/certs - matrix: - - PGVERSION=9.4 - - PGVERSION=9.3 - - PGVERSION=9.2 - - PGVERSION=9.1 - - PGVERSION=9.0 - - PGVERSION=8.4 - -script: - - go test -v ./... - -before_script: - - psql -c 'create database pqgotest' -U postgres - - psql -c 'create user pqgossltest' -U postgres - - psql -c 'create user pqgosslcert' -U postgres diff --git a/vendor/github.com/lib/pq/README.md b/vendor/github.com/lib/pq/README.md index cb15d6f0db0..7670fc87a51 100644 --- a/vendor/github.com/lib/pq/README.md +++ b/vendor/github.com/lib/pq/README.md @@ -1,6 +1,6 @@ # pq - A pure Go postgres driver for Go's database/sql package -[![Build Status](https://travis-ci.org/lib/pq.png?branch=master)](https://travis-ci.org/lib/pq) +[![Build Status](https://travis-ci.org/lib/pq.svg?branch=master)](https://travis-ci.org/lib/pq) ## Install @@ -20,11 +20,11 @@ variables. Example: - PGHOST=/var/run/postgresql go test github.com/lib/pq + PGHOST=/run/postgresql go test github.com/lib/pq Optionally, a benchmark suite can be run as part of the tests: - PGHOST=/var/run/postgresql go test -bench . + PGHOST=/run/postgresql go test -bench . ## Features @@ -38,6 +38,7 @@ Optionally, a benchmark suite can be run as part of the tests: * Many libpq compatible environment variables * Unix socket support * Notifications: `LISTEN`/`NOTIFY` +* pgpass support ## Future / Things you can help with @@ -57,13 +58,17 @@ code still exists in here. * Brad Fitzpatrick (bradfitz) * Charlie Melbye (cmelbye) * Chris Bandy (cbandy) +* Chris Gilling (cgilling) * Chris Walsh (cwds) * Dan Sosedoff (sosedoff) * Daniel Farina (fdr) * Eric Chlebek (echlebek) +* Eric Garrido (minusnine) +* Eric Urban (hydrogen18) * Everyone at The Go Team * Evan Shaw (edsrzf) * Ewan Chou (coocood) +* Fazal Majid (fazalmajid) * Federico Romero (federomero) * Fumin (fumin) * Gary Burd (garyburd) @@ -80,7 +85,7 @@ code still exists in here. * Keith Rarick (kr) * Kir Shatrov (kirs) * Lann Martin (lann) -* Maciek Sakrejda (deafbybeheading) +* Maciek Sakrejda (uhoh-itsmaciek) * Marc Brinkmann (mbr) * Marko Tiikkaja (johto) * Matt Newberry (MattNewberry) @@ -94,5 +99,7 @@ code still exists in here. * Ryan Smith (ryandotsmith) * Samuel Stauffer (samuel) * Timothée Peignier (cyberdelia) +* Travis Cline (tmc) * TruongSinh Tran-Nguyen (truongsinh) +* Yaismel Miranda (ympons) * notedit (notedit) diff --git a/vendor/github.com/lib/pq/array.go b/vendor/github.com/lib/pq/array.go new file mode 100644 index 00000000000..e7b2145d67b --- /dev/null +++ b/vendor/github.com/lib/pq/array.go @@ -0,0 +1,756 @@ +package pq + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/hex" + "fmt" + "reflect" + "strconv" + "strings" +) + +var typeByteSlice = reflect.TypeOf([]byte{}) +var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +// Array returns the optimal driver.Valuer and sql.Scanner for an array or +// slice of any dimension. +// +// For example: +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) +// +// Scanning multi-dimensional arrays is not supported. Arrays where the lower +// bound is not one (such as `[0:0]={1}') are not supported. +func Array(a interface{}) interface { + driver.Valuer + sql.Scanner +} { + switch a := a.(type) { + case []bool: + return (*BoolArray)(&a) + case []float64: + return (*Float64Array)(&a) + case []int64: + return (*Int64Array)(&a) + case []string: + return (*StringArray)(&a) + + case *[]bool: + return (*BoolArray)(a) + case *[]float64: + return (*Float64Array)(a) + case *[]int64: + return (*Int64Array)(a) + case *[]string: + return (*StringArray)(a) + } + + return GenericArray{a} +} + +// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner +// to override the array delimiter used by GenericArray. +type ArrayDelimiter interface { + // ArrayDelimiter returns the delimiter character(s) for this element's type. + ArrayDelimiter() string +} + +// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. +type BoolArray []bool + +// Scan implements the sql.Scanner interface. +func (a *BoolArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to BoolArray", src) +} + +func (a *BoolArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BoolArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(BoolArray, len(elems)) + for i, v := range elems { + if len(v) != 1 { + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + switch v[0] { + case 't': + b[i] = true + case 'f': + b[i] = false + default: + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a BoolArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be exactly two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1+2*n) + + for i := 0; i < n; i++ { + b[2*i] = ',' + if a[i] { + b[1+2*i] = 't' + } else { + b[1+2*i] = 'f' + } + } + + b[0] = '{' + b[2*n] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. +type ByteaArray [][]byte + +// Scan implements the sql.Scanner interface. +func (a *ByteaArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) +} + +func (a *ByteaArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(ByteaArray, len(elems)) + for i, v := range elems { + b[i], err = parseBytea(v) + if err != nil { + return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. It uses the "hex" format which +// is only supported on PostgreSQL 9.0 or newer. +func (a ByteaArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // 3*N bytes of hex formatting, and N-1 bytes of delimiters. + size := 1 + 6*n + for _, x := range a { + size += hex.EncodedLen(len(x)) + } + + b := make([]byte, size) + + for i, s := 0, b; i < n; i++ { + o := copy(s, `,"\\x`) + o += hex.Encode(s[o:], a[i]) + s[o] = '"' + s = s[o+1:] + } + + b[0] = '{' + b[size-1] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// Float64Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float64Array []float64 + +// Scan implements the sql.Scanner interface. +func (a *Float64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float64Array", src) +} + +func (a *Float64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float64Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, a[0], 'f', -1, 64) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, a[i], 'f', -1, 64) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// GenericArray implements the driver.Valuer and sql.Scanner interfaces for +// an array or slice of any dimension. +type GenericArray struct{ A interface{} } + +func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { + var assign func([]byte, reflect.Value) error + var del = "," + + // TODO calculate the assign function for other types + // TODO repeat this section on the element type of arrays or slices (multidimensional) + { + if reflect.PtrTo(rt).Implements(typeSqlScanner) { + // dest is always addressable because it is an element of a slice. + assign = func(src []byte, dest reflect.Value) (err error) { + ss := dest.Addr().Interface().(sql.Scanner) + if src == nil { + err = ss.Scan(nil) + } else { + err = ss.Scan(src) + } + return + } + goto FoundType + } + + assign = func([]byte, reflect.Value) error { + return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) + } + } + +FoundType: + + if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + return rt, assign, del +} + +// Scan implements the sql.Scanner interface. +func (a GenericArray) Scan(src interface{}) error { + dpv := reflect.ValueOf(a.A) + switch { + case dpv.Kind() != reflect.Ptr: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + case dpv.IsNil(): + return fmt.Errorf("pq: destination %T is nil", a.A) + } + + dv := dpv.Elem() + switch dv.Kind() { + case reflect.Slice: + case reflect.Array: + default: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + } + + switch src := src.(type) { + case []byte: + return a.scanBytes(src, dv) + case string: + return a.scanBytes([]byte(src), dv) + case nil: + if dv.Kind() == reflect.Slice { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + } + + return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) +} + +func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { + dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) + dims, elems, err := parseArray(src, []byte(del)) + if err != nil { + return err + } + + // TODO allow multidimensional + + if len(dims) > 1 { + return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", + strings.Replace(fmt.Sprint(dims), " ", "][", -1)) + } + + // Treat a zero-dimensional array like an array with a single dimension of zero. + if len(dims) == 0 { + dims = append(dims, 0) + } + + for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { + switch rt.Kind() { + case reflect.Slice: + case reflect.Array: + if rt.Len() != dims[i] { + return fmt.Errorf("pq: cannot convert ARRAY%s to %s", + strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) + } + default: + // TODO handle multidimensional + } + } + + values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) + for i, e := range elems { + if err := assign(e, values.Index(i)); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + + // TODO handle multidimensional + + switch dv.Kind() { + case reflect.Slice: + dv.Set(values.Slice(0, dims[0])) + case reflect.Array: + for i := 0; i < dims[0]; i++ { + dv.Index(i).Set(values.Index(i)) + } + } + + return nil +} + +// Value implements the driver.Valuer interface. +func (a GenericArray) Value() (driver.Value, error) { + if a.A == nil { + return nil, nil + } + + rv := reflect.ValueOf(a.A) + + switch rv.Kind() { + case reflect.Slice: + if rv.IsNil() { + return nil, nil + } + case reflect.Array: + default: + return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) + } + + if n := rv.Len(); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 0, 1+2*n) + + b, _, err := appendArray(b, rv, n) + return string(b), err + } + + return "{}", nil +} + +// Int64Array represents a one-dimensional array of the PostgreSQL integer types. +type Int64Array []int64 + +// Scan implements the sql.Scanner interface. +func (a *Int64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int64Array", src) +} + +func (a *Int64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int64Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, a[0], 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, a[i], 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// StringArray represents a one-dimensional array of the PostgreSQL character types. +type StringArray []string + +// Scan implements the sql.Scanner interface. +func (a *StringArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to StringArray", src) +} + +func (a *StringArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "StringArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(StringArray, len(elems)) + for i, v := range elems { + if b[i] = string(v); v == nil { + return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedBytes(b, []byte(a[0])) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedBytes(b, []byte(a[i])) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// appendArray appends rv to the buffer, returning the extended buffer and +// the delimiter used between elements. +// +// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. +func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { + var del string + var err error + + b = append(b, '{') + + if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { + return b, del, err + } + + for i := 1; i < n; i++ { + b = append(b, del...) + if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { + return b, del, err + } + } + + return append(b, '}'), del, nil +} + +// appendArrayElement appends rv to the buffer, returning the extended buffer +// and the delimiter to use before the next element. +// +// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted +// using driver.DefaultParameterConverter and the resulting []byte or string +// is double-quoted. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { + if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { + if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { + if n := rv.Len(); n > 0 { + return appendArray(b, rv, n) + } + + return b, "", nil + } + } + + var del string = "," + var err error + var iv interface{} = rv.Interface() + + if ad, ok := iv.(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { + return b, del, err + } + + switch v := iv.(type) { + case nil: + return append(b, "NULL"...), del, nil + case []byte: + return appendArrayQuotedBytes(b, v), del, nil + case string: + return appendArrayQuotedBytes(b, []byte(v)), del, nil + } + + b, err = appendValue(b, iv) + return b, del, err +} + +func appendArrayQuotedBytes(b, v []byte) []byte { + b = append(b, '"') + for { + i := bytes.IndexAny(v, `"\`) + if i < 0 { + b = append(b, v...) + break + } + if i > 0 { + b = append(b, v[:i]...) + } + b = append(b, '\\', v[i]) + v = v[i+1:] + } + return append(b, '"') +} + +func appendValue(b []byte, v driver.Value) ([]byte, error) { + return append(b, encode(nil, v, 0)...), nil +} + +// parseArray extracts the dimensions and elements of an array represented in +// text format. Only representations emitted by the backend are supported. +// Notably, whitespace around brackets and delimiters is significant, and NULL +// is case-sensitive. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { + var depth, i int + + if len(src) < 1 || src[0] != '{' { + return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) + } + +Open: + for i < len(src) { + switch src[i] { + case '{': + depth++ + i++ + case '}': + elems = make([][]byte, 0) + goto Close + default: + break Open + } + } + dims = make([]int, i) + +Element: + for i < len(src) { + switch src[i] { + case '{': + if depth == len(dims) { + break Element + } + depth++ + dims[depth-1] = 0 + i++ + case '"': + var elem = []byte{} + var escape bool + for i++; i < len(src); i++ { + if escape { + elem = append(elem, src[i]) + escape = false + } else { + switch src[i] { + default: + elem = append(elem, src[i]) + case '\\': + escape = true + case '"': + elems = append(elems, elem) + i++ + break Element + } + } + } + default: + for start := i; i < len(src); i++ { + if bytes.HasPrefix(src[i:], del) || src[i] == '}' { + elem := src[start:i] + if len(elem) == 0 { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + if bytes.Equal(elem, []byte("NULL")) { + elem = nil + } + elems = append(elems, elem) + break Element + } + } + } + } + + for i < len(src) { + if bytes.HasPrefix(src[i:], del) && depth > 0 { + dims[depth-1]++ + i += len(del) + goto Element + } else if src[i] == '}' && depth > 0 { + dims[depth-1]++ + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + +Close: + for i < len(src) { + if src[i] == '}' && depth > 0 { + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + if depth > 0 { + err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) + } + if err == nil { + for _, d := range dims { + if (len(elems) % d) != 0 { + err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") + } + } + } + return +} + +func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { + dims, elems, err := parseArray(src, del) + if err != nil { + return nil, err + } + if len(dims) > 1 { + return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) + } + return elems, err +} diff --git a/vendor/github.com/lib/pq/bench_test.go b/vendor/github.com/lib/pq/bench_test.go deleted file mode 100644 index f2eb1dab367..00000000000 --- a/vendor/github.com/lib/pq/bench_test.go +++ /dev/null @@ -1,434 +0,0 @@ -// +build go1.1 - -package pq - -import ( - "bufio" - "bytes" - "database/sql" - "database/sql/driver" - "github.com/lib/pq/oid" - "io" - "math/rand" - "net" - "runtime" - "strconv" - "strings" - "sync" - "testing" - "time" -) - -var ( - selectStringQuery = "SELECT '" + strings.Repeat("0123456789", 10) + "'" - selectSeriesQuery = "SELECT generate_series(1, 100)" -) - -func BenchmarkSelectString(b *testing.B) { - var result string - benchQuery(b, selectStringQuery, &result) -} - -func BenchmarkSelectSeries(b *testing.B) { - var result int - benchQuery(b, selectSeriesQuery, &result) -} - -func benchQuery(b *testing.B, query string, result interface{}) { - b.StopTimer() - db := openTestConn(b) - defer db.Close() - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchQueryLoop(b, db, query, result) - } -} - -func benchQueryLoop(b *testing.B, db *sql.DB, query string, result interface{}) { - rows, err := db.Query(query) - if err != nil { - b.Fatal(err) - } - defer rows.Close() - for rows.Next() { - err = rows.Scan(result) - if err != nil { - b.Fatal("failed to scan", err) - } - } -} - -// reading from circularConn yields content[:prefixLen] once, followed by -// content[prefixLen:] over and over again. It never returns EOF. -type circularConn struct { - content string - prefixLen int - pos int - net.Conn // for all other net.Conn methods that will never be called -} - -func (r *circularConn) Read(b []byte) (n int, err error) { - n = copy(b, r.content[r.pos:]) - r.pos += n - if r.pos >= len(r.content) { - r.pos = r.prefixLen - } - return -} - -func (r *circularConn) Write(b []byte) (n int, err error) { return len(b), nil } - -func (r *circularConn) Close() error { return nil } - -func fakeConn(content string, prefixLen int) *conn { - c := &circularConn{content: content, prefixLen: prefixLen} - return &conn{buf: bufio.NewReader(c), c: c} -} - -// This benchmark is meant to be the same as BenchmarkSelectString, but takes -// out some of the factors this package can't control. The numbers are less noisy, -// but also the costs of network communication aren't accurately represented. -func BenchmarkMockSelectString(b *testing.B) { - b.StopTimer() - // taken from a recorded run of BenchmarkSelectString - // See: http://www.postgresql.org/docs/current/static/protocol-message-formats.html - const response = "1\x00\x00\x00\x04" + - "t\x00\x00\x00\x06\x00\x00" + - "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + - "Z\x00\x00\x00\x05I" + - "2\x00\x00\x00\x04" + - "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + - "C\x00\x00\x00\rSELECT 1\x00" + - "Z\x00\x00\x00\x05I" + - "3\x00\x00\x00\x04" + - "Z\x00\x00\x00\x05I" - c := fakeConn(response, 0) - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchMockQuery(b, c, selectStringQuery) - } -} - -var seriesRowData = func() string { - var buf bytes.Buffer - for i := 1; i <= 100; i++ { - digits := byte(2) - if i >= 100 { - digits = 3 - } else if i < 10 { - digits = 1 - } - buf.WriteString("D\x00\x00\x00") - buf.WriteByte(10 + digits) - buf.WriteString("\x00\x01\x00\x00\x00") - buf.WriteByte(digits) - buf.WriteString(strconv.Itoa(i)) - } - return buf.String() -}() - -func BenchmarkMockSelectSeries(b *testing.B) { - b.StopTimer() - var response = "1\x00\x00\x00\x04" + - "t\x00\x00\x00\x06\x00\x00" + - "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + - "Z\x00\x00\x00\x05I" + - "2\x00\x00\x00\x04" + - seriesRowData + - "C\x00\x00\x00\x0fSELECT 100\x00" + - "Z\x00\x00\x00\x05I" + - "3\x00\x00\x00\x04" + - "Z\x00\x00\x00\x05I" - c := fakeConn(response, 0) - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchMockQuery(b, c, selectSeriesQuery) - } -} - -func benchMockQuery(b *testing.B, c *conn, query string) { - stmt, err := c.Prepare(query) - if err != nil { - b.Fatal(err) - } - defer stmt.Close() - rows, err := stmt.Query(nil) - if err != nil { - b.Fatal(err) - } - defer rows.Close() - var dest [1]driver.Value - for { - if err := rows.Next(dest[:]); err != nil { - if err == io.EOF { - break - } - b.Fatal(err) - } - } -} - -func BenchmarkPreparedSelectString(b *testing.B) { - var result string - benchPreparedQuery(b, selectStringQuery, &result) -} - -func BenchmarkPreparedSelectSeries(b *testing.B) { - var result int - benchPreparedQuery(b, selectSeriesQuery, &result) -} - -func benchPreparedQuery(b *testing.B, query string, result interface{}) { - b.StopTimer() - db := openTestConn(b) - defer db.Close() - stmt, err := db.Prepare(query) - if err != nil { - b.Fatal(err) - } - defer stmt.Close() - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchPreparedQueryLoop(b, db, stmt, result) - } -} - -func benchPreparedQueryLoop(b *testing.B, db *sql.DB, stmt *sql.Stmt, result interface{}) { - rows, err := stmt.Query() - if err != nil { - b.Fatal(err) - } - if !rows.Next() { - rows.Close() - b.Fatal("no rows") - } - defer rows.Close() - for rows.Next() { - err = rows.Scan(&result) - if err != nil { - b.Fatal("failed to scan") - } - } -} - -// See the comment for BenchmarkMockSelectString. -func BenchmarkMockPreparedSelectString(b *testing.B) { - b.StopTimer() - const parseResponse = "1\x00\x00\x00\x04" + - "t\x00\x00\x00\x06\x00\x00" + - "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + - "Z\x00\x00\x00\x05I" - const responses = parseResponse + - "2\x00\x00\x00\x04" + - "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + - "C\x00\x00\x00\rSELECT 1\x00" + - "Z\x00\x00\x00\x05I" - c := fakeConn(responses, len(parseResponse)) - - stmt, err := c.Prepare(selectStringQuery) - if err != nil { - b.Fatal(err) - } - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchPreparedMockQuery(b, c, stmt) - } -} - -func BenchmarkMockPreparedSelectSeries(b *testing.B) { - b.StopTimer() - const parseResponse = "1\x00\x00\x00\x04" + - "t\x00\x00\x00\x06\x00\x00" + - "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + - "Z\x00\x00\x00\x05I" - var responses = parseResponse + - "2\x00\x00\x00\x04" + - seriesRowData + - "C\x00\x00\x00\x0fSELECT 100\x00" + - "Z\x00\x00\x00\x05I" - c := fakeConn(responses, len(parseResponse)) - - stmt, err := c.Prepare(selectSeriesQuery) - if err != nil { - b.Fatal(err) - } - b.StartTimer() - - for i := 0; i < b.N; i++ { - benchPreparedMockQuery(b, c, stmt) - } -} - -func benchPreparedMockQuery(b *testing.B, c *conn, stmt driver.Stmt) { - rows, err := stmt.Query(nil) - if err != nil { - b.Fatal(err) - } - defer rows.Close() - var dest [1]driver.Value - for { - if err := rows.Next(dest[:]); err != nil { - if err == io.EOF { - break - } - b.Fatal(err) - } - } -} - -func BenchmarkEncodeInt64(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{}, int64(1234), oid.T_int8) - } -} - -func BenchmarkEncodeFloat64(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{}, 3.14159, oid.T_float8) - } -} - -var testByteString = []byte("abcdefghijklmnopqrstuvwxyz") - -func BenchmarkEncodeByteaHex(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{serverVersion: 90000}, testByteString, oid.T_bytea) - } -} -func BenchmarkEncodeByteaEscape(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{serverVersion: 84000}, testByteString, oid.T_bytea) - } -} - -func BenchmarkEncodeBool(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{}, true, oid.T_bool) - } -} - -var testTimestamptz = time.Date(2001, time.January, 1, 0, 0, 0, 0, time.Local) - -func BenchmarkEncodeTimestamptz(b *testing.B) { - for i := 0; i < b.N; i++ { - encode(¶meterStatus{}, testTimestamptz, oid.T_timestamptz) - } -} - -var testIntBytes = []byte("1234") - -func BenchmarkDecodeInt64(b *testing.B) { - for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testIntBytes, oid.T_int8) - } -} - -var testFloatBytes = []byte("3.14159") - -func BenchmarkDecodeFloat64(b *testing.B) { - for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testFloatBytes, oid.T_float8) - } -} - -var testBoolBytes = []byte{'t'} - -func BenchmarkDecodeBool(b *testing.B) { - for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testBoolBytes, oid.T_bool) - } -} - -func TestDecodeBool(t *testing.T) { - db := openTestConn(t) - rows, err := db.Query("select true") - if err != nil { - t.Fatal(err) - } - rows.Close() -} - -var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07") - -func BenchmarkDecodeTimestamptz(b *testing.B) { - for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz) - } -} - -func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) { - oldProcs := runtime.GOMAXPROCS(0) - defer runtime.GOMAXPROCS(oldProcs) - runtime.GOMAXPROCS(runtime.NumCPU()) - globalLocationCache = newLocationCache() - - f := func(wg *sync.WaitGroup, loops int) { - defer wg.Done() - for i := 0; i < loops; i++ { - decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz) - } - } - - wg := &sync.WaitGroup{} - b.ResetTimer() - for j := 0; j < 10; j++ { - wg.Add(1) - go f(wg, b.N/10) - } - wg.Wait() -} - -func BenchmarkLocationCache(b *testing.B) { - globalLocationCache = newLocationCache() - for i := 0; i < b.N; i++ { - globalLocationCache.getLocation(rand.Intn(10000)) - } -} - -func BenchmarkLocationCacheMultiThread(b *testing.B) { - oldProcs := runtime.GOMAXPROCS(0) - defer runtime.GOMAXPROCS(oldProcs) - runtime.GOMAXPROCS(runtime.NumCPU()) - globalLocationCache = newLocationCache() - - f := func(wg *sync.WaitGroup, loops int) { - defer wg.Done() - for i := 0; i < loops; i++ { - globalLocationCache.getLocation(rand.Intn(10000)) - } - } - - wg := &sync.WaitGroup{} - b.ResetTimer() - for j := 0; j < 10; j++ { - wg.Add(1) - go f(wg, b.N/10) - } - wg.Wait() -} - -// Stress test the performance of parsing results from the wire. -func BenchmarkResultParsing(b *testing.B) { - b.StopTimer() - - db := openTestConn(b) - defer db.Close() - _, err := db.Exec("BEGIN") - if err != nil { - b.Fatal(err) - } - - b.StartTimer() - for i := 0; i < b.N; i++ { - res, err := db.Query("SELECT generate_series(1, 50000)") - if err != nil { - b.Fatal(err) - } - res.Close() - } -} diff --git a/vendor/github.com/lib/pq/buf.go b/vendor/github.com/lib/pq/buf.go index 9f417a1ebae..666b0012a79 100644 --- a/vendor/github.com/lib/pq/buf.go +++ b/vendor/github.com/lib/pq/buf.go @@ -3,6 +3,7 @@ package pq import ( "bytes" "encoding/binary" + "github.com/lib/pq/oid" ) @@ -20,6 +21,7 @@ func (b *readBuf) oid() (n oid.Oid) { return } +// N.B: this is actually an unsigned 16-bit integer, unlike int32 func (b *readBuf) int16() (n int) { n = int(binary.BigEndian.Uint16(*b)) *b = (*b)[2:] @@ -46,28 +48,44 @@ func (b *readBuf) byte() byte { return b.next(1)[0] } -type writeBuf []byte +type writeBuf struct { + buf []byte + pos int +} func (b *writeBuf) int32(n int) { x := make([]byte, 4) binary.BigEndian.PutUint32(x, uint32(n)) - *b = append(*b, x...) + b.buf = append(b.buf, x...) } func (b *writeBuf) int16(n int) { x := make([]byte, 2) binary.BigEndian.PutUint16(x, uint16(n)) - *b = append(*b, x...) + b.buf = append(b.buf, x...) } func (b *writeBuf) string(s string) { - *b = append(*b, (s + "\000")...) + b.buf = append(b.buf, (s + "\000")...) } func (b *writeBuf) byte(c byte) { - *b = append(*b, c) + b.buf = append(b.buf, c) } func (b *writeBuf) bytes(v []byte) { - *b = append(*b, v...) + b.buf = append(b.buf, v...) +} + +func (b *writeBuf) wrap() []byte { + p := b.buf[b.pos:] + binary.BigEndian.PutUint32(p, uint32(len(p))) + return b.buf +} + +func (b *writeBuf) next(c byte) { + p := b.buf[b.pos:] + binary.BigEndian.PutUint32(p, uint32(len(p))) + b.pos = len(b.buf) + 1 + b.buf = append(b.buf, c, 0, 0, 0, 0) } diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index 44e4833921f..3747ffcaf84 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -3,16 +3,12 @@ package pq import ( "bufio" "crypto/md5" - "crypto/tls" - "crypto/x509" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" - "github.com/lib/pq/oid" "io" - "io/ioutil" "net" "os" "os/user" @@ -22,6 +18,8 @@ import ( "strings" "time" "unicode" + + "github.com/lib/pq/oid" ) // Common error types @@ -31,16 +29,20 @@ var ( ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") + + errUnexpectedReady = errors.New("unexpected ReadyForQuery") + errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") + errNoLastInsertId = errors.New("no LastInsertId available after the empty statement") ) -type drv struct{} +type Driver struct{} -func (d *drv) Open(name string) (driver.Conn, error) { +func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } func init() { - sql.Register("postgres", &drv{}) + sql.Register("postgres", &Driver{}) } type parameterStatus struct { @@ -96,6 +98,15 @@ type conn struct { namei int scratch [512]byte txnStatus transactionStatus + txnFinish func() + + // Save connection arguments to use during CancelRequest. + dialer Dialer + opts values + + // Cancellation key data for use with CancelRequest messages. + processID int + secretKey int parameterStatus parameterStatus @@ -105,12 +116,125 @@ type conn struct { // If true, this connection is bad and all public-facing functions should // return ErrBadConn. bad bool + + // If set, this connection should never use the binary format when + // receiving query results from prepared statements. Only provided for + // debugging. + disablePreparedBinaryResult bool + + // Whether to always send []byte parameters over as binary. Enables single + // round-trip mode for non-prepared Query calls. + binaryParameters bool + + // If true this connection is in the middle of a COPY + inCopy bool +} + +// Handle driver-side settings in parsed connection string. +func (c *conn) handleDriverSettings(o values) (err error) { + boolSetting := func(key string, val *bool) error { + if value, ok := o[key]; ok { + if value == "yes" { + *val = true + } else if value == "no" { + *val = false + } else { + return fmt.Errorf("unrecognized value %q for %s", value, key) + } + } + return nil + } + + err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) + if err != nil { + return err + } + err = boolSetting("binary_parameters", &c.binaryParameters) + if err != nil { + return err + } + return nil +} + +func (c *conn) handlePgpass(o values) { + // if a password was supplied, do not process .pgpass + if _, ok := o["password"]; ok { + return + } + filename := os.Getenv("PGPASSFILE") + if filename == "" { + // XXX this code doesn't work on Windows where the default filename is + // XXX %APPDATA%\postgresql\pgpass.conf + user, err := user.Current() + if err != nil { + return + } + filename = filepath.Join(user.HomeDir, ".pgpass") + } + fileinfo, err := os.Stat(filename) + if err != nil { + return + } + mode := fileinfo.Mode() + if mode&(0x77) != 0 { + // XXX should warn about incorrect .pgpass permissions as psql does + return + } + file, err := os.Open(filename) + if err != nil { + return + } + defer file.Close() + scanner := bufio.NewScanner(io.Reader(file)) + hostname := o["host"] + ntw, _ := network(o) + port := o["port"] + db := o["dbname"] + username := o["user"] + // From: https://github.com/tg/pgpass/blob/master/reader.go + getFields := func(s string) []string { + fs := make([]string, 0, 5) + f := make([]rune, 0, len(s)) + + var esc bool + for _, c := range s { + switch { + case esc: + f = append(f, c) + esc = false + case c == '\\': + esc = true + case c == ':': + fs = append(fs, string(f)) + f = f[:0] + default: + f = append(f, c) + } + } + return append(fs, string(f)) + } + for scanner.Scan() { + line := scanner.Text() + if len(line) == 0 || line[0] == '#' { + continue + } + split := getFields(line) + if len(split) != 5 { + continue + } + if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { + o["password"] = split[4] + return + } + } } func (c *conn) writeBuf(b byte) *writeBuf { c.scratch[0] = b - w := writeBuf(c.scratch[:5]) - return &w + return &writeBuf{ + buf: c.scratch[:5], + pos: 1, + } } func Open(name string) (_ driver.Conn, err error) { @@ -118,22 +242,11 @@ func Open(name string) (_ driver.Conn, err error) { } func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { - defer func() { - // Handle any panics during connection initialization. Note that we - // specifically do *not* want to use errRecover(), as that would turn - // any connection errors into ErrBadConns, hiding the real error - // message from the user. - e := recover() - if e == nil { - // Do nothing - return - } - var ok bool - err, ok = e.(error) - if !ok { - err = fmt.Errorf("pq: unexpected error: %#v", e) - } - }() + // Handle any panics during connection initialization. Note that we + // specifically do *not* want to use errRecover(), as that would turn any + // connection errors into ErrBadConns, hiding the real error message from + // the user. + defer errRecoverNoErrBadConn(&err) o := make(values) @@ -142,16 +255,16 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // * Very low precedence defaults applied in every situation // * Environment variables // * Explicitly passed connection information - o.Set("host", "localhost") - o.Set("port", "5432") + o["host"] = "localhost" + o["port"] = "5432" // N.B.: Extra float digits should be set to 3, but that breaks // Postgres 8.4 and older, where the max is 2. - o.Set("extra_float_digits", "2") + o["extra_float_digits"] = "2" for k, v := range parseEnviron(os.Environ()) { - o.Set(k, v) + o[k] = v } - if strings.HasPrefix(name, "postgres://") { + if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { name, err = ParseURL(name) if err != nil { return nil, err @@ -163,9 +276,9 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } // Use the "fallback" application name if necessary - if fallback := o.Get("fallback_application_name"); fallback != "" { - if !o.Isset("application_name") { - o.Set("application_name", fallback) + if fallback, ok := o["fallback_application_name"]; ok { + if _, ok := o["application_name"]; !ok { + o["application_name"] = fallback } } @@ -176,53 +289,66 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // parsing its value is not worth it. Instead, we always explicitly send // client_encoding as a separate run-time parameter, which should override // anything set in options. - if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) { + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { return nil, errors.New("client_encoding must be absent or 'UTF8'") } - o.Set("client_encoding", "UTF8") + o["client_encoding"] = "UTF8" // DateStyle needs a similar treatment. - if datestyle := o.Get("datestyle"); datestyle != "" { + if datestyle, ok := o["datestyle"]; ok { if datestyle != "ISO, MDY" { panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)) } } else { - o.Set("datestyle", "ISO, MDY") + o["datestyle"] = "ISO, MDY" } // If a user is not provided by any other means, the last // resort is to use the current operating system provided user // name. - if o.Get("user") == "" { + if _, ok := o["user"]; !ok { u, err := userCurrent() if err != nil { return nil, err } else { - o.Set("user", u) + o["user"] = u } } - c, err := dial(d, o) + cn := &conn{ + opts: o, + dialer: d, + } + err = cn.handleDriverSettings(o) if err != nil { return nil, err } + cn.handlePgpass(o) - cn := &conn{c: c} + cn.c, err = dial(d, o) + if err != nil { + return nil, err + } cn.ssl(o) cn.buf = bufio.NewReader(cn.c) cn.startup(o) + // reset the deadline, in case one was set (see dial) - err = cn.c.SetDeadline(time.Time{}) + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { + err = cn.c.SetDeadline(time.Time{}) + } return cn, err } func dial(d Dialer, o values) (net.Conn, error) { ntw, addr := network(o) - - timeout := o.Get("connect_timeout") + // SSL is not necessary or supported over UNIX domain sockets + if ntw == "unix" { + o["sslmode"] = "disable" + } // Zero or not specified means wait indefinitely. - if timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { seconds, err := strconv.ParseInt(timeout, 10, 0) if err != nil { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) @@ -244,31 +370,18 @@ func dial(d Dialer, o values) (net.Conn, error) { } func network(o values) (string, string) { - host := o.Get("host") + host := o["host"] if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) + sockPath := path.Join(host, ".s.PGSQL."+o["port"]) return "unix", sockPath } - return "tcp", host + ":" + o.Get("port") + return "tcp", net.JoinHostPort(host, o["port"]) } type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - -func (vs values) Isset(k string) bool { - _, ok := vs[k] - return ok -} - // scanner implements a tokenizer for libpq-style option strings. type scanner struct { s []rune @@ -339,7 +452,7 @@ func parseOpts(name string, o values) error { // Skip any whitespace after the = if r, ok = s.SkipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. - o.Set(string(keyRunes), "") + o[string(keyRunes)] = "" break } @@ -374,7 +487,7 @@ func parseOpts(name string, o values) error { } } - o.Set(string(keyRunes), string(valRunes)) + o[string(keyRunes)] = string(valRunes) } return nil @@ -393,13 +506,17 @@ func (cn *conn) checkIsInTransaction(intxn bool) { } func (cn *conn) Begin() (_ driver.Tx, err error) { + return cn.begin("") +} + +func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) cn.checkIsInTransaction(false) - _, commandTag, err := cn.simpleExec("BEGIN") + _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { return nil, err } @@ -414,7 +531,14 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { return cn, nil } +func (cn *conn) closeTxn() { + if finish := cn.txnFinish; finish != nil { + finish() + } +} + func (cn *conn) Commit() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -436,6 +560,9 @@ func (cn *conn) Commit() (err error) { _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { + if cn.isInTransaction() { + cn.bad = true + } return err } if commandTag != "COMMIT" { @@ -447,6 +574,7 @@ func (cn *conn) Commit() (err error) { } func (cn *conn) Rollback() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -455,6 +583,9 @@ func (cn *conn) Rollback() (err error) { cn.checkIsInTransaction(true) _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { + if cn.isInTransaction() { + cn.bad = true + } return err } if commandTag != "ROLLBACK" { @@ -481,11 +612,16 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } // done return case 'E': err = parseError(r) - case 'T', 'D', 'I': + case 'I': + res = emptyRows + case 'T', 'D': // ignore any results default: cn.bad = true @@ -494,11 +630,9 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err } } -func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { +func (cn *conn) simpleQuery(q string) (res *rows, err error) { defer cn.errRecover(&err) - st := &stmt{cn: cn, name: ""} - b := cn.writeBuf('Q') b.string(q) cn.send(b) @@ -515,7 +649,18 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { cn.bad = true errorf("unexpected message %q in simple query execution", t) } - res = &rows{st: st, done: true} + if res == nil { + res = &rows{ + cn: cn, + } + } + // Set the result and tag to the last command complete if there wasn't a + // query already run. Although queries usually return from here and cede + // control to Next, a query with zero results does not. + if t == 'C' && res.colNames == nil { + res.result, res.tag = cn.parseComplete(r.string()) + } + res.done = true case 'Z': cn.processReadyForQuery(r) // done @@ -534,8 +679,8 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { case 'T': // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it - res = &rows{st: st} - st.cols, st.rowTyps = parseMeta(r) + res = &rows{cn: cn} + res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. @@ -546,47 +691,90 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { } } -func (cn *conn) prepareTo(q, stmtName string) (_ *stmt, err error) { +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errNoLastInsertId +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errNoRowsAffected +} + +// Decides which column formats to use for a prepared statement. The input is +// an array of type oids, one element per result column. +func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { + if len(colTyps) == 0 { + return nil, colFmtDataAllText + } + + colFmts = make([]format, len(colTyps)) + if forceText { + return colFmts, colFmtDataAllText + } + + allBinary := true + allText := true + for i, o := range colTyps { + switch o { + // This is the list of types to use binary mode for when receiving them + // through a prepared statement. If a type appears in this list, it + // must also be implemented in binaryDecode in encode.go. + case oid.T_bytea: + fallthrough + case oid.T_int8: + fallthrough + case oid.T_int4: + fallthrough + case oid.T_int2: + fallthrough + case oid.T_uuid: + colFmts[i] = formatBinary + allText = false + + default: + allBinary = false + } + } + + if allBinary { + return colFmts, colFmtDataAllBinary + } else if allText { + return colFmts, colFmtDataAllText + } else { + colFmtData = make([]byte, 2+len(colFmts)*2) + binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) + for i, v := range colFmts { + binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) + } + return colFmts, colFmtData + } +} + +func (cn *conn) prepareTo(q, stmtName string) *stmt { st := &stmt{cn: cn, name: stmtName} b := cn.writeBuf('P') b.string(st.name) b.string(q) b.int16(0) - cn.send(b) - b = cn.writeBuf('D') + b.next('D') b.byte('S') b.string(st.name) + + b.next('S') cn.send(b) - cn.send(cn.writeBuf('S')) - - for { - t, r := cn.recv1() - switch t { - case '1': - case 't': - nparams := r.int16() - st.paramTyps = make([]oid.Oid, nparams) - - for i := range st.paramTyps { - st.paramTyps[i] = r.oid() - } - case 'T': - st.cols, st.rowTyps = parseMeta(r) - case 'n': - // no data - case 'Z': - cn.processReadyForQuery(r) - return st, err - case 'E': - err = parseError(r) - default: - cn.bad = true - errorf("unexpected describe rows response: %q", t) - } - } + cn.readParseResponse() + st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() + st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) + cn.readReadyForQuery() + return st } func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { @@ -596,32 +784,45 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { defer cn.errRecover(&err) if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { - return cn.prepareCopyIn(q) + s, err := cn.prepareCopyIn(q) + if err == nil { + cn.inCopy = true + } + return s, err } - return cn.prepareTo(q, cn.gname()) + return cn.prepareTo(q, cn.gname()), nil } func (cn *conn) Close() (err error) { - if cn.bad { - return driver.ErrBadConn - } + // Skip cn.bad return here because we always want to close a connection. defer cn.errRecover(&err) + // Ensure that cn.c.Close is always run. Since error handling is done with + // panics and cn.errRecover, the Close must be in a defer. + defer func() { + cerr := cn.c.Close() + if err == nil { + err = cerr + } + }() + // Don't go through send(); ListenerConn relies on us not scribbling on the // scratch buffer of this connection. - err = cn.sendSimpleMessage('X') - if err != nil { - return err - } - - return cn.c.Close() + return cn.sendSimpleMessage('X') } // Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) { +func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return cn.query(query, args) +} + +func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { if cn.bad { return nil, driver.ErrBadConn } + if cn.inCopy { + return nil, errCopyInProgress + } defer cn.errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is @@ -630,17 +831,29 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err return cn.simpleQuery(query) } - st, err := cn.prepareTo(query, "") - if err != nil { - panic(err) - } + if cn.binaryParameters { + cn.sendBinaryModeQuery(query, args) - st.exec(args) - return &rows{st: st}, nil + cn.readParseResponse() + cn.readBindResponse() + rows := &rows{cn: cn} + rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() + cn.postExecuteWorkaround() + return rows, nil + } else { + st := cn.prepareTo(query, "") + st.exec(args) + return &rows{ + cn: cn, + colNames: st.colNames, + colTyps: st.colTyps, + colFmts: st.colFmts, + }, nil + } } // Implement the optional "Execer" interface for one-shot queries -func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) { +func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { if cn.bad { return nil, driver.ErrBadConn } @@ -654,37 +867,40 @@ func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err er return r, err } - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st, err := cn.prepareTo(query, "") - if err != nil { - panic(err) - } + if cn.binaryParameters { + cn.sendBinaryModeQuery(query, args) - r, err := st.Exec(args) - if err != nil { - panic(err) + cn.readParseResponse() + cn.readBindResponse() + cn.readPortalDescribeResponse() + cn.postExecuteWorkaround() + res, _, err = cn.readExecuteResponse("Execute") + return res, err + } else { + // Use the unnamed statement to defer planning until bind + // time, or else value-based selectivity estimates cannot be + // used. + st := cn.prepareTo(query, "") + r, err := st.Exec(args) + if err != nil { + panic(err) + } + return r, err } - - return r, err } -// Assumes len(*m) is > 5 func (cn *conn) send(m *writeBuf) { - b := (*m)[1:] - binary.BigEndian.PutUint32(b, uint32(len(b))) - - if (*m)[0] == 0 { - *m = b - } - - _, err := cn.c.Write(*m) + _, err := cn.c.Write(m.wrap()) if err != nil { panic(err) } } +func (cn *conn) sendStartupPacket(m *writeBuf) error { + _, err := cn.c.Write((m.wrap())[1:]) + return err +} + // Send a message of type typ to the server on the other end of cn. The // message should have no payload. This method does not use the scratch // buffer. @@ -796,30 +1012,17 @@ func (cn *conn) recv1() (t byte, r *readBuf) { } func (cn *conn) ssl(o values) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o.Get("sslmode"); mode { - case "require", "": - tlsConf.InsecureSkipVerify = true - case "verify-ca": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o.Get("host") - case "disable": + upgrade := ssl(o) + if upgrade == nil { + // Nothing to do return - default: - errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode) } - cn.setupSSLClientCertificates(&tlsConf, o) - cn.setupSSLCA(&tlsConf, o) - w := cn.writeBuf(0) w.int32(80877103) - cn.send(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } b := cn.scratch[:1] _, err := io.ReadFull(cn.c, b) @@ -831,114 +1034,7 @@ func (cn *conn) ssl(o values) { panic(ErrSSLNotSupported) } - client := tls.Client(cn.c, &tlsConf) - if verifyCaOnly { - cn.verifyCA(client, &tlsConf) - } - cn.c = client -} - -// verifyCA carries out a TLS handshake to the server and verifies the -// presented certificate against the effective CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) { - err := client.Handshake() - if err != nil { - panic(err) - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, - } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - _, err = certs[0].Verify(opts) - if err != nil { - panic(err) - } -} - -// This function sets up SSL client certificates based on either the "sslkey" -// and "sslcert" settings (possibly set via the environment variables PGSSLKEY -// and PGSSLCERT, respectively), or if they aren't set, from the .postgresql -// directory in the user's home directory. If the file paths are set -// explicitly, the files must exist. The key file must also not be -// world-readable, or this function will panic with -// ErrSSLKeyHasWorldPermissions. -func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) { - var missingOk bool - - sslkey := o.Get("sslkey") - sslcert := o.Get("sslcert") - if sslkey != "" && sslcert != "" { - // If the user has set an sslkey and sslcert, they *must* exist. - missingOk = false - } else { - // Automatically load certificates from ~/.postgresql. - user, err := user.Current() - if err != nil { - // user.Current() might fail when cross-compiling. We have to - // ignore the error and continue without client certificates, since - // we wouldn't know where to load them from. - return - } - - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") - missingOk = true - } - - // Check that both files exist, and report the error or stop, depending on - // which behaviour we want. Note that we don't do any more extensive - // checks than this (such as checking that the paths aren't directories); - // LoadX509KeyPair() will take care of the rest. - keyfinfo, err := os.Stat(sslkey) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - _, err = os.Stat(sslcert) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - - // If we got this far, the key file must also have the correct permissions - kmode := keyfinfo.Mode() - if kmode != kmode&0600 { - panic(ErrSSLKeyHasWorldPermissions) - } - - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - panic(err) - } - tlsConf.Certificates = []tls.Certificate{cert} -} - -// Sets up RootCAs in the TLS configuration if sslrootcert is set. -func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) { - if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" { - tlsConf.RootCAs = x509.NewCertPool() - - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - panic(err) - } - - ok := tlsConf.RootCAs.AppendCertsFromPEM(cert) - if !ok { - errorf("couldn't parse pem in sslrootcert") - } - } + cn.c = upgrade(cn.c) } // isDriverSetting returns true iff a setting is purely for configuring the @@ -956,6 +1052,10 @@ func isDriverSetting(key string) bool { return true case "connect_timeout": return true + case "disable_prepared_binary_result": + return true + case "binary_parameters": + return true default: return false @@ -983,12 +1083,15 @@ func (cn *conn) startup(o values) { w.string(v) } w.string("") - cn.send(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } for { t, r := cn.recv() switch t { case 'K': + cn.processBackendKeyData(r) case 'S': cn.processParameterStatus(r) case 'R': @@ -1008,7 +1111,7 @@ func (cn *conn) auth(r *readBuf, o values) { // OK case 3: w := cn.writeBuf('p') - w.string(o.Get("password")) + w.string(o["password"]) cn.send(w) t, r := cn.recv() @@ -1022,7 +1125,7 @@ func (cn *conn) auth(r *readBuf, o values) { case 5: s := string(r.next(4)) w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) + w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) cn.send(w) t, r := cn.recv() @@ -1038,13 +1141,26 @@ func (cn *conn) auth(r *readBuf, o values) { } } +type format int + +const formatText format = 0 +const formatBinary format = 1 + +// One result-column format code with the value 1 (i.e. all binary). +var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} + +// No result-column format codes (i.e. all text). +var colFmtDataAllText []byte = []byte{0, 0} + type stmt struct { - cn *conn - name string - cols []string - rowTyps []oid.Oid - paramTyps []oid.Oid - closed bool + cn *conn + name string + colNames []string + colFmts []format + colFmtData []byte + colTyps []oid.Oid + paramTyps []oid.Oid + closed bool } func (st *stmt) Close() (err error) { @@ -1087,7 +1203,12 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { defer st.cn.errRecover(&err) st.exec(v) - return &rows{st: st}, nil + return &rows{ + cn: st.cn, + colNames: st.colNames, + colTyps: st.colTyps, + colFmts: st.colFmts, + }, nil } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { @@ -1097,25 +1218,8 @@ func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { defer st.cn.errRecover(&err) st.exec(v) - - for { - t, r := st.cn.recv1() - switch t { - case 'E': - err = parseError(r) - case 'C': - res, _ = st.cn.parseComplete(r.string()) - case 'Z': - st.cn.processReadyForQuery(r) - // done - return - case 'T', 'D', 'I': - // ignore any results - default: - st.cn.bad = true - errorf("unknown exec response: %q", t) - } - } + res, _, err = st.cn.readExecuteResponse("simple query") + return res, err } func (st *stmt) exec(v []driver.Value) { @@ -1126,84 +1230,38 @@ func (st *stmt) exec(v []driver.Value) { errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) } - w := st.cn.writeBuf('B') - w.string("") + cn := st.cn + w := cn.writeBuf('B') + w.byte(0) // unnamed portal w.string(st.name) - w.int16(0) - w.int16(len(v)) - for i, x := range v { - if x == nil { - w.int32(-1) - } else { - b := encode(&st.cn.parameterStatus, x, st.paramTyps[i]) - w.int32(len(b)) - w.bytes(b) + + if cn.binaryParameters { + cn.sendBinaryParameters(w, v) + } else { + w.int16(0) + w.int16(len(v)) + for i, x := range v { + if x == nil { + w.int32(-1) + } else { + b := encode(&cn.parameterStatus, x, st.paramTyps[i]) + w.int32(len(b)) + w.bytes(b) + } } } - w.int16(0) - st.cn.send(w) + w.bytes(st.colFmtData) - w = st.cn.writeBuf('E') - w.string("") + w.next('E') + w.byte(0) w.int32(0) - st.cn.send(w) - st.cn.send(st.cn.writeBuf('S')) + w.next('S') + cn.send(w) - var err error - for { - t, r := st.cn.recv1() - switch t { - case 'E': - err = parseError(r) - case '2': - if err != nil { - panic(err) - } - goto workaround - case 'Z': - st.cn.processReadyForQuery(r) - if err != nil { - panic(err) - } - return - default: - st.cn.bad = true - errorf("unexpected bind response: %q", t) - } - } + cn.readBindResponse() + cn.postExecuteWorkaround() - // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores - // any errors from rows.Next, which masks errors that happened during the - // execution of the query. To avoid the problem in common cases, we wait - // here for one more message from the database. If it's not an error the - // query will likely succeed (or perhaps has already, if it's a - // CommandComplete), so we push the message into the conn struct; recv1 - // will return it as the next message for rows.Next or rows.Close. - // However, if it's an error, we wait until ReadyForQuery and then return - // the error to our caller. -workaround: - for { - t, r := st.cn.recv1() - switch t { - case 'E': - err = parseError(r) - case 'C', 'D', 'I': - // the query didn't fail, but we can't process this message - st.cn.saveMessage(t, r) - return - case 'Z': - if err == nil { - st.cn.bad = true - errorf("unexpected ReadyForQuery during extended query execution") - } - st.cn.processReadyForQuery(r) - panic(err) - default: - st.cn.bad = true - errorf("unexpected message during query execution: %q", t) - } - } } func (st *stmt) NumInput() int { @@ -1260,19 +1318,33 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } type rows struct { - st *stmt - done bool - rb readBuf + cn *conn + finish func() + colNames []string + colTyps []oid.Oid + colFmts []format + done bool + rb readBuf + result driver.Result + tag string } func (rs *rows) Close() error { + if finish := rs.finish; finish != nil { + defer finish() + } // no need to look at cn.bad as Next() will for { err := rs.Next(nil) switch err { case nil: case io.EOF: - return nil + // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row + // description, used with HasNextResultSet). We need to fetch messages until + // we hit a 'Z', which is done by waiting for done to be set. + if rs.done { + return nil + } default: return err } @@ -1280,7 +1352,18 @@ func (rs *rows) Close() error { } func (rs *rows) Columns() []string { - return rs.st.cols + return rs.colNames +} + +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag } func (rs *rows) Next(dest []driver.Value) (err error) { @@ -1288,7 +1371,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { return io.EOF } - conn := rs.st.cn + conn := rs.cn if conn.bad { return driver.ErrBadConn } @@ -1300,6 +1383,9 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'E': err = parseError(&rs.rb) case 'C', 'I': + if t == 'C' { + rs.result, rs.tag = conn.parseComplete(rs.rb.string()) + } continue case 'Z': conn.processReadyForQuery(&rs.rb) @@ -1310,6 +1396,10 @@ func (rs *rows) Next(dest []driver.Value) (err error) { return io.EOF case 'D': n := rs.rb.int16() + if err != nil { + conn.bad = true + errorf("unexpected DataRow after error %s", err) + } if n < len(dest) { dest = dest[:n] } @@ -1319,15 +1409,26 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.st.rowTyps[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) } return + case 'T': + rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb) + return io.EOF default: errorf("unexpected message after execute: %q", t) } } } +func (rs *rows) HasNextResultSet() bool { + return !rs.done +} + +func (rs *rows) NextResultSet() error { + return nil +} + // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be // used as part of an SQL statement. For example: // @@ -1352,6 +1453,68 @@ func md5s(s string) string { return fmt.Sprintf("%x", h.Sum(nil)) } +func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { + // Do one pass over the parameters to see if we're going to send any of + // them over in binary. If we are, create a paramFormats array at the + // same time. + var paramFormats []int + for i, x := range args { + _, ok := x.([]byte) + if ok { + if paramFormats == nil { + paramFormats = make([]int, len(args)) + } + paramFormats[i] = 1 + } + } + if paramFormats == nil { + b.int16(0) + } else { + b.int16(len(paramFormats)) + for _, x := range paramFormats { + b.int16(x) + } + } + + b.int16(len(args)) + for _, x := range args { + if x == nil { + b.int32(-1) + } else { + datum := binaryEncode(&cn.parameterStatus, x) + b.int32(len(datum)) + b.bytes(datum) + } + } +} + +func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { + if len(args) >= 65536 { + errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) + } + + b := cn.writeBuf('P') + b.byte(0) // unnamed statement + b.string(query) + b.int16(0) + + b.next('B') + b.int16(0) // unnamed portal and statement + cn.sendBinaryParameters(b, args) + b.bytes(colFmtDataAllText) + + b.next('D') + b.byte('P') + b.byte(0) // unnamed portal + + b.next('E') + b.byte(0) + b.int32(0) + + b.next('S') + cn.send(b) +} + func (c *conn) processParameterStatus(r *readBuf) { var err error @@ -1381,15 +1544,186 @@ func (c *conn) processReadyForQuery(r *readBuf) { c.txnStatus = transactionStatus(r.byte()) } -func parseMeta(r *readBuf) (cols []string, rowTyps []oid.Oid) { +func (cn *conn) readReadyForQuery() { + t, r := cn.recv1() + switch t { + case 'Z': + cn.processReadyForQuery(r) + return + default: + cn.bad = true + errorf("unexpected message %q; expected ReadyForQuery", t) + } +} + +func (c *conn) processBackendKeyData(r *readBuf) { + c.processID = r.int32() + c.secretKey = r.int32() +} + +func (cn *conn) readParseResponse() { + t, r := cn.recv1() + switch t { + case '1': + return + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + default: + cn.bad = true + errorf("unexpected Parse response %q", t) + } +} + +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) { + for { + t, r := cn.recv1() + switch t { + case 't': + nparams := r.int16() + paramTyps = make([]oid.Oid, nparams) + for i := range paramTyps { + paramTyps[i] = r.oid() + } + case 'n': + return paramTyps, nil, nil + case 'T': + colNames, colTyps = parseStatementRowDescribe(r) + return paramTyps, colNames, colTyps + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + default: + cn.bad = true + errorf("unexpected Describe statement response %q", t) + } + } +} + +func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) { + t, r := cn.recv1() + switch t { + case 'T': + return parsePortalRowDescribe(r) + case 'n': + return nil, nil, nil + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + default: + cn.bad = true + errorf("unexpected Describe response %q", t) + } + panic("not reached") +} + +func (cn *conn) readBindResponse() { + t, r := cn.recv1() + switch t { + case '2': + return + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + default: + cn.bad = true + errorf("unexpected Bind response %q", t) + } +} + +func (cn *conn) postExecuteWorkaround() { + // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores + // any errors from rows.Next, which masks errors that happened during the + // execution of the query. To avoid the problem in common cases, we wait + // here for one more message from the database. If it's not an error the + // query will likely succeed (or perhaps has already, if it's a + // CommandComplete), so we push the message into the conn struct; recv1 + // will return it as the next message for rows.Next or rows.Close. + // However, if it's an error, we wait until ReadyForQuery and then return + // the error to our caller. + for { + t, r := cn.recv1() + switch t { + case 'E': + err := parseError(r) + cn.readReadyForQuery() + panic(err) + case 'C', 'D', 'I': + // the query didn't fail, but we can't process this message + cn.saveMessage(t, r) + return + default: + cn.bad = true + errorf("unexpected message during extended query execution: %q", t) + } + } +} + +// Only for Exec(), since we ignore the returned data +func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { + for { + t, r := cn.recv1() + switch t { + case 'C': + if err != nil { + cn.bad = true + errorf("unexpected CommandComplete after error %s", err) + } + res, commandTag = cn.parseComplete(r.string()) + case 'Z': + cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } + return res, commandTag, err + case 'E': + err = parseError(r) + case 'T', 'D', 'I': + if err != nil { + cn.bad = true + errorf("unexpected %q after error %s", t, err) + } + if t == 'I' { + res = emptyRows + } + // ignore any results + default: + cn.bad = true + errorf("unknown %s response: %q", protocolState, t) + } + } +} + +func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) { n := r.int16() - cols = make([]string, n) - rowTyps = make([]oid.Oid, n) - for i := range cols { - cols[i] = r.string() + colNames = make([]string, n) + colTyps = make([]oid.Oid, n) + for i := range colNames { + colNames[i] = r.string() r.next(6) - rowTyps[i] = r.oid() - r.next(8) + colTyps[i] = r.oid() + r.next(6) + // format code not known when describing a statement; always 0 + r.next(2) + } + return +} + +func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) { + n := r.int16() + colNames = make([]string, n) + colFmts = make([]format, n) + colTyps = make([]oid.Oid, n) + for i := range colNames { + colNames[i] = r.string() + r.next(6) + colTyps[i] = r.oid() + r.next(6) + colFmts[i] = format(r.int16()) } return } @@ -1434,7 +1768,7 @@ func parseEnviron(env []string) (out map[string]string) { accrue("user") case "PGPASSWORD": accrue("password") - case "PGPASSFILE", "PGSERVICE", "PGSERVICEFILE", "PGREALM": + case "PGSERVICE", "PGSERVICEFILE", "PGREALM": unsupported() case "PGOPTIONS": accrue("options") diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go new file mode 100644 index 00000000000..ab97a104d83 --- /dev/null +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -0,0 +1,128 @@ +// +build go1.8 + +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "io/ioutil" +) + +// Implement the "QueryerContext" interface +func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + finish := cn.watchCancel(ctx) + r, err := cn.query(query, list) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil +} + +// Implement the "ExecerContext" interface +func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + + return cn.Exec(query, list) +} + +// Implement the "ConnBeginTx" interface +func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var mode string + + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + // Don't touch mode: use the server's default + case sql.LevelReadUncommitted: + mode = " ISOLATION LEVEL READ UNCOMMITTED" + case sql.LevelReadCommitted: + mode = " ISOLATION LEVEL READ COMMITTED" + case sql.LevelRepeatableRead: + mode = " ISOLATION LEVEL REPEATABLE READ" + case sql.LevelSerializable: + mode = " ISOLATION LEVEL SERIALIZABLE" + default: + return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) + } + + if opts.ReadOnly { + mode += " READ ONLY" + } else { + mode += " READ WRITE" + } + + tx, err := cn.begin(mode) + if err != nil { + return nil, err + } + cn.txnFinish = cn.watchCancel(ctx) + return tx, nil +} + +func (cn *conn) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}) + go func() { + select { + case <-done: + _ = cn.cancel() + finished <- struct{}{} + case <-finished: + } + }() + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } + } + } + return nil +} + +func (cn *conn) cancel() error { + c, err := dial(cn.dialer, cn.opts) + if err != nil { + return err + } + defer c.Close() + + { + can := conn{ + c: c, + } + can.ssl(cn.opts) + + w := can.writeBuf(0) + w.int32(80877102) // cancel request code + w.int32(cn.processID) + w.int32(cn.secretKey) + + if err := can.sendStartupPacket(w); err != nil { + return err + } + } + + // Read until EOF to ensure that the server received the cancel. + { + _, err := io.Copy(ioutil.Discard, c) + return err + } +} diff --git a/vendor/github.com/lib/pq/conn_test.go b/vendor/github.com/lib/pq/conn_test.go deleted file mode 100644 index 6c3c6b54629..00000000000 --- a/vendor/github.com/lib/pq/conn_test.go +++ /dev/null @@ -1,1286 +0,0 @@ -package pq - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "io" - "os" - "reflect" - "testing" - "time" -) - -type Fatalistic interface { - Fatal(args ...interface{}) -} - -func openTestConnConninfo(conninfo string) (*sql.DB, error) { - datname := os.Getenv("PGDATABASE") - sslmode := os.Getenv("PGSSLMODE") - timeout := os.Getenv("PGCONNECT_TIMEOUT") - - if datname == "" { - os.Setenv("PGDATABASE", "pqgotest") - } - - if sslmode == "" { - os.Setenv("PGSSLMODE", "disable") - } - - if timeout == "" { - os.Setenv("PGCONNECT_TIMEOUT", "20") - } - - return sql.Open("postgres", conninfo) -} - -func openTestConn(t Fatalistic) *sql.DB { - conn, err := openTestConnConninfo("") - if err != nil { - t.Fatal(err) - } - - return conn -} - -func getServerVersion(t *testing.T, db *sql.DB) int { - var version int - err := db.QueryRow("SHOW server_version_num").Scan(&version) - if err != nil { - t.Fatal(err) - } - return version -} - -func TestReconnect(t *testing.T) { - db1 := openTestConn(t) - defer db1.Close() - tx, err := db1.Begin() - if err != nil { - t.Fatal(err) - } - var pid1 int - err = tx.QueryRow("SELECT pg_backend_pid()").Scan(&pid1) - if err != nil { - t.Fatal(err) - } - db2 := openTestConn(t) - defer db2.Close() - _, err = db2.Exec("SELECT pg_terminate_backend($1)", pid1) - if err != nil { - t.Fatal(err) - } - // The rollback will probably "fail" because we just killed - // its connection above - _ = tx.Rollback() - - const expected int = 42 - var result int - err = db1.QueryRow(fmt.Sprintf("SELECT %d", expected)).Scan(&result) - if err != nil { - t.Fatal(err) - } - if result != expected { - t.Errorf("got %v; expected %v", result, expected) - } -} - -func TestCommitInFailedTransaction(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - rows, err := txn.Query("SELECT error") - if err == nil { - rows.Close() - t.Fatal("expected failure") - } - err = txn.Commit() - if err != ErrInFailedTransaction { - t.Fatalf("expected ErrInFailedTransaction; got %#v", err) - } -} - -func TestOpenURL(t *testing.T) { - db, err := openTestConnConninfo("postgres://") - if err != nil { - t.Fatal(err) - } - defer db.Close() - // database/sql might not call our Open at all unless we do something with - // the connection - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - txn.Rollback() -} - -func TestExec(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("CREATE TEMP TABLE temp (a int)") - if err != nil { - t.Fatal(err) - } - - r, err := db.Exec("INSERT INTO temp VALUES (1)") - if err != nil { - t.Fatal(err) - } - - if n, _ := r.RowsAffected(); n != 1 { - t.Fatalf("expected 1 row affected, not %d", n) - } - - r, err = db.Exec("INSERT INTO temp VALUES ($1), ($2), ($3)", 1, 2, 3) - if err != nil { - t.Fatal(err) - } - - if n, _ := r.RowsAffected(); n != 3 { - t.Fatalf("expected 3 rows affected, not %d", n) - } - - // SELECT doesn't send the number of returned rows in the command tag - // before 9.0 - if getServerVersion(t, db) >= 90000 { - r, err = db.Exec("SELECT g FROM generate_series(1, 2) g") - if err != nil { - t.Fatal(err) - } - if n, _ := r.RowsAffected(); n != 2 { - t.Fatalf("expected 2 rows affected, not %d", n) - } - - r, err = db.Exec("SELECT g FROM generate_series(1, $1) g", 3) - if err != nil { - t.Fatal(err) - } - if n, _ := r.RowsAffected(); n != 3 { - t.Fatalf("expected 3 rows affected, not %d", n) - } - } -} - -func TestStatment(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - st, err := db.Prepare("SELECT 1") - if err != nil { - t.Fatal(err) - } - - st1, err := db.Prepare("SELECT 2") - if err != nil { - t.Fatal(err) - } - - r, err := st.Query() - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if !r.Next() { - t.Fatal("expected row") - } - - var i int - err = r.Scan(&i) - if err != nil { - t.Fatal(err) - } - - if i != 1 { - t.Fatalf("expected 1, got %d", i) - } - - // st1 - - r1, err := st1.Query() - if err != nil { - t.Fatal(err) - } - defer r1.Close() - - if !r1.Next() { - if r.Err() != nil { - t.Fatal(r1.Err()) - } - t.Fatal("expected row") - } - - err = r1.Scan(&i) - if err != nil { - t.Fatal(err) - } - - if i != 2 { - t.Fatalf("expected 2, got %d", i) - } -} - -func TestRowsCloseBeforeDone(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - r, err := db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - - err = r.Close() - if err != nil { - t.Fatal(err) - } - - if r.Next() { - t.Fatal("unexpected row") - } - - if r.Err() != nil { - t.Fatal(r.Err()) - } -} - -func TestParameterCountMismatch(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - var notused int - err := db.QueryRow("SELECT false", 1).Scan(¬used) - if err == nil { - t.Fatal("expected err") - } - // make sure we clean up correctly - err = db.QueryRow("SELECT 1").Scan(¬used) - if err != nil { - t.Fatal(err) - } - - err = db.QueryRow("SELECT $1").Scan(¬used) - if err == nil { - t.Fatal("expected err") - } - // make sure we clean up correctly - err = db.QueryRow("SELECT 1").Scan(¬used) - if err != nil { - t.Fatal(err) - } -} - -// Test that EmptyQueryResponses are handled correctly. -func TestEmptyQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("") - if err != nil { - t.Fatal(err) - } - rows, err := db.Query("") - if err != nil { - t.Fatal(err) - } - cols, err := rows.Columns() - if err != nil { - t.Fatal(err) - } - if len(cols) != 0 { - t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) - } - if rows.Next() { - t.Fatal("unexpected row") - } - if rows.Err() != nil { - t.Fatal(rows.Err()) - } - - stmt, err := db.Prepare("") - if err != nil { - t.Fatal(err) - } - _, err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - rows, err = stmt.Query() - if err != nil { - t.Fatal(err) - } - cols, err = rows.Columns() - if err != nil { - t.Fatal(err) - } - if len(cols) != 0 { - t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) - } - if rows.Next() { - t.Fatal("unexpected row") - } - if rows.Err() != nil { - t.Fatal(rows.Err()) - } -} - -func TestEncodeDecode(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - q := ` - SELECT - E'\\000\\001\\002'::bytea, - 'foobar'::text, - NULL::integer, - '2000-1-1 01:02:03.04-7'::timestamptz, - 0::boolean, - 123, - 3.14::float8 - WHERE - E'\\000\\001\\002'::bytea = $1 - AND 'foobar'::text = $2 - AND $3::integer is NULL - ` - // AND '2000-1-1 12:00:00.000000-7'::timestamp = $3 - - exp1 := []byte{0, 1, 2} - exp2 := "foobar" - - r, err := db.Query(q, exp1, exp2, nil) - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if !r.Next() { - if r.Err() != nil { - t.Fatal(r.Err()) - } - t.Fatal("expected row") - } - - var got1 []byte - var got2 string - var got3 = sql.NullInt64{Valid: true} - var got4 time.Time - var got5, got6, got7 interface{} - - err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(exp1, got1) { - t.Errorf("expected %q byte: %q", exp1, got1) - } - - if !reflect.DeepEqual(exp2, got2) { - t.Errorf("expected %q byte: %q", exp2, got2) - } - - if got3.Valid { - t.Fatal("expected invalid") - } - - if got4.Year() != 2000 { - t.Fatal("wrong year") - } - - if got5 != false { - t.Fatalf("expected false, got %q", got5) - } - - if got6 != int64(123) { - t.Fatalf("expected 123, got %d", got6) - } - - if got7 != float64(3.14) { - t.Fatalf("expected 3.14, got %f", got7) - } -} - -func TestNoData(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - st, err := db.Prepare("SELECT 1 WHERE true = false") - if err != nil { - t.Fatal(err) - } - defer st.Close() - - r, err := st.Query() - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if r.Next() { - if r.Err() != nil { - t.Fatal(r.Err()) - } - t.Fatal("unexpected row") - } - - _, err = db.Query("SELECT * FROM nonexistenttable WHERE age=$1", 20) - if err == nil { - t.Fatal("Should have raised an error on non existent table") - } - - _, err = db.Query("SELECT * FROM nonexistenttable") - if err == nil { - t.Fatal("Should have raised an error on non existent table") - } -} - -func TestErrorDuringStartup(t *testing.T) { - // Don't use the normal connection setup, this is intended to - // blow up in the startup packet from a non-existent user. - db, err := openTestConnConninfo("user=thisuserreallydoesntexist") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - _, err = db.Begin() - if err == nil { - t.Fatal("expected error") - } - - e, ok := err.(*Error) - if !ok { - t.Fatalf("expected Error, got %#v", err) - } else if e.Code.Name() != "invalid_authorization_specification" && e.Code.Name() != "invalid_password" { - t.Fatalf("expected invalid_authorization_specification or invalid_password, got %s (%+v)", e.Code.Name(), err) - } -} - -func TestBadConn(t *testing.T) { - var err error - - cn := conn{} - func() { - defer cn.errRecover(&err) - panic(io.EOF) - }() - if err != driver.ErrBadConn { - t.Fatalf("expected driver.ErrBadConn, got: %#v", err) - } - if !cn.bad { - t.Fatalf("expected cn.bad") - } - - cn = conn{} - func() { - defer cn.errRecover(&err) - e := &Error{Severity: Efatal} - panic(e) - }() - if err != driver.ErrBadConn { - t.Fatalf("expected driver.ErrBadConn, got: %#v", err) - } - if !cn.bad { - t.Fatalf("expected cn.bad") - } -} - -func TestErrorOnExec(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)") - if err != nil { - t.Fatal(err) - } - - _, err = txn.Exec("INSERT INTO foo VALUES (0), (0)") - if err == nil { - t.Fatal("Should have raised error") - } - - e, ok := err.(*Error) - if !ok { - t.Fatalf("expected Error, got %#v", err) - } else if e.Code.Name() != "unique_violation" { - t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err) - } -} - -func TestErrorOnQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)") - if err != nil { - t.Fatal(err) - } - - _, err = txn.Query("INSERT INTO foo VALUES (0), (0)") - if err == nil { - t.Fatal("Should have raised error") - } - - e, ok := err.(*Error) - if !ok { - t.Fatalf("expected Error, got %#v", err) - } else if e.Code.Name() != "unique_violation" { - t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err) - } -} - -func TestErrorOnQueryRowSimpleQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)") - if err != nil { - t.Fatal(err) - } - - var v int - err = txn.QueryRow("INSERT INTO foo VALUES (0), (0)").Scan(&v) - if err == nil { - t.Fatal("Should have raised error") - } - - e, ok := err.(*Error) - if !ok { - t.Fatalf("expected Error, got %#v", err) - } else if e.Code.Name() != "unique_violation" { - t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err) - } -} - -// Test the QueryRow bug workarounds in stmt.exec() and simpleQuery() -func TestQueryRowBugWorkaround(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - // stmt.exec() - _, err := db.Exec("CREATE TEMP TABLE notnulltemp (a varchar(10) not null)") - if err != nil { - t.Fatal(err) - } - - var a string - err = db.QueryRow("INSERT INTO notnulltemp(a) values($1) RETURNING a", nil).Scan(&a) - if err == sql.ErrNoRows { - t.Fatalf("expected constraint violation error; got: %v", err) - } - pge, ok := err.(*Error) - if !ok { - t.Fatalf("expected *Error; got: %#v", err) - } - if pge.Code.Name() != "not_null_violation" { - t.Fatalf("expected not_null_violation; got: %s (%+v)", pge.Code.Name(), err) - } - - // Test workaround in simpleQuery() - tx, err := db.Begin() - if err != nil { - t.Fatalf("unexpected error %s in Begin", err) - } - defer tx.Rollback() - - _, err = tx.Exec("SET LOCAL check_function_bodies TO FALSE") - if err != nil { - t.Fatalf("could not disable check_function_bodies: %s", err) - } - _, err = tx.Exec(` -CREATE OR REPLACE FUNCTION bad_function() -RETURNS integer --- hack to prevent the function from being inlined -SET check_function_bodies TO TRUE -AS $$ - SELECT text 'bad' -$$ LANGUAGE sql`) - if err != nil { - t.Fatalf("could not create function: %s", err) - } - - err = tx.QueryRow("SELECT * FROM bad_function()").Scan(&a) - if err == nil { - t.Fatalf("expected error") - } - pge, ok = err.(*Error) - if !ok { - t.Fatalf("expected *Error; got: %#v", err) - } - if pge.Code.Name() != "invalid_function_definition" { - t.Fatalf("expected invalid_function_definition; got: %s (%+v)", pge.Code.Name(), err) - } - - err = tx.Rollback() - if err != nil { - t.Fatalf("unexpected error %s in Rollback", err) - } - - // Also test that simpleQuery()'s workaround works when the query fails - // after a row has been received. - rows, err := db.Query(` -select - (select generate_series(1, ss.i)) -from (select gs.i - from generate_series(1, 2) gs(i) - order by gs.i limit 2) ss`) - if err != nil { - t.Fatalf("query failed: %s", err) - } - if !rows.Next() { - t.Fatalf("expected at least one result row; got %s", rows.Err()) - } - var i int - err = rows.Scan(&i) - if err != nil { - t.Fatalf("rows.Scan() failed: %s", err) - } - if i != 1 { - t.Fatalf("unexpected value for i: %d", i) - } - if rows.Next() { - t.Fatalf("unexpected row") - } - pge, ok = rows.Err().(*Error) - if !ok { - t.Fatalf("expected *Error; got: %#v", err) - } - if pge.Code.Name() != "cardinality_violation" { - t.Fatalf("expected cardinality_violation; got: %s (%+v)", pge.Code.Name(), rows.Err()) - } -} - -func TestSimpleQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - r, err := db.Query("select 1") - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if !r.Next() { - t.Fatal("expected row") - } -} - -func TestBindError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("create temp table test (i integer)") - if err != nil { - t.Fatal(err) - } - - _, err = db.Query("select * from test where i=$1", "hhh") - if err == nil { - t.Fatal("expected an error") - } - - // Should not get error here - r, err := db.Query("select * from test where i=$1", 1) - if err != nil { - t.Fatal(err) - } - defer r.Close() -} - -func TestParseErrorInExtendedQuery(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - rows, err := db.Query("PARSE_ERROR $1", 1) - if err == nil { - t.Fatal("expected error") - } - - rows, err = db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - rows.Close() -} - -// TestReturning tests that an INSERT query using the RETURNING clause returns a row. -func TestReturning(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("CREATE TEMP TABLE distributors (did integer default 0, dname text)") - if err != nil { - t.Fatal(err) - } - - rows, err := db.Query("INSERT INTO distributors (did, dname) VALUES (DEFAULT, 'XYZ Widgets') " + - "RETURNING did;") - if err != nil { - t.Fatal(err) - } - if !rows.Next() { - t.Fatal("no rows") - } - var did int - err = rows.Scan(&did) - if err != nil { - t.Fatal(err) - } - if did != 0 { - t.Fatalf("bad value for did: got %d, want %d", did, 0) - } - - if rows.Next() { - t.Fatal("unexpected next row") - } - err = rows.Err() - if err != nil { - t.Fatal(err) - } -} - -func TestIssue186(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - // Exec() a query which returns results - _, err := db.Exec("VALUES (1), (2), (3)") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("VALUES ($1), ($2), ($3)", 1, 2, 3) - if err != nil { - t.Fatal(err) - } - - // Query() a query which doesn't return any results - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - rows, err := txn.Query("CREATE TEMP TABLE foo(f1 int)") - if err != nil { - t.Fatal(err) - } - if err = rows.Close(); err != nil { - t.Fatal(err) - } - - // small trick to get NoData from a parameterized query - _, err = txn.Exec("CREATE RULE nodata AS ON INSERT TO foo DO INSTEAD NOTHING") - if err != nil { - t.Fatal(err) - } - rows, err = txn.Query("INSERT INTO foo VALUES ($1)", 1) - if err != nil { - t.Fatal(err) - } - if err = rows.Close(); err != nil { - t.Fatal(err) - } -} - -func TestIssue196(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - row := db.QueryRow("SELECT float4 '0.10000122' = $1, float8 '35.03554004971999' = $2", - float32(0.10000122), float64(35.03554004971999)) - - var float4match, float8match bool - err := row.Scan(&float4match, &float8match) - if err != nil { - t.Fatal(err) - } - if !float4match { - t.Errorf("Expected float4 fidelity to be maintained; got no match") - } - if !float8match { - t.Errorf("Expected float8 fidelity to be maintained; got no match") - } -} - -// Test that any CommandComplete messages sent before the query results are -// ignored. -func TestIssue282(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - var search_path string - err := db.QueryRow(` - SET LOCAL search_path TO pg_catalog; - SET LOCAL search_path TO pg_catalog; - SHOW search_path`).Scan(&search_path) - if err != nil { - t.Fatal(err) - } - if search_path != "pg_catalog" { - t.Fatalf("unexpected search_path %s", search_path) - } -} - -func TestReadFloatPrecision(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - row := db.QueryRow("SELECT float4 '0.10000122', float8 '35.03554004971999'") - var float4val float32 - var float8val float64 - err := row.Scan(&float4val, &float8val) - if err != nil { - t.Fatal(err) - } - if float4val != float32(0.10000122) { - t.Errorf("Expected float4 fidelity to be maintained; got no match") - } - if float8val != float64(35.03554004971999) { - t.Errorf("Expected float8 fidelity to be maintained; got no match") - } -} - -func TestXactMultiStmt(t *testing.T) { - // minified test case based on bug reports from - // pico303@gmail.com and rangelspam@gmail.com - t.Skip("Skipping failing test") - db := openTestConn(t) - defer db.Close() - - tx, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer tx.Commit() - - rows, err := tx.Query("select 1") - if err != nil { - t.Fatal(err) - } - - if rows.Next() { - var val int32 - if err = rows.Scan(&val); err != nil { - t.Fatal(err) - } - } else { - t.Fatal("Expected at least one row in first query in xact") - } - - rows2, err := tx.Query("select 2") - if err != nil { - t.Fatal(err) - } - - if rows2.Next() { - var val2 int32 - if err := rows2.Scan(&val2); err != nil { - t.Fatal(err) - } - } else { - t.Fatal("Expected at least one row in second query in xact") - } - - if err = rows.Err(); err != nil { - t.Fatal(err) - } - - if err = rows2.Err(); err != nil { - t.Fatal(err) - } - - if err = tx.Commit(); err != nil { - t.Fatal(err) - } -} - -var envParseTests = []struct { - Expected map[string]string - Env []string -}{ - { - Env: []string{"PGDATABASE=hello", "PGUSER=goodbye"}, - Expected: map[string]string{"dbname": "hello", "user": "goodbye"}, - }, - { - Env: []string{"PGDATESTYLE=ISO, MDY"}, - Expected: map[string]string{"datestyle": "ISO, MDY"}, - }, - { - Env: []string{"PGCONNECT_TIMEOUT=30"}, - Expected: map[string]string{"connect_timeout": "30"}, - }, -} - -func TestParseEnviron(t *testing.T) { - for i, tt := range envParseTests { - results := parseEnviron(tt.Env) - if !reflect.DeepEqual(tt.Expected, results) { - t.Errorf("%d: Expected: %#v Got: %#v", i, tt.Expected, results) - } - } -} - -func TestParseComplete(t *testing.T) { - tpc := func(commandTag string, command string, affectedRows int64, shouldFail bool) { - defer func() { - if p := recover(); p != nil { - if !shouldFail { - t.Error(p) - } - } - }() - cn := &conn{} - res, c := cn.parseComplete(commandTag) - if c != command { - t.Errorf("Expected %v, got %v", command, c) - } - n, err := res.RowsAffected() - if err != nil { - t.Fatal(err) - } - if n != affectedRows { - t.Errorf("Expected %d, got %d", affectedRows, n) - } - } - - tpc("ALTER TABLE", "ALTER TABLE", 0, false) - tpc("INSERT 0 1", "INSERT", 1, false) - tpc("UPDATE 100", "UPDATE", 100, false) - tpc("SELECT 100", "SELECT", 100, false) - tpc("FETCH 100", "FETCH", 100, false) - // allow COPY (and others) without row count - tpc("COPY", "COPY", 0, false) - // don't fail on command tags we don't recognize - tpc("UNKNOWNCOMMANDTAG", "UNKNOWNCOMMANDTAG", 0, false) - - // failure cases - tpc("INSERT 1", "", 0, true) // missing oid - tpc("UPDATE 0 1", "", 0, true) // too many numbers - tpc("SELECT foo", "", 0, true) // invalid row count -} - -func TestExecerInterface(t *testing.T) { - // Gin up a straw man private struct just for the type check - cn := &conn{c: nil} - var cni interface{} = cn - - _, ok := cni.(driver.Execer) - if !ok { - t.Fatal("Driver doesn't implement Execer") - } -} - -func TestNullAfterNonNull(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - r, err := db.Query("SELECT 9::integer UNION SELECT NULL::integer") - if err != nil { - t.Fatal(err) - } - - var n sql.NullInt64 - - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected row") - } - - if err := r.Scan(&n); err != nil { - t.Fatal(err) - } - - if n.Int64 != 9 { - t.Fatalf("expected 2, not %d", n.Int64) - } - - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected row") - } - - if err := r.Scan(&n); err != nil { - t.Fatal(err) - } - - if n.Valid { - t.Fatal("expected n to be invalid") - } - - if n.Int64 != 0 { - t.Fatalf("expected n to 2, not %d", n.Int64) - } -} - -func Test64BitErrorChecking(t *testing.T) { - defer func() { - if err := recover(); err != nil { - t.Fatal("panic due to 0xFFFFFFFF != -1 " + - "when int is 64 bits") - } - }() - - db := openTestConn(t) - defer db.Close() - - r, err := db.Query(`SELECT * -FROM (VALUES (0::integer, NULL::text), (1, 'test string')) AS t;`) - - if err != nil { - t.Fatal(err) - } - - defer r.Close() - - for r.Next() { - } -} - -func TestCommit(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Exec("CREATE TEMP TABLE temp (a int)") - if err != nil { - t.Fatal(err) - } - sqlInsert := "INSERT INTO temp VALUES (1)" - sqlSelect := "SELECT * FROM temp" - tx, err := db.Begin() - if err != nil { - t.Fatal(err) - } - _, err = tx.Exec(sqlInsert) - if err != nil { - t.Fatal(err) - } - err = tx.Commit() - if err != nil { - t.Fatal(err) - } - var i int - err = db.QueryRow(sqlSelect).Scan(&i) - if err != nil { - t.Fatal(err) - } - if i != 1 { - t.Fatalf("expected 1, got %d", i) - } -} - -func TestErrorClass(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Query("SELECT int 'notint'") - if err == nil { - t.Fatal("expected error") - } - pge, ok := err.(*Error) - if !ok { - t.Fatalf("expected *pq.Error, got %#+v", err) - } - if pge.Code.Class() != "22" { - t.Fatalf("expected class 28, got %v", pge.Code.Class()) - } - if pge.Code.Class().Name() != "data_exception" { - t.Fatalf("expected data_exception, got %v", pge.Code.Class().Name()) - } -} - -func TestParseOpts(t *testing.T) { - tests := []struct { - in string - expected values - valid bool - }{ - {"dbname=hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname=hello user=goodbye ", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname = hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname=hello user =goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname=hello user= goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"host=localhost password='correct horse battery staple'", values{"host": "localhost", "password": "correct horse battery staple"}, true}, - {"dbname=データベース password=パスワード", values{"dbname": "データベース", "password": "パスワード"}, true}, - {"dbname=hello user=''", values{"dbname": "hello", "user": ""}, true}, - {"user='' dbname=hello", values{"dbname": "hello", "user": ""}, true}, - // The last option value is an empty string if there's no non-whitespace after its = - {"dbname=hello user= ", values{"dbname": "hello", "user": ""}, true}, - - // The parser ignores spaces after = and interprets the next set of non-whitespace characters as the value. - {"user= password=foo", values{"user": "password=foo"}, true}, - - // Backslash escapes next char - {`user=a\ \'\\b`, values{"user": `a '\b`}, true}, - {`user='a \'b'`, values{"user": `a 'b`}, true}, - - // Incomplete escape - {`user=x\`, values{}, false}, - - // No '=' after the key - {"postgre://marko@internet", values{}, false}, - {"dbname user=goodbye", values{}, false}, - {"user=foo blah", values{}, false}, - {"user=foo blah ", values{}, false}, - - // Unterminated quoted value - {"dbname=hello user='unterminated", values{}, false}, - } - - for _, test := range tests { - o := make(values) - err := parseOpts(test.in, o) - - switch { - case err != nil && test.valid: - t.Errorf("%q got unexpected error: %s", test.in, err) - case err == nil && test.valid && !reflect.DeepEqual(test.expected, o): - t.Errorf("%q got: %#v want: %#v", test.in, o, test.expected) - case err == nil && !test.valid: - t.Errorf("%q expected an error", test.in) - } - } -} - -func TestRuntimeParameters(t *testing.T) { - type RuntimeTestResult int - const ( - ResultUnknown RuntimeTestResult = iota - ResultSuccess - ResultError // other error - ) - - tests := []struct { - conninfo string - param string - expected string - expectedOutcome RuntimeTestResult - }{ - // invalid parameter - {"DOESNOTEXIST=foo", "", "", ResultError}, - // we can only work with a specific value for these two - {"client_encoding=SQL_ASCII", "", "", ResultError}, - {"datestyle='ISO, YDM'", "", "", ResultError}, - // "options" should work exactly as it does in libpq - {"options='-c search_path=pqgotest'", "search_path", "pqgotest", ResultSuccess}, - // pq should override client_encoding in this case - {"options='-c client_encoding=SQL_ASCII'", "client_encoding", "UTF8", ResultSuccess}, - // allow client_encoding to be set explicitly - {"client_encoding=UTF8", "client_encoding", "UTF8", ResultSuccess}, - // test a runtime parameter not supported by libpq - {"work_mem='139kB'", "work_mem", "139kB", ResultSuccess}, - // test fallback_application_name - {"application_name=foo fallback_application_name=bar", "application_name", "foo", ResultSuccess}, - {"application_name='' fallback_application_name=bar", "application_name", "", ResultSuccess}, - {"fallback_application_name=bar", "application_name", "bar", ResultSuccess}, - } - - for _, test := range tests { - db, err := openTestConnConninfo(test.conninfo) - if err != nil { - t.Fatal(err) - } - - // application_name didn't exist before 9.0 - if test.param == "application_name" && getServerVersion(t, db) < 90000 { - db.Close() - continue - } - - tryGetParameterValue := func() (value string, outcome RuntimeTestResult) { - defer db.Close() - row := db.QueryRow("SELECT current_setting($1)", test.param) - err = row.Scan(&value) - if err != nil { - return "", ResultError - } - return value, ResultSuccess - } - - value, outcome := tryGetParameterValue() - if outcome != test.expectedOutcome && outcome == ResultError { - t.Fatalf("%v: unexpected error: %v", test.conninfo, err) - } - if outcome != test.expectedOutcome { - t.Fatalf("unexpected outcome %v (was expecting %v) for conninfo \"%s\"", - outcome, test.expectedOutcome, test.conninfo) - } - if value != test.expected { - t.Fatalf("bad value for %s: got %s, want %s with conninfo \"%s\"", - test.param, value, test.expected, test.conninfo) - } - } -} - -func TestIsUTF8(t *testing.T) { - var cases = []struct { - name string - want bool - }{ - {"unicode", true}, - {"utf-8", true}, - {"utf_8", true}, - {"UTF-8", true}, - {"UTF8", true}, - {"utf8", true}, - {"u n ic_ode", true}, - {"ut_f%8", true}, - {"ubf8", false}, - {"punycode", false}, - } - - for _, test := range cases { - if g := isUTF8(test.name); g != test.want { - t.Errorf("isUTF8(%q) = %v want %v", test.name, g, test.want) - } - } -} - -func TestQuoteIdentifier(t *testing.T) { - var cases = []struct { - input string - want string - }{ - {`foo`, `"foo"`}, - {`foo bar baz`, `"foo bar baz"`}, - {`foo"bar`, `"foo""bar"`}, - {"foo\x00bar", `"foo"`}, - {"\x00foo", `""`}, - } - - for _, test := range cases { - got := QuoteIdentifier(test.input) - if got != test.want { - t.Errorf("QuoteIdentifier(%q) = %v want %v", test.input, got, test.want) - } - } -} diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go index 18c04e7407e..345c2398f6c 100644 --- a/vendor/github.com/lib/pq/copy.go +++ b/vendor/github.com/lib/pq/copy.go @@ -13,6 +13,7 @@ var ( errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") errCopyToNotSupported = errors.New("pq: COPY TO is not supported") errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") + errCopyInProgress = errors.New("pq: COPY in progress") ) // CopyIn creates a COPY FROM statement which can be prepared with @@ -96,13 +97,13 @@ awaitCopyInResponse: err = parseError(r) case 'Z': if err == nil { - cn.bad = true + ci.setBad() errorf("unexpected ReadyForQuery in response to COPY") } cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for copy query: %q", t) } } @@ -121,7 +122,7 @@ awaitCopyInResponse: cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for CopyFail: %q", t) } } @@ -142,7 +143,7 @@ func (ci *copyin) resploop() { var r readBuf t, err := ci.cn.recvMessage(&r) if err != nil { - ci.cn.bad = true + ci.setBad() ci.setError(err) ci.done <- true return @@ -150,6 +151,8 @@ func (ci *copyin) resploop() { switch t { case 'C': // complete + case 'N': + // NoticeResponse case 'Z': ci.cn.processReadyForQuery(&r) ci.done <- true @@ -158,7 +161,7 @@ func (ci *copyin) resploop() { err := parseError(&r) ci.setError(err) default: - ci.cn.bad = true + ci.setBad() ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) ci.done <- true return @@ -166,6 +169,19 @@ func (ci *copyin) resploop() { } } +func (ci *copyin) setBad() { + ci.Lock() + ci.cn.bad = true + ci.Unlock() +} + +func (ci *copyin) isBad() bool { + ci.Lock() + b := ci.cn.bad + ci.Unlock() + return b +} + func (ci *copyin) isErrorSet() bool { ci.Lock() isSet := (ci.err != nil) @@ -203,7 +219,7 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { return nil, errCopyInClosed } - if ci.cn.bad { + if ci.isBad() { return nil, driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -213,9 +229,7 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { } if len(v) == 0 { - err = ci.Close() - ci.closed = true - return nil, err + return nil, ci.Close() } numValues := len(v) @@ -238,11 +252,12 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { } func (ci *copyin) Close() (err error) { - if ci.closed { - return errCopyInClosed + if ci.closed { // Don't do anything, we're already closed + return nil } + ci.closed = true - if ci.cn.bad { + if ci.isBad() { return driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -257,6 +272,7 @@ func (ci *copyin) Close() (err error) { } <-ci.done + ci.cn.inCopy = false if ci.isErrorSet() { err = ci.err diff --git a/vendor/github.com/lib/pq/copy_test.go b/vendor/github.com/lib/pq/copy_test.go deleted file mode 100644 index e489f22a86b..00000000000 --- a/vendor/github.com/lib/pq/copy_test.go +++ /dev/null @@ -1,380 +0,0 @@ -package pq - -import ( - "bytes" - "database/sql" - "strings" - "testing" -) - -func TestCopyInStmt(t *testing.T) { - var stmt string - stmt = CopyIn("table name") - if stmt != `COPY "table name" () FROM STDIN` { - t.Fatal(stmt) - } - - stmt = CopyIn("table name", "column 1", "column 2") - if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` { - t.Fatal(stmt) - } - - stmt = CopyIn(`table " name """`, `co"lumn""`) - if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` { - t.Fatal(stmt) - } -} - -func TestCopyInSchemaStmt(t *testing.T) { - var stmt string - stmt = CopyInSchema("schema name", "table name") - if stmt != `COPY "schema name"."table name" () FROM STDIN` { - t.Fatal(stmt) - } - - stmt = CopyInSchema("schema name", "table name", "column 1", "column 2") - if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` { - t.Fatal(stmt) - } - - stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`) - if stmt != `COPY "schema "" name """"""".`+ - `"table "" name """"""" ("co""lumn""""") FROM STDIN` { - t.Fatal(stmt) - } -} - -func TestCopyInMultipleValues(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") - if err != nil { - t.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) - if err != nil { - t.Fatal(err) - } - - longString := strings.Repeat("#", 500) - - for i := 0; i < 500; i++ { - _, err = stmt.Exec(int64(i), longString) - if err != nil { - t.Fatal(err) - } - } - - _, err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - - err = stmt.Close() - if err != nil { - t.Fatal(err) - } - - var num int - err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) - if err != nil { - t.Fatal(err) - } - - if num != 500 { - t.Fatalf("expected 500 items, not %d", num) - } -} - -func TestCopyInTypes(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)") - if err != nil { - t.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing")) - if err != nil { - t.Fatal(err) - } - - _, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil) - if err != nil { - t.Fatal(err) - } - - _, err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - - err = stmt.Close() - if err != nil { - t.Fatal(err) - } - - var num int - var text string - var blob []byte - var nothing sql.NullString - - err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, ¬hing) - if err != nil { - t.Fatal(err) - } - - if num != 1234567890 { - t.Fatal("unexpected result", num) - } - if text != "Héllö\n ☃!\r\t\\" { - t.Fatal("unexpected result", text) - } - if bytes.Compare(blob, []byte{0, 255, 9, 10, 13}) != 0 { - t.Fatal("unexpected result", blob) - } - if nothing.Valid { - t.Fatal("unexpected result", nothing.String) - } -} - -func TestCopyInWrongType(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") - if err != nil { - t.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "num")) - if err != nil { - t.Fatal(err) - } - defer stmt.Close() - - _, err = stmt.Exec("Héllö\n ☃!\r\t\\") - if err != nil { - t.Fatal(err) - } - - _, err = stmt.Exec() - if err == nil { - t.Fatal("expected error") - } - if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" { - t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge) - } -} - -func TestCopyOutsideOfTxnError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - _, err := db.Prepare(CopyIn("temp", "num")) - if err == nil { - t.Fatal("COPY outside of transaction did not return an error") - } - if err != errCopyNotSupportedOutsideTxn { - t.Fatalf("expected %s, got %s", err, err.Error()) - } -} - -func TestCopyInBinaryError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") - if err != nil { - t.Fatal(err) - } - _, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary") - if err != errBinaryCopyNotSupported { - t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err) - } - // check that the protocol is in a valid state - err = txn.Rollback() - if err != nil { - t.Fatal(err) - } -} - -func TestCopyFromError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") - if err != nil { - t.Fatal(err) - } - _, err = txn.Prepare("COPY temp (num) TO STDOUT") - if err != errCopyToNotSupported { - t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err) - } - // check that the protocol is in a valid state - err = txn.Rollback() - if err != nil { - t.Fatal(err) - } -} - -func TestCopySyntaxError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Prepare("COPY ") - if err == nil { - t.Fatal("expected error") - } - if pge := err.(*Error); pge.Code.Name() != "syntax_error" { - t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge) - } - // check that the protocol is in a valid state - err = txn.Rollback() - if err != nil { - t.Fatal(err) - } -} - -// Tests for connection errors in copyin.resploop() -func TestCopyRespLoopConnectionError(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - var pid int - err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid) - if err != nil { - t.Fatal(err) - } - - _, err = txn.Exec("CREATE TEMP TABLE temp (a int)") - if err != nil { - t.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "a")) - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("SELECT pg_terminate_backend($1)", pid) - if err != nil { - t.Fatal(err) - } - - // We have to try and send something over, since postgres won't process - // SIGTERMs while it's waiting for CopyData/CopyEnd messages; see - // tcop/postgres.c. - _, err = stmt.Exec(1) - if err != nil { - t.Fatal(err) - } - _, err = stmt.Exec() - if err == nil { - t.Fatalf("expected error") - } - pge, ok := err.(*Error) - if !ok { - t.Fatalf("expected *pq.Error, got %+#v", err) - } else if pge.Code.Name() != "admin_shutdown" { - t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name()) - } - - err = stmt.Close() - if err != nil { - t.Fatal(err) - } -} - -func BenchmarkCopyIn(b *testing.B) { - db := openTestConn(b) - defer db.Close() - - txn, err := db.Begin() - if err != nil { - b.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") - if err != nil { - b.Fatal(err) - } - - stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) - if err != nil { - b.Fatal(err) - } - - for i := 0; i < b.N; i++ { - _, err = stmt.Exec(int64(i), "hello world!") - if err != nil { - b.Fatal(err) - } - } - - _, err = stmt.Exec() - if err != nil { - b.Fatal(err) - } - - err = stmt.Close() - if err != nil { - b.Fatal(err) - } - - var num int - err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) - if err != nil { - b.Fatal(err) - } - - if num != b.N { - b.Fatalf("expected %d items, not %d", b.N, num) - } -} diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index 4d7a0e39935..6d252ecee21 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -5,8 +5,9 @@ In most cases clients will use the database/sql package instead of using this package directly. For example: import ( - _ "github.com/lib/pq" "database/sql" + + _ "github.com/lib/pq" ) func main() { @@ -85,9 +86,13 @@ variables not supported by pq are set, pq will panic during connection establishment. Environment variables have a lower precedence than explicitly provided connection parameters. +The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html +is supported, but on Windows PGPASSFILE must be specified explicitly. + Queries + database/sql does not dictate any specific format for parameter markers in query strings, and pq uses the Postgres-native ordinal markers, as shown above. The same marker can be reused for the same parameter: @@ -111,8 +116,29 @@ For more details on RETURNING, see the Postgres documentation: For additional instructions on querying see the documentation for the database/sql package. + +Data Types + + +Parameters pass through driver.DefaultParameterConverter before they are handled +by this package. When the binary_parameters connection option is enabled, +[]byte values are sent directly to the backend as data in binary format. + +This package returns the following types for values from the PostgreSQL backend: + + - integer types smallint, integer, and bigint are returned as int64 + - floating-point types real and double precision are returned as float64 + - character types char, varchar, and text are returned as string + - temporal types date, time, timetz, timestamp, and timestamptz are returned as time.Time + - the boolean type is returned as bool + - the bytea type is returned as []byte + +All other types are returned directly from the backend as []byte values in text format. + + Errors + pq may return errors of type *pq.Error which can be interrogated for error details: if err, ok := err.(*pq.Error); ok { diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go index 3887f72aa29..88a322cda82 100644 --- a/vendor/github.com/lib/pq/encode.go +++ b/vendor/github.com/lib/pq/encode.go @@ -3,24 +3,34 @@ package pq import ( "bytes" "database/sql/driver" + "encoding/binary" "encoding/hex" + "errors" "fmt" - "github.com/lib/pq/oid" "math" "strconv" "strings" "sync" "time" + + "github.com/lib/pq/oid" ) +func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { + switch v := x.(type) { + case []byte: + return v + default: + return encode(parameterStatus, x, oid.T_unknown) + } +} + func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { switch v := x.(type) { case int64: - return []byte(fmt.Sprintf("%d", v)) - case float32: - return []byte(fmt.Sprintf("%.9f", v)) + return strconv.AppendInt(nil, v, 10) case float64: - return []byte(fmt.Sprintf("%.17f", v)) + return strconv.AppendFloat(nil, v, 'f', -1, 64) case []byte: if pgtypOid == oid.T_bytea { return encodeBytea(parameterStatus.serverVersion, v) @@ -34,7 +44,7 @@ func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) [ return []byte(v) case bool: - return []byte(fmt.Sprintf("%t", v)) + return strconv.AppendBool(nil, v) case time.Time: return formatTs(v) @@ -45,10 +55,51 @@ func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) [ panic("not reached") } -func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { +func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { + switch f { + case formatBinary: + return binaryDecode(parameterStatus, s, typ) + case formatText: + return textDecode(parameterStatus, s, typ) + default: + panic("not reached") + } +} + +func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { switch typ { case oid.T_bytea: - return parseBytea(s) + return s + case oid.T_int8: + return int64(binary.BigEndian.Uint64(s)) + case oid.T_int4: + return int64(int32(binary.BigEndian.Uint32(s))) + case oid.T_int2: + return int64(int16(binary.BigEndian.Uint16(s))) + case oid.T_uuid: + b, err := decodeUUIDBinary(s) + if err != nil { + panic(err) + } + return b + + default: + errorf("don't know how to decode binary parameter of type %d", uint32(typ)) + } + + panic("not reached") +} + +func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { + switch typ { + case oid.T_char, oid.T_varchar, oid.T_text: + return string(s) + case oid.T_bytea: + b, err := parseBytea(s) + if err != nil { + errorf("%s", err) + } + return b case oid.T_timestamptz: return parseTs(parameterStatus.currentLocation, string(s)) case oid.T_timestamp, oid.T_date: @@ -59,7 +110,7 @@ func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} return mustParse("15:04:05-07", typ, s) case oid.T_bool: return s[0] == 't' - case oid.T_int8, oid.T_int2, oid.T_int4: + case oid.T_int8, oid.T_int4, oid.T_int2: i, err := strconv.ParseInt(string(s), 10, 64) if err != nil { errorf("%s", err) @@ -86,8 +137,6 @@ func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface switch v := x.(type) { case int64: return strconv.AppendInt(buf, v, 10) - case float32: - return strconv.AppendFloat(buf, float64(v), 'f', -1, 32) case float64: return strconv.AppendFloat(buf, v, 'f', -1, 64) case []byte: @@ -149,12 +198,6 @@ func appendEscapedText(buf []byte, text string) []byte { func mustParse(f string, typ oid.Oid, s []byte) time.Time { str := string(s) - // Special case until time.Parse bug is fixed: - // http://code.google.com/p/go/issues/detail?id=3487 - if str[len(str)-2] == '.' { - str += "0" - } - // check for a 30-minute-offset timezone if (typ == oid.T_timestamptz || typ == oid.T_timetz) && str[len(str)-3] == ':' { @@ -167,16 +210,39 @@ func mustParse(f string, typ oid.Oid, s []byte) time.Time { return t } -func expect(str, char string, pos int) { - if c := str[pos : pos+1]; c != char { - errorf("expected '%v' at position %v; got '%v'", char, pos, c) +var errInvalidTimestamp = errors.New("invalid timestamp") + +type timestampParser struct { + err error +} + +func (p *timestampParser) expect(str string, char byte, pos int) { + if p.err != nil { + return + } + if pos+1 > len(str) { + p.err = errInvalidTimestamp + return + } + if c := str[pos]; c != char && p.err == nil { + p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) } } -func mustAtoi(str string) int { - result, err := strconv.Atoi(str) +func (p *timestampParser) mustAtoi(str string, begin int, end int) int { + if p.err != nil { + return 0 + } + if begin < 0 || end < 0 || begin > end || end > len(str) { + p.err = errInvalidTimestamp + return 0 + } + result, err := strconv.Atoi(str[begin:end]) if err != nil { - errorf("expected number; got '%v'", str) + if p.err == nil { + p.err = fmt.Errorf("expected number; got '%v'", str) + } + return 0 } return result } @@ -191,7 +257,7 @@ type locationCache struct { // about 5% speed could be gained by putting the cache in the connection and // losing the mutex, at the cost of a small amount of memory and a somewhat // significant increase in code complexity. -var globalLocationCache *locationCache = newLocationCache() +var globalLocationCache = newLocationCache() func newLocationCache() *locationCache { return &locationCache{cache: make(map[int]*time.Location)} @@ -212,32 +278,106 @@ func (c *locationCache) getLocation(offset int) *time.Location { return location } +var infinityTsEnabled = false +var infinityTsNegative time.Time +var infinityTsPositive time.Time + +const ( + infinityTsEnabledAlready = "pq: infinity timestamp enabled already" + infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" +) + +// EnableInfinityTs controls the handling of Postgres' "-infinity" and +// "infinity" "timestamp"s. +// +// If EnableInfinityTs is not called, "-infinity" and "infinity" will return +// []byte("-infinity") and []byte("infinity") respectively, and potentially +// cause error "sql: Scan error on column index 0: unsupported driver -> Scan +// pair: []uint8 -> *time.Time", when scanning into a time.Time value. +// +// Once EnableInfinityTs has been called, all connections created using this +// driver will decode Postgres' "-infinity" and "infinity" for "timestamp", +// "timestamp with time zone" and "date" types to the predefined minimum and +// maximum times, respectively. When encoding time.Time values, any time which +// equals or precedes the predefined minimum time will be encoded to +// "-infinity". Any values at or past the maximum time will similarly be +// encoded to "infinity". +// +// If EnableInfinityTs is called with negative >= positive, it will panic. +// Calling EnableInfinityTs after a connection has been established results in +// undefined behavior. If EnableInfinityTs is called more than once, it will +// panic. +func EnableInfinityTs(negative time.Time, positive time.Time) { + if infinityTsEnabled { + panic(infinityTsEnabledAlready) + } + if !negative.Before(positive) { + panic(infinityTsNegativeMustBeSmaller) + } + infinityTsEnabled = true + infinityTsNegative = negative + infinityTsPositive = positive +} + +/* + * Testing might want to toggle infinityTsEnabled + */ +func disableInfinityTs() { + infinityTsEnabled = false +} + // This is a time function specific to the Postgres default DateStyle // setting ("ISO, MDY"), the only one we currently support. This // accounts for the discrepancies between the parsing available with // time.Parse and the Postgres date formatting quirks. -func parseTs(currentLocation *time.Location, str string) (result time.Time) { +func parseTs(currentLocation *time.Location, str string) interface{} { + switch str { + case "-infinity": + if infinityTsEnabled { + return infinityTsNegative + } + return []byte(str) + case "infinity": + if infinityTsEnabled { + return infinityTsPositive + } + return []byte(str) + } + t, err := ParseTimestamp(currentLocation, str) + if err != nil { + panic(err) + } + return t +} + +// ParseTimestamp parses Postgres' text format. It returns a time.Time in +// currentLocation iff that time's offset agrees with the offset sent from the +// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the +// fixed offset offset provided by the Postgres server. +func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { + p := timestampParser{} + monSep := strings.IndexRune(str, '-') // this is Gregorian year, not ISO Year // In Gregorian system, the year 1 BC is followed by AD 1 - year := mustAtoi(str[:monSep]) + year := p.mustAtoi(str, 0, monSep) daySep := monSep + 3 - month := mustAtoi(str[monSep+1 : daySep]) - expect(str, "-", daySep) + month := p.mustAtoi(str, monSep+1, daySep) + p.expect(str, '-', daySep) timeSep := daySep + 3 - day := mustAtoi(str[daySep+1 : timeSep]) + day := p.mustAtoi(str, daySep+1, timeSep) var hour, minute, second int if len(str) > monSep+len("01-01")+1 { - expect(str, " ", timeSep) + p.expect(str, ' ', timeSep) minSep := timeSep + 3 - expect(str, ":", minSep) - hour = mustAtoi(str[timeSep+1 : minSep]) + p.expect(str, ':', minSep) + hour = p.mustAtoi(str, timeSep+1, minSep) secSep := minSep + 3 - expect(str, ":", secSep) - minute = mustAtoi(str[minSep+1 : secSep]) + p.expect(str, ':', secSep) + minute = p.mustAtoi(str, minSep+1, secSep) secEnd := secSep + 3 - second = mustAtoi(str[secSep+1 : secEnd]) + second = p.mustAtoi(str, secSep+1, secEnd) } remainderIdx := monSep + len("01-01 00:00:00") + 1 // Three optional (but ordered) sections follow: the @@ -248,49 +388,50 @@ func parseTs(currentLocation *time.Location, str string) (result time.Time) { nanoSec := 0 tzOff := 0 - if remainderIdx < len(str) && str[remainderIdx:remainderIdx+1] == "." { + if remainderIdx < len(str) && str[remainderIdx] == '.' { fracStart := remainderIdx + 1 fracOff := strings.IndexAny(str[fracStart:], "-+ ") if fracOff < 0 { fracOff = len(str) - fracStart } - fracSec := mustAtoi(str[fracStart : fracStart+fracOff]) + fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) remainderIdx += fracOff + 1 } - if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart:tzStart+1] == "-" || str[tzStart:tzStart+1] == "+") { + if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { // time zone separator is always '-' or '+' (UTC is +00) var tzSign int - if c := str[tzStart : tzStart+1]; c == "-" { + switch c := str[tzStart]; c { + case '-': tzSign = -1 - } else if c == "+" { + case '+': tzSign = +1 - } else { - errorf("expected '-' or '+' at position %v; got %v", tzStart, c) + default: + return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) } - tzHours := mustAtoi(str[tzStart+1 : tzStart+3]) + tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) remainderIdx += 3 var tzMin, tzSec int - if tzStart+3 < len(str) && str[tzStart+3:tzStart+4] == ":" { - tzMin = mustAtoi(str[tzStart+4 : tzStart+6]) + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) remainderIdx += 3 } - if tzStart+6 < len(str) && str[tzStart+6:tzStart+7] == ":" { - tzSec = mustAtoi(str[tzStart+7 : tzStart+9]) + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) remainderIdx += 3 } tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) } var isoYear int - if remainderIdx < len(str) && str[remainderIdx:remainderIdx+3] == " BC" { + if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" { isoYear = 1 - year remainderIdx += 3 } else { isoYear = year } if remainderIdx < len(str) { - errorf("expected end of input, got %v", str[remainderIdx:]) + return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) } t := time.Date(isoYear, time.Month(month), day, hour, minute, second, nanoSec, @@ -307,13 +448,26 @@ func parseTs(currentLocation *time.Location, str string) (result time.Time) { } } - return t + return t, p.err } -// formatTs formats t as time.RFC3339Nano and appends time zone seconds if -// needed. -func formatTs(t time.Time) (b []byte) { - b = []byte(t.Format(time.RFC3339Nano)) +// formatTs formats t into a format postgres understands. +func formatTs(t time.Time) []byte { + if infinityTsEnabled { + // t <= -infinity : ! (t > -infinity) + if !t.After(infinityTsNegative) { + return []byte("-infinity") + } + // t >= infinity : ! (!t < infinity) + if !t.Before(infinityTsPositive) { + return []byte("infinity") + } + } + return FormatTimestamp(t) +} + +// FormatTimestamp formats t into Postgres' text format for timestamps. +func FormatTimestamp(t time.Time) []byte { // Need to send dates before 0001 A.D. with " BC" suffix, instead of the // minus sign preferred by Go. // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on @@ -323,38 +477,39 @@ func formatTs(t time.Time) (b []byte) { t = t.AddDate((-t.Year())*2+1, 0, 0) bc = true } - b = []byte(t.Format(time.RFC3339Nano)) - if bc { - b = append(b, " BC"...) - } + b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) _, offset := t.Zone() offset = offset % 60 - if offset == 0 { - return b + if offset != 0 { + // RFC3339Nano already printed the minus sign + if offset < 0 { + offset = -offset + } + + b = append(b, ':') + if offset < 10 { + b = append(b, '0') + } + b = strconv.AppendInt(b, int64(offset), 10) } - if offset < 0 { - offset = -offset + if bc { + b = append(b, " BC"...) } - - b = append(b, ':') - if offset < 10 { - b = append(b, '0') - } - return strconv.AppendInt(b, int64(offset), 10) + return b } // Parse a bytea value received from the server. Both "hex" and the legacy // "escape" format are supported. -func parseBytea(s []byte) (result []byte) { +func parseBytea(s []byte) (result []byte, err error) { if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { // bytea_output = hex s = s[2:] // trim off leading "\\x" result = make([]byte, hex.DecodedLen(len(s))) _, err := hex.Decode(result, s) if err != nil { - errorf("%s", err) + return nil, err } } else { // bytea_output = escape @@ -369,11 +524,11 @@ func parseBytea(s []byte) (result []byte) { // '\\' followed by an octal number if len(s) < 4 { - errorf("invalid bytea sequence %v", s) + return nil, fmt.Errorf("invalid bytea sequence %v", s) } r, err := strconv.ParseInt(string(s[1:4]), 8, 9) if err != nil { - errorf("could not parse bytea value: %s", err.Error()) + return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) } result = append(result, byte(r)) s = s[4:] @@ -391,13 +546,16 @@ func parseBytea(s []byte) (result []byte) { } } - return result + return result, nil } func encodeBytea(serverVersion int, v []byte) (result []byte) { if serverVersion >= 90000 { // Use the hex format if we know that the server supports it - result = []byte(fmt.Sprintf("\\x%x", v)) + result = make([]byte, 2+hex.EncodedLen(len(v))) + result[0] = '\\' + result[1] = 'x' + hex.Encode(result[2:], v) } else { // .. or resort to "escape" for _, b := range v { diff --git a/vendor/github.com/lib/pq/encode_test.go b/vendor/github.com/lib/pq/encode_test.go deleted file mode 100644 index e835f2b36fd..00000000000 --- a/vendor/github.com/lib/pq/encode_test.go +++ /dev/null @@ -1,433 +0,0 @@ -package pq - -import ( - "github.com/lib/pq/oid" - - "bytes" - "fmt" - "testing" - "time" -) - -func TestScanTimestamp(t *testing.T) { - var nt NullTime - tn := time.Now() - nt.Scan(tn) - if !nt.Valid { - t.Errorf("Expected Valid=false") - } - if nt.Time != tn { - t.Errorf("Time value mismatch") - } -} - -func TestScanNilTimestamp(t *testing.T) { - var nt NullTime - nt.Scan(nil) - if nt.Valid { - t.Errorf("Expected Valid=false") - } -} - -var timeTests = []struct { - str string - timeval time.Time -}{ - {"22001-02-03", time.Date(22001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, - {"2001-02-03", time.Date(2001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.000001", time.Date(2001, time.February, 3, 4, 5, 6, 1000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.00001", time.Date(2001, time.February, 3, 4, 5, 6, 10000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.0001", time.Date(2001, time.February, 3, 4, 5, 6, 100000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.001", time.Date(2001, time.February, 3, 4, 5, 6, 1000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.01", time.Date(2001, time.February, 3, 4, 5, 6, 10000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.1", time.Date(2001, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.12", time.Date(2001, time.February, 3, 4, 5, 6, 120000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.123", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.1234", time.Date(2001, time.February, 3, 4, 5, 6, 123400000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.12345", time.Date(2001, time.February, 3, 4, 5, 6, 123450000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.123456", time.Date(2001, time.February, 3, 4, 5, 6, 123456000, time.FixedZone("", 0))}, - {"2001-02-03 04:05:06.123-07", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, - time.FixedZone("", -7*60*60))}, - {"2001-02-03 04:05:06-07", time.Date(2001, time.February, 3, 4, 5, 6, 0, - time.FixedZone("", -7*60*60))}, - {"2001-02-03 04:05:06-07:42", time.Date(2001, time.February, 3, 4, 5, 6, 0, - time.FixedZone("", -(7*60*60+42*60)))}, - {"2001-02-03 04:05:06-07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0, - time.FixedZone("", -(7*60*60+30*60+9)))}, - {"2001-02-03 04:05:06+07", time.Date(2001, time.February, 3, 4, 5, 6, 0, - time.FixedZone("", 7*60*60))}, - {"0011-02-03 04:05:06 BC", time.Date(-10, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, - {"0011-02-03 04:05:06.123 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"0011-02-03 04:05:06.123-07 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, - time.FixedZone("", -7*60*60))}, - {"0001-02-03 04:05:06.123", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"0001-02-03 04:05:06.123 BC", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)}, - {"0001-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"0002-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)}, - {"0002-02-03 04:05:06.123 BC", time.Date(-1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, - {"12345-02-03 04:05:06.1", time.Date(12345, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, - {"123456-02-03 04:05:06.1", time.Date(123456, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, -} - -// Helper function for the two tests below -func tryParse(str string) (t time.Time, err error) { - defer func() { - if p := recover(); p != nil { - err = fmt.Errorf("%v", p) - return - } - }() - t = parseTs(nil, str) - return -} - -// Test that parsing the string results in the expected value. -func TestParseTs(t *testing.T) { - for i, tt := range timeTests { - val, err := tryParse(tt.str) - if err != nil { - t.Errorf("%d: got error: %v", i, err) - } else if val.String() != tt.timeval.String() { - t.Errorf("%d: expected to parse %q into %q; got %q", - i, tt.str, tt.timeval, val) - } - } -} - -// Now test that sending the value into the database and parsing it back -// returns the same time.Time value. -func TestEncodeAndParseTs(t *testing.T) { - db, err := openTestConnConninfo("timezone='Etc/UTC'") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - for i, tt := range timeTests { - var dbstr string - err = db.QueryRow("SELECT ($1::timestamptz)::text", tt.timeval).Scan(&dbstr) - if err != nil { - t.Errorf("%d: could not send value %q to the database: %s", i, tt.timeval, err) - continue - } - - val, err := tryParse(dbstr) - if err != nil { - t.Errorf("%d: could not parse value %q: %s", i, dbstr, err) - continue - } - val = val.In(tt.timeval.Location()) - if val.String() != tt.timeval.String() { - t.Errorf("%d: expected to parse %q into %q; got %q", i, dbstr, tt.timeval, val) - } - } -} - -var formatTimeTests = []struct { - time time.Time - expected string -}{ - {time.Time{}, "0001-01-01T00:00:00Z"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03T04:05:06.123456789Z"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03T04:05:06.123456789+02:00"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03T04:05:06.123456789-06:00"}, - {time.Date(1, time.January, 1, 0, 0, 0, 0, time.FixedZone("", 19*60+32)), "0001-01-01T00:00:00+00:19:32"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03T04:05:06-07:30:09"}, -} - -func TestFormatTs(t *testing.T) { - for i, tt := range formatTimeTests { - val := string(formatTs(tt.time)) - if val != tt.expected { - t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected) - } - } -} - -func TestTimestampWithTimeZone(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - tx, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer tx.Rollback() - - // try several different locations, all included in Go's zoneinfo.zip - for _, locName := range []string{ - "UTC", - "America/Chicago", - "America/New_York", - "Australia/Darwin", - "Australia/Perth", - } { - loc, err := time.LoadLocation(locName) - if err != nil { - t.Logf("Could not load time zone %s - skipping", locName) - continue - } - - // Postgres timestamps have a resolution of 1 microsecond, so don't - // use the full range of the Nanosecond argument - refTime := time.Date(2012, 11, 6, 10, 23, 42, 123456000, loc) - - for _, pgTimeZone := range []string{"US/Eastern", "Australia/Darwin"} { - // Switch Postgres's timezone to test different output timestamp formats - _, err = tx.Exec(fmt.Sprintf("set time zone '%s'", pgTimeZone)) - if err != nil { - t.Fatal(err) - } - - var gotTime time.Time - row := tx.QueryRow("select $1::timestamp with time zone", refTime) - err = row.Scan(&gotTime) - if err != nil { - t.Fatal(err) - } - - if !refTime.Equal(gotTime) { - t.Errorf("timestamps not equal: %s != %s", refTime, gotTime) - } - - // check that the time zone is set correctly based on TimeZone - pgLoc, err := time.LoadLocation(pgTimeZone) - if err != nil { - t.Logf("Could not load time zone %s - skipping", pgLoc) - continue - } - translated := refTime.In(pgLoc) - if translated.String() != gotTime.String() { - t.Errorf("timestamps not equal: %s != %s", translated, gotTime) - } - } - } -} - -func TestTimestampWithOutTimezone(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - test := func(ts, pgts string) { - r, err := db.Query("SELECT $1::timestamp", pgts) - if err != nil { - t.Fatalf("Could not run query: %v", err) - } - - n := r.Next() - - if n != true { - t.Fatal("Expected at least one row") - } - - var result time.Time - err = r.Scan(&result) - if err != nil { - t.Fatalf("Did not expect error scanning row: %v", err) - } - - expected, err := time.Parse(time.RFC3339, ts) - if err != nil { - t.Fatalf("Could not parse test time literal: %v", err) - } - - if !result.Equal(expected) { - t.Fatalf("Expected time to match %v: got mismatch %v", - expected, result) - } - - n = r.Next() - if n != false { - t.Fatal("Expected only one row") - } - } - - test("2000-01-01T00:00:00Z", "2000-01-01T00:00:00") - - // Test higher precision time - test("2013-01-04T20:14:58.80033Z", "2013-01-04 20:14:58.80033") -} - -func TestStringWithNul(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - hello0world := string("hello\x00world") - _, err := db.Query("SELECT $1::text", &hello0world) - if err == nil { - t.Fatal("Postgres accepts a string with nul in it; " + - "injection attacks may be plausible") - } -} - -func TestByteaToText(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - b := []byte("hello world") - row := db.QueryRow("SELECT $1::text", b) - - var result []byte - err := row.Scan(&result) - if err != nil { - t.Fatal(err) - } - - if string(result) != string(b) { - t.Fatalf("expected %v but got %v", b, result) - } -} - -func TestTextToBytea(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - b := "hello world" - row := db.QueryRow("SELECT $1::bytea", b) - - var result []byte - err := row.Scan(&result) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(result, []byte(b)) { - t.Fatalf("expected %v but got %v", b, result) - } -} - -func TestByteaOutputFormatEncoding(t *testing.T) { - input := []byte("\\x\x00\x01\x02\xFF\xFEabcdefg0123") - want := []byte("\\x5c78000102fffe6162636465666730313233") - got := encode(¶meterStatus{serverVersion: 90000}, input, oid.T_bytea) - if !bytes.Equal(want, got) { - t.Errorf("invalid hex bytea output, got %v but expected %v", got, want) - } - - want = []byte("\\\\x\\000\\001\\002\\377\\376abcdefg0123") - got = encode(¶meterStatus{serverVersion: 84000}, input, oid.T_bytea) - if !bytes.Equal(want, got) { - t.Errorf("invalid escape bytea output, got %v but expected %v", got, want) - } -} - -func TestByteaOutputFormats(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - if getServerVersion(t, db) < 90000 { - // skip - return - } - - testByteaOutputFormat := func(f string) { - expectedData := []byte("\x5c\x78\x00\xff\x61\x62\x63\x01\x08") - sqlQuery := "SELECT decode('5c7800ff6162630108', 'hex')" - - var data []byte - - // use a txn to avoid relying on getting the same connection - txn, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer txn.Rollback() - - _, err = txn.Exec("SET LOCAL bytea_output TO " + f) - if err != nil { - t.Fatal(err) - } - // use Query; QueryRow would hide the actual error - rows, err := txn.Query(sqlQuery) - if err != nil { - t.Fatal(err) - } - if !rows.Next() { - if rows.Err() != nil { - t.Fatal(rows.Err()) - } - t.Fatal("shouldn't happen") - } - err = rows.Scan(&data) - if err != nil { - t.Fatal(err) - } - err = rows.Close() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, expectedData) { - t.Errorf("unexpected bytea value %v for format %s; expected %v", data, f, expectedData) - } - } - - testByteaOutputFormat("hex") - testByteaOutputFormat("escape") -} - -func TestAppendEncodedText(t *testing.T) { - var buf []byte - - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, int64(10)) - buf = append(buf, '\t') - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, float32(42.0000000001)) - buf = append(buf, '\t') - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, 42.0000000001) - buf = append(buf, '\t') - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, "hello\tworld") - buf = append(buf, '\t') - buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, []byte{0, 128, 255}) - - if string(buf) != "10\t42\t42.0000000001\thello\\tworld\t\\\\x0080ff" { - t.Fatal(string(buf)) - } -} - -func TestAppendEscapedText(t *testing.T) { - if esc := appendEscapedText(nil, "hallo\tescape"); string(esc) != "hallo\\tescape" { - t.Fatal(string(esc)) - } - if esc := appendEscapedText(nil, "hallo\\tescape\n"); string(esc) != "hallo\\\\tescape\\n" { - t.Fatal(string(esc)) - } - if esc := appendEscapedText(nil, "\n\r\t\f"); string(esc) != "\\n\\r\\t\f" { - t.Fatal(string(esc)) - } -} - -func TestAppendEscapedTextExistingBuffer(t *testing.T) { - var buf []byte - buf = []byte("123\t") - if esc := appendEscapedText(buf, "hallo\tescape"); string(esc) != "123\thallo\\tescape" { - t.Fatal(string(esc)) - } - buf = []byte("123\t") - if esc := appendEscapedText(buf, "hallo\\tescape\n"); string(esc) != "123\thallo\\\\tescape\\n" { - t.Fatal(string(esc)) - } - buf = []byte("123\t") - if esc := appendEscapedText(buf, "\n\r\t\f"); string(esc) != "123\t\\n\\r\\t\f" { - t.Fatal(string(esc)) - } -} - -func BenchmarkAppendEscapedText(b *testing.B) { - longString := "" - for i := 0; i < 100; i++ { - longString += "123456789\n" - } - for i := 0; i < b.N; i++ { - appendEscapedText(nil, longString) - } -} - -func BenchmarkAppendEscapedTextNoEscape(b *testing.B) { - longString := "" - for i := 0; i < 100; i++ { - longString += "1234567890" - } - for i := 0; i < b.N; i++ { - appendEscapedText(nil, longString) - } -} diff --git a/vendor/github.com/lib/pq/error.go b/vendor/github.com/lib/pq/error.go index 0a49364d9ec..b4bb44cee39 100644 --- a/vendor/github.com/lib/pq/error.go +++ b/vendor/github.com/lib/pq/error.go @@ -459,6 +459,19 @@ func errorf(s string, args ...interface{}) { panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } +func errRecoverNoErrBadConn(err *error) { + e := recover() + if e == nil { + // Do nothing + return + } + var ok bool + *err, ok = e.(error) + if !ok { + *err = fmt.Errorf("pq: unexpected error: %#v", e) + } +} + func (c *conn) errRecover(err *error) { e := recover() switch v := e.(type) { diff --git a/vendor/github.com/lib/pq/notify.go b/vendor/github.com/lib/pq/notify.go index e3b08d59a2f..09f94244b9b 100644 --- a/vendor/github.com/lib/pq/notify.go +++ b/vendor/github.com/lib/pq/notify.go @@ -6,7 +6,6 @@ package pq import ( "errors" "fmt" - "io" "sync" "sync/atomic" "time" @@ -63,14 +62,18 @@ type ListenerConn struct { // Creates a new ListenerConn. Use NewListener instead. func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) { - cn, err := Open(name) + return newDialListenerConn(defaultDialer{}, name, notificationChan) +} + +func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*ListenerConn, error) { + cn, err := DialOpen(d, name) if err != nil { return nil, err } l := &ListenerConn{ cn: cn.(*conn), - notificationChan: notificationChan, + notificationChan: c, connState: connStateIdle, replyChan: make(chan message, 2), } @@ -87,12 +90,16 @@ func NewListenerConn(name string, notificationChan chan<- *Notification) (*Liste // Returns an error if an unrecoverable error has occurred and the ListenerConn // should be abandoned. func (l *ListenerConn) acquireSenderLock() error { - l.connectionLock.Lock() - defer l.connectionLock.Unlock() - if l.err != nil { - return l.err - } + // we must acquire senderLock first to avoid deadlocks; see ExecSimpleQuery l.senderLock.Lock() + + l.connectionLock.Lock() + err := l.err + l.connectionLock.Unlock() + if err != nil { + l.senderLock.Unlock() + return err + } return nil } @@ -125,7 +132,7 @@ func (l *ListenerConn) setState(newState int32) bool { // away or should be discarded because we couldn't agree on the state with the // server backend. func (l *ListenerConn) listenerConnLoop() (err error) { - defer l.cn.errRecover(&err) + defer errRecoverNoErrBadConn(&err) r := &readBuf{} for { @@ -140,6 +147,9 @@ func (l *ListenerConn) listenerConnLoop() (err error) { // about the scratch buffer being overwritten. l.notificationChan <- recvNotification(r) + case 'T', 'D': + // only used by tests; ignore + case 'E': // We might receive an ErrorResponse even when not in a query; it // is expected that the server will close the connection after @@ -238,7 +248,7 @@ func (l *ListenerConn) Ping() error { // The caller must be holding senderLock (see acquireSenderLock and // releaseSenderLock). func (l *ListenerConn) sendSimpleQuery(q string) (err error) { - defer l.cn.errRecover(&err) + defer errRecoverNoErrBadConn(&err) // must set connection state before sending the query if !l.setState(connStateExpectResponse) { @@ -247,8 +257,10 @@ func (l *ListenerConn) sendSimpleQuery(q string) (err error) { // Can't use l.cn.writeBuf here because it uses the scratch buffer which // might get overwritten by listenerConnLoop. - data := writeBuf([]byte("Q\x00\x00\x00\x00")) - b := &data + b := &writeBuf{ + buf: []byte("Q\x00\x00\x00\x00"), + pos: 1, + } b.string(q) l.cn.send(b) @@ -277,13 +289,13 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { // We can't know what state the protocol is in, so we need to abandon // this connection. l.connectionLock.Lock() - defer l.connectionLock.Unlock() // Set the error pointer if it hasn't been set already; see // listenerConnMain. if l.err == nil { l.err = err } - l.cn.Close() + l.connectionLock.Unlock() + l.cn.c.Close() return false, err } @@ -292,8 +304,11 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { m, ok := <-l.replyChan if !ok { // We lost the connection to server, don't bother waiting for a - // a response. - return false, io.EOF + // a response. err should have been set already. + l.connectionLock.Lock() + err := l.err + l.connectionLock.Unlock() + return false, err } switch m.typ { case 'Z': @@ -320,12 +335,15 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { func (l *ListenerConn) Close() error { l.connectionLock.Lock() - defer l.connectionLock.Unlock() if l.err != nil { + l.connectionLock.Unlock() return errListenerConnClosed } l.err = errListenerConnClosed - return l.cn.Close() + l.connectionLock.Unlock() + // We can't send anything on the connection without holding senderLock. + // Simply close the net.Conn to wake up everyone operating on it. + return l.cn.c.Close() } // Err() returns the reason the connection was closed. It is not safe to call @@ -377,6 +395,7 @@ type Listener struct { name string minReconnectInterval time.Duration maxReconnectInterval time.Duration + dialer Dialer eventCallback EventCallbackType lock sync.Mutex @@ -407,10 +426,21 @@ func NewListener(name string, minReconnectInterval time.Duration, maxReconnectInterval time.Duration, eventCallback EventCallbackType) *Listener { + return NewDialListener(defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback) +} + +// NewDialListener is like NewListener but it takes a Dialer. +func NewDialListener(d Dialer, + name string, + minReconnectInterval time.Duration, + maxReconnectInterval time.Duration, + eventCallback EventCallbackType) *Listener { + l := &Listener{ name: name, minReconnectInterval: minReconnectInterval, maxReconnectInterval: maxReconnectInterval, + dialer: d, eventCallback: eventCallback, channels: make(map[string]struct{}), @@ -646,7 +676,7 @@ func (l *Listener) closed() bool { func (l *Listener) connect() error { notificationChan := make(chan *Notification, 32) - cn, err := NewListenerConn(l.name, notificationChan) + cn, err := newDialListenerConn(l.dialer, l.name, notificationChan) if err != nil { return err } diff --git a/vendor/github.com/lib/pq/notify_test.go b/vendor/github.com/lib/pq/notify_test.go deleted file mode 100644 index ae1208d160f..00000000000 --- a/vendor/github.com/lib/pq/notify_test.go +++ /dev/null @@ -1,502 +0,0 @@ -package pq - -import ( - "errors" - "fmt" - "io" - "os" - "testing" - "time" -) - -var errNilNotification = errors.New("nil notification") - -func expectNotification(t *testing.T, ch <-chan *Notification, relname string, extra string) error { - select { - case n := <-ch: - if n == nil { - return errNilNotification - } - if n.Channel != relname || n.Extra != extra { - return fmt.Errorf("unexpected notification %v", n) - } - return nil - case <-time.After(1500 * time.Millisecond): - return fmt.Errorf("timeout") - } -} - -func expectNoNotification(t *testing.T, ch <-chan *Notification) error { - select { - case n := <-ch: - return fmt.Errorf("unexpected notification %v", n) - case <-time.After(100 * time.Millisecond): - return nil - } -} - -func expectEvent(t *testing.T, eventch <-chan ListenerEventType, et ListenerEventType) error { - select { - case e := <-eventch: - if e != et { - return fmt.Errorf("unexpected event %v", e) - } - return nil - case <-time.After(1500 * time.Millisecond): - return fmt.Errorf("timeout") - } -} - -func expectNoEvent(t *testing.T, eventch <-chan ListenerEventType) error { - select { - case e := <-eventch: - return fmt.Errorf("unexpected event %v", e) - case <-time.After(100 * time.Millisecond): - return nil - } -} - -func newTestListenerConn(t *testing.T) (*ListenerConn, <-chan *Notification) { - datname := os.Getenv("PGDATABASE") - sslmode := os.Getenv("PGSSLMODE") - - if datname == "" { - os.Setenv("PGDATABASE", "pqgotest") - } - - if sslmode == "" { - os.Setenv("PGSSLMODE", "disable") - } - - notificationChan := make(chan *Notification) - l, err := NewListenerConn("", notificationChan) - if err != nil { - t.Fatal(err) - } - - return l, notificationChan -} - -func TestNewListenerConn(t *testing.T) { - l, _ := newTestListenerConn(t) - - defer l.Close() -} - -func TestConnListen(t *testing.T) { - l, channel := newTestListenerConn(t) - - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - ok, err := l.Listen("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, channel, "notify_test", "") - if err != nil { - t.Fatal(err) - } -} - -func TestConnUnlisten(t *testing.T) { - l, channel := newTestListenerConn(t) - - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - ok, err := l.Listen("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - - err = expectNotification(t, channel, "notify_test", "") - if err != nil { - t.Fatal(err) - } - - ok, err = l.Unlisten("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - if err != nil { - t.Fatal(err) - } - - err = expectNoNotification(t, channel) - if err != nil { - t.Fatal(err) - } -} - -func TestConnUnlistenAll(t *testing.T) { - l, channel := newTestListenerConn(t) - - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - ok, err := l.Listen("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - - err = expectNotification(t, channel, "notify_test", "") - if err != nil { - t.Fatal(err) - } - - ok, err = l.UnlistenAll() - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test") - if err != nil { - t.Fatal(err) - } - - err = expectNoNotification(t, channel) - if err != nil { - t.Fatal(err) - } -} - -func TestConnClose(t *testing.T) { - l, _ := newTestListenerConn(t) - defer l.Close() - - err := l.Close() - if err != nil { - t.Fatal(err) - } - err = l.Close() - if err != errListenerConnClosed { - t.Fatalf("expected errListenerConnClosed; got %v", err) - } -} - -func TestConnPing(t *testing.T) { - l, _ := newTestListenerConn(t) - defer l.Close() - err := l.Ping() - if err != nil { - t.Fatal(err) - } - err = l.Close() - if err != nil { - t.Fatal(err) - } - err = l.Ping() - if err != errListenerConnClosed { - t.Fatalf("expected errListenerConnClosed; got %v", err) - } -} - -func TestNotifyExtra(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - if getServerVersion(t, db) < 90000 { - t.Skip("skipping NOTIFY payload test since the server does not appear to support it") - } - - l, channel := newTestListenerConn(t) - defer l.Close() - - ok, err := l.Listen("notify_test") - if !ok || err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_test, 'something'") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, channel, "notify_test", "something") - if err != nil { - t.Fatal(err) - } -} - -// create a new test listener and also set the timeouts -func newTestListenerTimeout(t *testing.T, min time.Duration, max time.Duration) (*Listener, <-chan ListenerEventType) { - datname := os.Getenv("PGDATABASE") - sslmode := os.Getenv("PGSSLMODE") - - if datname == "" { - os.Setenv("PGDATABASE", "pqgotest") - } - - if sslmode == "" { - os.Setenv("PGSSLMODE", "disable") - } - - eventch := make(chan ListenerEventType, 16) - l := NewListener("", min, max, func(t ListenerEventType, err error) { eventch <- t }) - err := expectEvent(t, eventch, ListenerEventConnected) - if err != nil { - t.Fatal(err) - } - return l, eventch -} - -func newTestListener(t *testing.T) (*Listener, <-chan ListenerEventType) { - return newTestListenerTimeout(t, time.Hour, time.Hour) -} - -func TestListenerListen(t *testing.T) { - l, _ := newTestListener(t) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } -} - -func TestListenerUnlisten(t *testing.T) { - l, _ := newTestListener(t) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = l.Unlisten("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNoNotification(t, l.Notify) - if err != nil { - t.Fatal(err) - } -} - -func TestListenerUnlistenAll(t *testing.T) { - l, _ := newTestListener(t) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = l.UnlistenAll() - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNoNotification(t, l.Notify) - if err != nil { - t.Fatal(err) - } -} - -func TestListenerFailedQuery(t *testing.T) { - l, eventch := newTestListener(t) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } - - // shouldn't cause a disconnect - ok, err := l.cn.ExecSimpleQuery("SELECT error") - if !ok { - t.Fatalf("could not send query to server: %v", err) - } - _, ok = err.(PGError) - if !ok { - t.Fatalf("unexpected error %v", err) - } - err = expectNoEvent(t, eventch) - if err != nil { - t.Fatal(err) - } - - // should still work - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } -} - -func TestListenerReconnect(t *testing.T) { - l, eventch := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) - defer l.Close() - - db := openTestConn(t) - defer db.Close() - - err := l.Listen("notify_listen_test") - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } - - // kill the connection and make sure it comes back up - ok, err := l.cn.ExecSimpleQuery("SELECT pg_terminate_backend(pg_backend_pid())") - if ok { - t.Fatalf("could not kill the connection: %v", err) - } - if err != io.EOF { - t.Fatalf("unexpected error %v", err) - } - err = expectEvent(t, eventch, ListenerEventDisconnected) - if err != nil { - t.Fatal(err) - } - err = expectEvent(t, eventch, ListenerEventReconnected) - if err != nil { - t.Fatal(err) - } - - // should still work - _, err = db.Exec("NOTIFY notify_listen_test") - if err != nil { - t.Fatal(err) - } - - // should get nil after Reconnected - err = expectNotification(t, l.Notify, "", "") - if err != errNilNotification { - t.Fatal(err) - } - - err = expectNotification(t, l.Notify, "notify_listen_test", "") - if err != nil { - t.Fatal(err) - } -} - -func TestListenerClose(t *testing.T) { - l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) - defer l.Close() - - err := l.Close() - if err != nil { - t.Fatal(err) - } - err = l.Close() - if err != errListenerClosed { - t.Fatalf("expected errListenerClosed; got %v", err) - } -} - -func TestListenerPing(t *testing.T) { - l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) - defer l.Close() - - err := l.Ping() - if err != nil { - t.Fatal(err) - } - - err = l.Close() - if err != nil { - t.Fatal(err) - } - - err = l.Ping() - if err != errListenerClosed { - t.Fatalf("expected errListenerClosed; got %v", err) - } -} diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go new file mode 100644 index 00000000000..7deb304366f --- /dev/null +++ b/vendor/github.com/lib/pq/ssl.go @@ -0,0 +1,158 @@ +package pq + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "os" + "os/user" + "path/filepath" +) + +// ssl generates a function to upgrade a net.Conn based on the "sslmode" and +// related settings. The function is nil when no upgrade should take place. +func ssl(o values) func(net.Conn) net.Conn { + verifyCaOnly := false + tlsConf := tls.Config{} + switch mode := o["sslmode"]; mode { + // "require" is the default. + case "", "require": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + + // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: + // + // Note: For backwards compatibility with earlier versions of + // PostgreSQL, if a root CA file exists, the behavior of + // sslmode=require will be the same as that of verify-ca, meaning the + // server certificate is validated against the CA. Relying on this + // behavior is discouraged, and applications that need certificate + // validation should always use verify-ca or verify-full. + if sslrootcert, ok := o["sslrootcert"]; ok { + if _, err := os.Stat(sslrootcert); err == nil { + verifyCaOnly = true + } else { + delete(o, "sslrootcert") + } + } + case "verify-ca": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + verifyCaOnly = true + case "verify-full": + tlsConf.ServerName = o["host"] + case "disable": + return nil + default: + errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) + } + + sslClientCertificates(&tlsConf, o) + sslCertificateAuthority(&tlsConf, o) + sslRenegotiation(&tlsConf) + + return func(conn net.Conn) net.Conn { + client := tls.Client(conn, &tlsConf) + if verifyCaOnly { + sslVerifyCertificateAuthority(client, &tlsConf) + } + return client + } +} + +// sslClientCertificates adds the certificate specified in the "sslcert" and +// "sslkey" settings, or if they aren't set, from the .postgresql directory +// in the user's home directory. The configured files must exist and have +// the correct permissions. +func sslClientCertificates(tlsConf *tls.Config, o values) { + // user.Current() might fail when cross-compiling. We have to ignore the + // error and continue without home directory defaults, since we wouldn't + // know from where to load them. + user, _ := user.Current() + + // In libpq, the client certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 + sslcert := o["sslcert"] + if len(sslcert) == 0 && user != nil { + sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 + if len(sslcert) == 0 { + return + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 + if _, err := os.Stat(sslcert); os.IsNotExist(err) { + return + } else if err != nil { + panic(err) + } + + // In libpq, the ssl key is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 + sslkey := o["sslkey"] + if len(sslkey) == 0 && user != nil { + sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + } + + if len(sslkey) > 0 { + if err := sslKeyPermissions(sslkey); err != nil { + panic(err) + } + } + + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + panic(err) + } + tlsConf.Certificates = []tls.Certificate{cert} +} + +// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. +func sslCertificateAuthority(tlsConf *tls.Config, o values) { + // In libpq, the root certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 + if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { + tlsConf.RootCAs = x509.NewCertPool() + + cert, err := ioutil.ReadFile(sslrootcert) + if err != nil { + panic(err) + } + + if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { + errorf("couldn't parse pem in sslrootcert") + } + } +} + +// sslVerifyCertificateAuthority carries out a TLS handshake to the server and +// verifies the presented certificate against the CA, i.e. the one specified in +// sslrootcert or the system CA if sslrootcert was not specified. +func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) { + err := client.Handshake() + if err != nil { + panic(err) + } + certs := client.ConnectionState().PeerCertificates + opts := x509.VerifyOptions{ + DNSName: client.ConnectionState().ServerName, + Intermediates: x509.NewCertPool(), + Roots: tlsConf.RootCAs, + } + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + _, err = certs[0].Verify(opts) + if err != nil { + panic(err) + } +} diff --git a/vendor/github.com/lib/pq/ssl_go1.7.go b/vendor/github.com/lib/pq/ssl_go1.7.go new file mode 100644 index 00000000000..d7ba43b32a1 --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_go1.7.go @@ -0,0 +1,14 @@ +// +build go1.7 + +package pq + +import "crypto/tls" + +// Accept renegotiation requests initiated by the backend. +// +// Renegotiation was deprecated then removed from PostgreSQL 9.5, but +// the default configuration of older versions has it enabled. Redshift +// also initiates renegotiations and cannot be reconfigured. +func sslRenegotiation(conf *tls.Config) { + conf.Renegotiation = tls.RenegotiateFreelyAsClient +} diff --git a/vendor/github.com/lib/pq/ssl_permissions.go b/vendor/github.com/lib/pq/ssl_permissions.go new file mode 100644 index 00000000000..3b7c3a2a319 --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_permissions.go @@ -0,0 +1,20 @@ +// +build !windows + +package pq + +import "os" + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(sslkey string) error { + info, err := os.Stat(sslkey) + if err != nil { + return err + } + if info.Mode().Perm()&0077 != 0 { + return ErrSSLKeyHasWorldPermissions + } + return nil +} diff --git a/vendor/github.com/lib/pq/ssl_renegotiation.go b/vendor/github.com/lib/pq/ssl_renegotiation.go new file mode 100644 index 00000000000..85ed5e437fb --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_renegotiation.go @@ -0,0 +1,8 @@ +// +build !go1.7 + +package pq + +import "crypto/tls" + +// Renegotiation is not supported by crypto/tls until Go 1.7. +func sslRenegotiation(*tls.Config) {} diff --git a/vendor/github.com/lib/pq/ssl_test.go b/vendor/github.com/lib/pq/ssl_test.go deleted file mode 100644 index 932b336f530..00000000000 --- a/vendor/github.com/lib/pq/ssl_test.go +++ /dev/null @@ -1,226 +0,0 @@ -package pq - -// This file contains SSL tests - -import ( - _ "crypto/sha256" - "crypto/x509" - "database/sql" - "fmt" - "os" - "path/filepath" - "testing" -) - -func maybeSkipSSLTests(t *testing.T) { - // Require some special variables for testing certificates - if os.Getenv("PQSSLCERTTEST_PATH") == "" { - t.Skip("PQSSLCERTTEST_PATH not set, skipping SSL tests") - } - - value := os.Getenv("PQGOSSLTESTS") - if value == "" || value == "0" { - t.Skip("PQGOSSLTESTS not enabled, skipping SSL tests") - } else if value != "1" { - t.Fatalf("unexpected value %q for PQGOSSLTESTS", value) - } -} - -func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) { - db, err := openTestConnConninfo(conninfo) - if err != nil { - // should never fail - t.Fatal(err) - } - // Do something with the connection to see whether it's working or not. - tx, err := db.Begin() - if err == nil { - return db, tx.Rollback() - } - _ = db.Close() - return nil, err -} - -func checkSSLSetup(t *testing.T, conninfo string) { - db, err := openSSLConn(t, conninfo) - if err == nil { - db.Close() - t.Fatalf("expected error with conninfo=%q", conninfo) - } -} - -// Connect over SSL and run a simple query to test the basics -func TestSSLConnection(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - db, err := openSSLConn(t, "sslmode=require user=pqgossltest") - if err != nil { - t.Fatal(err) - } - rows, err := db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - rows.Close() -} - -// Test sslmode=verify-full -func TestSSLVerifyFull(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - // Not OK according to the system CA - _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest") - if err == nil { - t.Fatal("expected error") - } - _, ok := err.(x509.UnknownAuthorityError) - if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) - } - - rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") - rootCert := "sslrootcert=" + rootCertPath + " " - // No match on Common Name - _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest") - if err == nil { - t.Fatal("expected error") - } - _, ok = err.(x509.HostnameError) - if !ok { - t.Fatalf("expected x509.HostnameError, got %#+v", err) - } - // OK - _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest") - if err != nil { - t.Fatal(err) - } -} - -// Test sslmode=verify-ca -func TestSSLVerifyCA(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - // Not OK according to the system CA - _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") - if err == nil { - t.Fatal("expected error") - } - _, ok := err.(x509.UnknownAuthorityError) - if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) - } - - rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") - rootCert := "sslrootcert=" + rootCertPath + " " - // No match on Common Name, but that's OK - _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest") - if err != nil { - t.Fatal(err) - } - // Everything OK - _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest") - if err != nil { - t.Fatal(err) - } -} - -func getCertConninfo(t *testing.T, source string) string { - var sslkey string - var sslcert string - - certpath := os.Getenv("PQSSLCERTTEST_PATH") - - switch source { - case "missingkey": - sslkey = "/tmp/filedoesnotexist" - sslcert = filepath.Join(certpath, "postgresql.crt") - case "missingcert": - sslkey = filepath.Join(certpath, "postgresql.key") - sslcert = "/tmp/filedoesnotexist" - case "certtwice": - sslkey = filepath.Join(certpath, "postgresql.crt") - sslcert = filepath.Join(certpath, "postgresql.crt") - case "valid": - sslkey = filepath.Join(certpath, "postgresql.key") - sslcert = filepath.Join(certpath, "postgresql.crt") - default: - t.Fatalf("invalid source %q", source) - } - return fmt.Sprintf("sslmode=require user=pqgosslcert sslkey=%s sslcert=%s", sslkey, sslcert) -} - -// Authenticate over SSL using client certificates -func TestSSLClientCertificates(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - // Should also fail without a valid certificate - db, err := openSSLConn(t, "sslmode=require user=pqgosslcert") - if err == nil { - db.Close() - t.Fatal("expected error") - } - pge, ok := err.(*Error) - if !ok { - t.Fatal("expected pq.Error") - } - if pge.Code.Name() != "invalid_authorization_specification" { - t.Fatalf("unexpected error code %q", pge.Code.Name()) - } - - // Should work - db, err = openSSLConn(t, getCertConninfo(t, "valid")) - if err != nil { - t.Fatal(err) - } - rows, err := db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - rows.Close() -} - -// Test errors with ssl certificates -func TestSSLClientCertificatesMissingFiles(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - // Key missing, should fail - _, err := openSSLConn(t, getCertConninfo(t, "missingkey")) - if err == nil { - t.Fatal("expected error") - } - // should be a PathError - _, ok := err.(*os.PathError) - if !ok { - t.Fatalf("expected PathError, got %#+v", err) - } - - // Cert missing, should fail - _, err = openSSLConn(t, getCertConninfo(t, "missingcert")) - if err == nil { - t.Fatal("expected error") - } - // should be a PathError - _, ok = err.(*os.PathError) - if !ok { - t.Fatalf("expected PathError, got %#+v", err) - } - - // Key has wrong permissions, should fail - _, err = openSSLConn(t, getCertConninfo(t, "certtwice")) - if err == nil { - t.Fatal("expected error") - } - if err != ErrSSLKeyHasWorldPermissions { - t.Fatalf("expected ErrSSLKeyHasWorldPermissions, got %#+v", err) - } -} diff --git a/vendor/github.com/lib/pq/ssl_windows.go b/vendor/github.com/lib/pq/ssl_windows.go new file mode 100644 index 00000000000..5d2c763cebc --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_windows.go @@ -0,0 +1,9 @@ +// +build windows + +package pq + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(string) error { return nil } diff --git a/vendor/github.com/lib/pq/url.go b/vendor/github.com/lib/pq/url.go index b83e806b8e3..f4d8a7c2062 100644 --- a/vendor/github.com/lib/pq/url.go +++ b/vendor/github.com/lib/pq/url.go @@ -2,6 +2,7 @@ package pq import ( "fmt" + "net" nurl "net/url" "sort" "strings" @@ -34,7 +35,7 @@ func ParseURL(url string) (string, error) { return "", err } - if u.Scheme != "postgres" { + if u.Scheme != "postgres" && u.Scheme != "postgresql" { return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) } @@ -54,12 +55,11 @@ func ParseURL(url string) (string, error) { accrue("password", v) } - i := strings.Index(u.Host, ":") - if i < 0 { + if host, port, err := net.SplitHostPort(u.Host); err != nil { accrue("host", u.Host) } else { - accrue("host", u.Host[:i]) - accrue("port", u.Host[i+1:]) + accrue("host", host) + accrue("port", port) } if u.Path != "" { diff --git a/vendor/github.com/lib/pq/url_test.go b/vendor/github.com/lib/pq/url_test.go deleted file mode 100644 index 29f4a7c7516..00000000000 --- a/vendor/github.com/lib/pq/url_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package pq - -import ( - "testing" -) - -func TestSimpleParseURL(t *testing.T) { - expected := "host=hostname.remote" - str, err := ParseURL("postgres://hostname.remote") - if err != nil { - t.Fatal(err) - } - - if str != expected { - t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) - } -} - -func TestFullParseURL(t *testing.T) { - expected := `dbname=database host=hostname.remote password=top\ secret port=1234 user=username` - str, err := ParseURL("postgres://username:top%20secret@hostname.remote:1234/database") - if err != nil { - t.Fatal(err) - } - - if str != expected { - t.Fatalf("unexpected result from ParseURL:\n+ %s\n- %s", str, expected) - } -} - -func TestInvalidProtocolParseURL(t *testing.T) { - _, err := ParseURL("http://hostname.remote") - switch err { - case nil: - t.Fatal("Expected an error from parsing invalid protocol") - default: - msg := "invalid connection protocol: http" - if err.Error() != msg { - t.Fatalf("Unexpected error message:\n+ %s\n- %s", - err.Error(), msg) - } - } -} - -func TestMinimalURL(t *testing.T) { - cs, err := ParseURL("postgres://") - if err != nil { - t.Fatal(err) - } - - if cs != "" { - t.Fatalf("expected blank connection string, got: %q", cs) - } -} diff --git a/vendor/github.com/lib/pq/user_posix.go b/vendor/github.com/lib/pq/user_posix.go index e937d7d0872..bf982524f93 100644 --- a/vendor/github.com/lib/pq/user_posix.go +++ b/vendor/github.com/lib/pq/user_posix.go @@ -1,6 +1,6 @@ // Package pq is a pure Go Postgres driver for the database/sql package. -// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris rumprun package pq diff --git a/vendor/github.com/lib/pq/uuid.go b/vendor/github.com/lib/pq/uuid.go new file mode 100644 index 00000000000..9a1b9e0748e --- /dev/null +++ b/vendor/github.com/lib/pq/uuid.go @@ -0,0 +1,23 @@ +package pq + +import ( + "encoding/hex" + "fmt" +) + +// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. +func decodeUUIDBinary(src []byte) ([]byte, error) { + if len(src) != 16 { + return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) + } + + dst := make([]byte, 36) + dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' + hex.Encode(dst[0:], src[0:4]) + hex.Encode(dst[9:], src[4:6]) + hex.Encode(dst[14:], src[6:8]) + hex.Encode(dst[19:], src[8:10]) + hex.Encode(dst[24:], src[10:16]) + + return dst, nil +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 73c45eeb7cf..5f8a975c35c 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -440,6 +440,12 @@ "revision": "1bab8b35b6bb565f92cbc97939610af9369f942a", "revisionTime": "2017-02-10T14:05:23Z" }, + { + "checksumSHA1": "ZAj/o03zG8Ui4mZ4XmzU4yyKC04=", + "path": "github.com/lib/pq", + "revision": "dd1fe2071026ce53f36a39112e645b4d4f5793a4", + "revisionTime": "2017-07-07T05:36:02Z" + }, { "checksumSHA1": "8z32QKTSDusa4QQyunKE4kyYXZ8=", "path": "github.com/patrickmn/go-cache",