This commit is contained in:
Marcus Efraimsson 2018-03-19 14:06:05 +01:00
parent d14946a135
commit a2eaf3954a
34 changed files with 411 additions and 272 deletions

9
Gopkg.lock generated
View File

@ -105,8 +105,11 @@
[[projects]]
name = "github.com/denisenkom/go-mssqldb"
packages = ["."]
revision = "ee492709d4324cdcb051d2ac266b77ddc380f5c5"
packages = [
".",
"internal/cp"
]
revision = "270bc3860bb94dd3a3ffd047377d746c5e276726"
[[projects]]
name = "github.com/fatih/color"
@ -639,6 +642,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "d2f67abb94028a388f051164896bfb69b1ff3a7255d285dc4d78d298f4793383"
inputs-digest = "5e65aeace832f1b4be17e7ff5d5714513c40f31b94b885f64f98f2332968d7c6"
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -200,4 +200,4 @@ ignored = [
[[constraint]]
name = "github.com/denisenkom/go-mssqldb"
revision = "ee492709d4324cdcb051d2ac266b77ddc380f5c5"
revision = "270bc3860bb94dd3a3ffd047377d746c5e276726"

View File

@ -115,15 +115,23 @@ func (w *tdsBuffer) WriteByte(b byte) error {
return nil
}
func (w *tdsBuffer) BeginPacket(packetType packetType) {
w.wbuf[1] = 0 // Packet is incomplete. This byte is set again in FinishPacket.
func (w *tdsBuffer) BeginPacket(packetType packetType, resetSession bool) {
status := byte(0)
if resetSession {
switch packetType {
// Reset session can only be set on the following packet types.
case packSQLBatch, packRPCRequest, packTransMgrReq:
status = 0x8
}
}
w.wbuf[1] = status // Packet is incomplete. This byte is set again in FinishPacket.
w.wpos = 8
w.wPacketSeq = 1
w.wPacketType = packetType
}
func (w *tdsBuffer) FinishPacket() error {
w.wbuf[1] = 1 // Mark this as the last packet in the message.
w.wbuf[1] |= 1 // Mark this as the last packet in the message.
return w.flush()
}

View File

@ -12,8 +12,14 @@ import (
"time"
)
type MssqlBulk struct {
cn *MssqlConn
type Bulk struct {
// ctx is used only for AddRow and Done methods.
// This could be removed if AddRow and Done accepted
// a ctx field as well, which is available with the
// database/sql call.
ctx context.Context
cn *Conn
metadata []columnStruct
bulkColumns []columnStruct
columnsName []string
@ -21,10 +27,10 @@ type MssqlBulk struct {
numRows int
headerSent bool
Options MssqlBulkOptions
Options BulkOptions
Debug bool
}
type MssqlBulkOptions struct {
type BulkOptions struct {
CheckConstraints bool
FireTriggers bool
KeepNulls bool
@ -36,15 +42,21 @@ type MssqlBulkOptions struct {
type DataValue interface{}
func (cn *MssqlConn) CreateBulk(table string, columns []string) (_ *MssqlBulk) {
b := MssqlBulk{cn: cn, tablename: table, headerSent: false, columnsName: columns}
func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns}
b.Debug = false
return &b
}
func (b *MssqlBulk) sendBulkCommand() (err error) {
func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) {
b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns}
b.Debug = false
return &b
}
func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
//get table columns info
err = b.getMetadata()
err = b.getMetadata(ctx)
if err != nil {
return err
}
@ -114,13 +126,13 @@ func (b *MssqlBulk) sendBulkCommand() (err error) {
query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
stmt, err := b.cn.Prepare(query)
stmt, err := b.cn.PrepareContext(ctx, query)
if err != nil {
return fmt.Errorf("Prepare failed: %s", err.Error())
}
b.dlogf(query)
_, err = stmt.Exec(nil)
_, err = stmt.(*Stmt).ExecContext(ctx, nil)
if err != nil {
return err
}
@ -128,9 +140,9 @@ func (b *MssqlBulk) sendBulkCommand() (err error) {
b.headerSent = true
var buf = b.cn.sess.buf
buf.BeginPacket(packBulkLoadBCP)
buf.BeginPacket(packBulkLoadBCP, false)
// send the columns metadata
// Send the columns metadata.
columnMetadata := b.createColMetadata()
_, err = buf.Write(columnMetadata)
@ -139,9 +151,9 @@ func (b *MssqlBulk) sendBulkCommand() (err error) {
// AddRow immediately writes the row to the destination table.
// The arguments are the row values in the order they were specified.
func (b *MssqlBulk) AddRow(row []interface{}) (err error) {
func (b *Bulk) AddRow(row []interface{}) (err error) {
if !b.headerSent {
err = b.sendBulkCommand()
err = b.sendBulkCommand(b.ctx)
if err != nil {
return
}
@ -166,7 +178,7 @@ func (b *MssqlBulk) AddRow(row []interface{}) (err error) {
return
}
func (b *MssqlBulk) makeRowData(row []interface{}) ([]byte, error) {
func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
buf.WriteByte(byte(tokenRow))
@ -196,7 +208,7 @@ func (b *MssqlBulk) makeRowData(row []interface{}) ([]byte, error) {
return buf.Bytes(), nil
}
func (b *MssqlBulk) Done() (rowcount int64, err error) {
func (b *Bulk) Done() (rowcount int64, err error) {
if b.headerSent == false {
//no rows had been sent
return 0, nil
@ -216,7 +228,7 @@ func (b *MssqlBulk) Done() (rowcount int64, err error) {
buf.FinishPacket()
tokchan := make(chan tokenStruct, 5)
go processResponse(context.Background(), b.cn.sess, tokchan, nil)
go processResponse(b.ctx, b.cn.sess, tokchan, nil)
var rowCount int64
for token := range tokchan {
@ -235,7 +247,7 @@ func (b *MssqlBulk) Done() (rowcount int64, err error) {
return rowCount, nil
}
func (b *MssqlBulk) createColMetadata() []byte {
func (b *Bulk) createColMetadata() []byte {
buf := new(bytes.Buffer)
buf.WriteByte(byte(tokenColMetadata)) // token
binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
@ -267,64 +279,40 @@ func (b *MssqlBulk) createColMetadata() []byte {
return buf.Bytes()
}
func (b *MssqlBulk) getMetadata() (err error) {
stmt, err := b.cn.Prepare("SET FMTONLY ON")
func (b *Bulk) getMetadata(ctx context.Context) (err error) {
stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON")
if err != nil {
return
}
_, err = stmt.Exec(nil)
_, err = stmt.ExecContext(ctx, nil)
if err != nil {
return
}
//get columns info
stmt, err = b.cn.Prepare(fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
// Get columns info.
stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
if err != nil {
return
}
stmt2 := stmt.(*MssqlStmt)
cols, err := stmt2.QueryMeta()
rows, err := stmt.QueryContext(ctx, nil)
if err != nil {
return fmt.Errorf("get columns info failed: %v", err.Error())
return fmt.Errorf("get columns info failed: %v", err)
}
b.metadata = cols
b.metadata = rows.(*Rows).cols
if b.Debug {
for _, col := range b.metadata {
b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
col.Flags, col.ti.Collation.lcidAndFlags)
col.Flags, col.ti.Collation.LcidAndFlags)
}
}
return nil
return rows.Close()
}
// QueryMeta is almost the same as MssqlStmt.Query, but returns all the columns info.
func (s *MssqlStmt) QueryMeta() (cols []columnStruct, err error) {
if err = s.sendQuery(nil); err != nil {
return
}
tokchan := make(chan tokenStruct, 5)
go processResponse(context.Background(), s.c.sess, tokchan, s.c.outs)
s.c.clearOuts()
loop:
for tok := range tokchan {
switch token := tok.(type) {
case doneStruct:
break loop
case []columnStruct:
cols = token
break loop
case error:
return nil, s.c.checkBadConn(token)
}
}
return cols, nil
}
func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err error) {
func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) {
res.ti.Size = col.ti.Size
res.ti.TypeId = col.ti.TypeId
@ -592,6 +580,15 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
return
}
case typeGuid:
switch val := val.(type) {
case []byte:
res.ti.Size = len(val)
res.buffer = val
default:
err = fmt.Errorf("mssql: invalid type for Guid column: %s", val)
return
}
default:
err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
@ -600,7 +597,7 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
}
func (b *MssqlBulk) dlogf(format string, v ...interface{}) {
func (b *Bulk) dlogf(format string, v ...interface{}) {
if b.Debug {
b.cn.sess.log.Printf(format, v...)
}

View File

@ -1,37 +1,38 @@
package mssql
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
)
type copyin struct {
cn *MssqlConn
bulkcopy *MssqlBulk
cn *Conn
bulkcopy *Bulk
closed bool
}
type SerializableBulkConfig struct {
type serializableBulkConfig struct {
TableName string
ColumnsName []string
Options MssqlBulkOptions
Options BulkOptions
}
func (d *MssqlDriver) OpenConnection(dsn string) (*MssqlConn, error) {
return d.open(dsn)
func (d *Driver) OpenConnection(dsn string) (*Conn, error) {
return d.open(context.Background(), dsn)
}
func (c *MssqlConn) prepareCopyIn(query string) (_ driver.Stmt, err error) {
func (c *Conn) prepareCopyIn(ctx context.Context, query string) (_ driver.Stmt, err error) {
config_json := query[11:]
bulkconfig := SerializableBulkConfig{}
bulkconfig := serializableBulkConfig{}
err = json.Unmarshal([]byte(config_json), &bulkconfig)
if err != nil {
return
}
bulkcopy := c.CreateBulk(bulkconfig.TableName, bulkconfig.ColumnsName)
bulkcopy := c.CreateBulkContext(ctx, bulkconfig.TableName, bulkconfig.ColumnsName)
bulkcopy.Options = bulkconfig.Options
ci := &copyin{
@ -42,8 +43,8 @@ func (c *MssqlConn) prepareCopyIn(query string) (_ driver.Stmt, err error) {
return ci, nil
}
func CopyIn(table string, options MssqlBulkOptions, columns ...string) string {
bulkconfig := &SerializableBulkConfig{TableName: table, Options: options, ColumnsName: columns}
func CopyIn(table string, options BulkOptions, columns ...string) string {
bulkconfig := &serializableBulkConfig{TableName: table, Options: options, ColumnsName: columns}
config_json, err := json.Marshal(bulkconfig)
if err != nil {
@ -60,7 +61,7 @@ func (ci *copyin) NumInput() int {
}
func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
return nil, errors.New("ErrNotSupported")
panic("should never be called")
}
func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {

View File

@ -1,39 +0,0 @@
package mssql
import (
"encoding/binary"
"io"
)
// http://msdn.microsoft.com/en-us/library/dd340437.aspx
type collation struct {
lcidAndFlags uint32
sortId uint8
}
func (c collation) getLcid() uint32 {
return c.lcidAndFlags & 0x000fffff
}
func (c collation) getFlags() uint32 {
return (c.lcidAndFlags & 0x0ff00000) >> 20
}
func (c collation) getVersion() uint32 {
return (c.lcidAndFlags & 0xf0000000) >> 28
}
func readCollation(r *tdsBuffer) (res collation) {
res.lcidAndFlags = r.uint32()
res.sortId = r.byte()
return
}
func writeCollation(w io.Writer, col collation) (err error) {
if err = binary.Write(w, binary.LittleEndian, col.lcidAndFlags); err != nil {
return
}
err = binary.Write(w, binary.LittleEndian, col.sortId)
return
}

View File

@ -1,14 +1,14 @@
package mssql
package cp
type charsetMap struct {
sb [256]rune // single byte runes, -1 for a double byte character lead byte
db map[int]rune // double byte runes
}
func collation2charset(col collation) *charsetMap {
func collation2charset(col Collation) *charsetMap {
// http://msdn.microsoft.com/en-us/library/ms144250.aspx
// http://msdn.microsoft.com/en-us/library/ms144250(v=sql.105).aspx
switch col.sortId {
switch col.SortId {
case 30, 31, 32, 33, 34:
return cp437
case 40, 41, 42, 44, 49, 55, 56, 57, 58, 59, 60, 61:
@ -86,7 +86,7 @@ func collation2charset(col collation) *charsetMap {
return cp1252
}
func charset2utf8(col collation, s []byte) string {
func CharsetToUTF8(col Collation, s []byte) string {
cm := collation2charset(col)
if cm == nil {
return string(s)

View File

@ -0,0 +1,20 @@
package cp
// http://msdn.microsoft.com/en-us/library/dd340437.aspx
type Collation struct {
LcidAndFlags uint32
SortId uint8
}
func (c Collation) getLcid() uint32 {
return c.LcidAndFlags & 0x000fffff
}
func (c Collation) getFlags() uint32 {
return (c.LcidAndFlags & 0x0ff00000) >> 20
}
func (c Collation) getVersion() uint32 {
return (c.LcidAndFlags & 0xf0000000) >> 28
}

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp1250 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp1251 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp1252 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp1253 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp1254 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp1255 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp1256 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp1257 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp1258 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp437 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp850 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp874 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp932 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp936 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp949 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -1,4 +1,4 @@
package mssql
package cp
var cp950 *charsetMap = &charsetMap{
sb: [256]rune{

View File

@ -15,20 +15,20 @@ import (
"time"
)
var driverInstance = &MssqlDriver{processQueryText: true}
var driverInstanceNoProcess = &MssqlDriver{processQueryText: false}
var driverInstance = &Driver{processQueryText: true}
var driverInstanceNoProcess = &Driver{processQueryText: false}
func init() {
sql.Register("mssql", driverInstance)
sql.Register("sqlserver", driverInstanceNoProcess)
createDialer = func(p *connectParams) dialer {
return tcpDialer{&net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive}}
return tcpDialer{&net.Dialer{KeepAlive: p.keepAlive}}
}
}
// Abstract the dialer for testing and for non-TCP based connections.
type dialer interface {
Dial(addr string) (net.Conn, error)
Dial(ctx context.Context, addr string) (net.Conn, error)
}
var createDialer func(p *connectParams) dialer
@ -37,28 +37,75 @@ type tcpDialer struct {
nd *net.Dialer
}
func (d tcpDialer) Dial(addr string) (net.Conn, error) {
return d.nd.Dial("tcp", addr)
func (d tcpDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
return d.nd.DialContext(ctx, "tcp", addr)
}
type MssqlDriver struct {
type Driver struct {
log optionalLogger
processQueryText bool
}
// OpenConnector opens a new connector. Useful to dial with a context.
func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
params, err := parseConnectParams(dsn)
if err != nil {
return nil, err
}
return &Connector{
params: params,
driver: d,
}, nil
}
func (d *Driver) Open(dsn string) (driver.Conn, error) {
return d.open(context.Background(), dsn)
}
func SetLogger(logger Logger) {
driverInstance.SetLogger(logger)
driverInstanceNoProcess.SetLogger(logger)
}
func (d *MssqlDriver) SetLogger(logger Logger) {
func (d *Driver) SetLogger(logger Logger) {
d.log = optionalLogger{logger}
}
type MssqlConn struct {
// Connector holds the parsed DSN and is ready to make a new connection
// at any time.
//
// In the future, settings that cannot be passed through a string DSN
// may be set directly on the connector.
type Connector struct {
params connectParams
driver *Driver
// ResetSQL is executed after marking a given connection to be reset.
// When not present, the next query will be reset to the database
// defaults.
// When present the connection will immediately mark the connection to
// be reset, then execute the ResetSQL text to setup the session
// that may be different from the base database defaults.
//
// For Example, the application relies on the following defaults
// but is not allowed to set them at the database system level.
//
// SET XACT_ABORT ON;
// SET TEXTSIZE -1;
// SET ANSI_NULLS ON;
// SET LOCK_TIMEOUT 10000;
//
// ResetSQL should not attempt to manually call sp_reset_connection.
// This will happen at the TDS layer.
ResetSQL string
}
type Conn struct {
connector *Connector
sess *tdsSession
transactionCtx context.Context
resetSession bool
processQueryText bool
connectionGood bool
@ -66,7 +113,7 @@ type MssqlConn struct {
outs map[string]interface{}
}
func (c *MssqlConn) checkBadConn(err error) error {
func (c *Conn) checkBadConn(err error) error {
// this is a hack to address Issue #275
// we set connectionGood flag to false if
// error indicates that connection is not usable
@ -81,11 +128,12 @@ func (c *MssqlConn) checkBadConn(err error) error {
case nil:
return nil
case io.EOF:
c.connectionGood = false
return driver.ErrBadConn
case driver.ErrBadConn:
// It is an internal programming error if driver.ErrBadConn
// is ever passed to this function. driver.ErrBadConn should
// only ever be returned in response to a *MssqlConn.connectionGood == false
// only ever be returned in response to a *mssql.Conn.connectionGood == false
// check in the external facing API.
panic("driver.ErrBadConn in checkBadConn. This should not happen.")
}
@ -102,11 +150,11 @@ func (c *MssqlConn) checkBadConn(err error) error {
}
}
func (c *MssqlConn) clearOuts() {
func (c *Conn) clearOuts() {
c.outs = nil
}
func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
func (c *Conn) simpleProcessResp(ctx context.Context) error {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, c.sess, tokchan, c.outs)
c.clearOuts()
@ -123,7 +171,7 @@ func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
return nil
}
func (c *MssqlConn) Commit() error {
func (c *Conn) Commit() error {
if !c.connectionGood {
return driver.ErrBadConn
}
@ -133,12 +181,14 @@ func (c *MssqlConn) Commit() error {
return c.simpleProcessResp(c.transactionCtx)
}
func (c *MssqlConn) sendCommitRequest() error {
func (c *Conn) sendCommitRequest() error {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{c.sess.tranid, 1}.pack()},
}
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
reset := c.resetSession
c.resetSession = false
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send CommitXact with %v", err)
}
@ -148,7 +198,7 @@ func (c *MssqlConn) sendCommitRequest() error {
return nil
}
func (c *MssqlConn) Rollback() error {
func (c *Conn) Rollback() error {
if !c.connectionGood {
return driver.ErrBadConn
}
@ -158,12 +208,14 @@ func (c *MssqlConn) Rollback() error {
return c.simpleProcessResp(c.transactionCtx)
}
func (c *MssqlConn) sendRollbackRequest() error {
func (c *Conn) sendRollbackRequest() error {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{c.sess.tranid, 1}.pack()},
}
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
reset := c.resetSession
c.resetSession = false
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
}
@ -173,11 +225,11 @@ func (c *MssqlConn) sendRollbackRequest() error {
return nil
}
func (c *MssqlConn) Begin() (driver.Tx, error) {
func (c *Conn) Begin() (driver.Tx, error) {
return c.begin(context.Background(), isolationUseCurrent)
}
func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
@ -192,13 +244,15 @@ func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver
return
}
func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
c.transactionCtx = ctx
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{0, 1}.pack()},
}
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
reset := c.resetSession
c.resetSession = false
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send BeginXact with %v", err)
}
@ -208,7 +262,7 @@ func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel)
return nil
}
func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
if err := c.simpleProcessResp(ctx); err != nil {
return nil, err
}
@ -217,17 +271,17 @@ func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error)
return c, nil
}
func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
return d.open(dsn)
}
func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
params, err := parseConnectParams(dsn)
if err != nil {
return nil, err
}
return d.connect(ctx, params)
}
sess, err := connect(d.log, params)
// connect to the server, using the provided context for dialing only.
func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, error) {
sess, err := connect(ctx, d.log, params)
if err != nil {
// main server failed, try fail-over partner
if params.failOverPartner == "" {
@ -239,29 +293,30 @@ func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
params.port = params.failOverPort
}
sess, err = connect(d.log, params)
sess, err = connect(ctx, d.log, params)
if err != nil {
// fail-over partner also failed, now fail
return nil, err
}
}
conn := &MssqlConn{
conn := &Conn{
sess: sess,
transactionCtx: context.Background(),
processQueryText: d.processQueryText,
connectionGood: true,
}
conn.sess.log = d.log
return conn, nil
}
func (c *MssqlConn) Close() error {
func (c *Conn) Close() error {
return c.sess.buf.transport.Close()
}
type MssqlStmt struct {
c *MssqlConn
type Stmt struct {
c *Conn
query string
paramCount int
notifSub *queryNotifSub
@ -273,30 +328,29 @@ type queryNotifSub struct {
timeout uint32
}
func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
return c.prepareCopyIn(query)
return c.prepareCopyIn(context.Background(), query)
}
return c.prepareContext(context.Background(), query)
}
func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) {
func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
paramCount := -1
if c.processQueryText {
query, paramCount = parseParams(query)
}
return &MssqlStmt{c, query, paramCount, nil}, nil
return &Stmt{c, query, paramCount, nil}, nil
}
func (s *MssqlStmt) Close() error {
func (s *Stmt) Close() error {
return nil
}
func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) {
to := uint32(timeout / time.Second)
if to < 1 {
to = 1
@ -304,11 +358,11 @@ func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Durati
s.notifSub = &queryNotifSub{id, options, to}
}
func (s *MssqlStmt) NumInput() int {
func (s *Stmt) NumInput() int {
return s.paramCount
}
func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
func (s *Stmt) sendQuery(args []namedValue) (err error) {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
@ -326,11 +380,13 @@ func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
})
}
conn := s.c
// no need to check number of parameters here, it is checked by database/sql
if s.c.sess.logFlags&logSQL != 0 {
s.c.sess.log.Println(s.query)
if conn.sess.logFlags&logSQL != 0 {
conn.sess.log.Println(s.query)
}
if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
for i := 0; i < len(args); i++ {
if len(args[i].Name) > 0 {
s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
@ -338,14 +394,16 @@ func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
}
}
}
reset := conn.resetSession
conn.resetSession = false
if len(args) == 0 {
if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
if s.c.sess.logFlags&logErrors != 0 {
s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
if conn.sess.logFlags&logErrors != 0 {
conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
}
s.c.connectionGood = false
conn.connectionGood = false
return fmt.Errorf("failed to send SQL Batch: %v", err)
}
} else {
@ -363,11 +421,11 @@ func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
params[0] = makeStrParam(s.query)
params[1] = makeStrParam(strings.Join(decls, ","))
}
if err = sendRpc(s.c.sess.buf, headers, proc, 0, params); err != nil {
if s.c.sess.logFlags&logErrors != 0 {
s.c.sess.log.Printf("Failed to send Rpc with %v", err)
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
if conn.sess.logFlags&logErrors != 0 {
conn.sess.log.Printf("Failed to send Rpc with %v", err)
}
s.c.connectionGood = false
conn.connectionGood = false
return fmt.Errorf("Failed to send RPC: %v", err)
}
}
@ -386,7 +444,7 @@ func isProc(s string) bool {
return !strings.ContainsAny(s, " \t\n\r;")
}
func (s *MssqlStmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) {
func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) {
var err error
params := make([]Param, len(args)+offset)
decls := make([]string, len(args))
@ -424,11 +482,11 @@ func convertOldArgs(args []driver.Value) []namedValue {
return list
}
func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.queryContext(context.Background(), convertOldArgs(args))
}
func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
@ -438,7 +496,7 @@ func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows d
return s.processQueryResponse(ctx)
}
func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
tokchan := make(chan tokenStruct, 5)
ctx, cancel := context.WithCancel(ctx)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
@ -466,15 +524,15 @@ loop:
return nil, s.c.checkBadConn(token)
}
}
res = &MssqlRows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
return
}
func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), convertOldArgs(args))
}
func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
@ -487,7 +545,7 @@ func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Res
return
}
func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
s.c.clearOuts()
@ -509,11 +567,11 @@ func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err err
return nil, token
}
}
return &MssqlResult{s.c, rowCount}, nil
return &Result{s.c, rowCount}, nil
}
type MssqlRows struct {
stmt *MssqlStmt
type Rows struct {
stmt *Stmt
cols []columnStruct
tokchan chan tokenStruct
@ -522,7 +580,7 @@ type MssqlRows struct {
cancel func()
}
func (rc *MssqlRows) Close() error {
func (rc *Rows) Close() error {
rc.cancel()
for range rc.tokchan {
}
@ -530,7 +588,7 @@ func (rc *MssqlRows) Close() error {
return nil
}
func (rc *MssqlRows) Columns() (res []string) {
func (rc *Rows) Columns() (res []string) {
res = make([]string, len(rc.cols))
for i, col := range rc.cols {
res[i] = col.ColName
@ -538,7 +596,7 @@ func (rc *MssqlRows) Columns() (res []string) {
return
}
func (rc *MssqlRows) Next(dest []driver.Value) error {
func (rc *Rows) Next(dest []driver.Value) error {
if !rc.stmt.c.connectionGood {
return driver.ErrBadConn
}
@ -566,11 +624,11 @@ func (rc *MssqlRows) Next(dest []driver.Value) error {
return io.EOF
}
func (rc *MssqlRows) HasNextResultSet() bool {
func (rc *Rows) HasNextResultSet() bool {
return rc.nextCols != nil
}
func (rc *MssqlRows) NextResultSet() error {
func (rc *Rows) NextResultSet() error {
rc.cols = rc.nextCols
rc.nextCols = nil
if rc.cols == nil {
@ -582,7 +640,7 @@ func (rc *MssqlRows) NextResultSet() error {
// It should return
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
return makeGoLangScanType(r.cols[index].ti)
}
@ -591,7 +649,7 @@ func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
return makeGoLangTypeName(r.cols[index].ti)
}
@ -606,7 +664,7 @@ func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
return makeGoLangTypeLength(r.cols[index].ti)
}
@ -616,7 +674,7 @@ func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
// decimal(38, 4) (38, 4, true)
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
return makeGoLangTypePrecisionScale(r.cols[index].ti)
}
@ -624,7 +682,7 @@ func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
// be true if it is known the column may be null, or false if the column is known
// to be not nullable.
// If the column nullability is unknown, ok should be false.
func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) {
func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
nullable = r.cols[index].Flags&colFlagNullable != 0
ok = true
return
@ -637,7 +695,7 @@ func makeStrParam(val string) (res Param) {
return
}
func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
func (s *Stmt) makeParam(val driver.Value) (res Param, err error) {
if val == nil {
res.ti.TypeId = typeNull
res.buffer = nil
@ -706,16 +764,16 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
return
}
type MssqlResult struct {
c *MssqlConn
type Result struct {
c *Conn
rowsAffected int64
}
func (r *MssqlResult) RowsAffected() (int64, error) {
func (r *Result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
func (r *MssqlResult) LastInsertId() (int64, error) {
func (r *Result) LastInsertId() (int64, error) {
s, err := r.c.Prepare("select cast(@@identity as bigint)")
if err != nil {
return 0, err

50
vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go generated vendored Normal file
View File

@ -0,0 +1,50 @@
// +build go1.10
package mssql
import (
"context"
"database/sql/driver"
)
var _ driver.Connector = &Connector{}
var _ driver.SessionResetter = &Conn{}
func (c *Conn) ResetSession(ctx context.Context) error {
if !c.connectionGood {
return driver.ErrBadConn
}
c.resetSession = true
if c.connector == nil || len(c.connector.ResetSQL) == 0 {
return nil
}
s, err := c.prepareContext(ctx, c.connector.ResetSQL)
if err != nil {
return driver.ErrBadConn
}
_, err = s.exec(ctx, nil)
if err != nil {
return driver.ErrBadConn
}
return nil
}
// Connect to the server and return a TDS connection.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
conn, err := c.driver.connect(ctx, c.params)
if conn != nil {
conn.connector = c
}
if err == nil {
err = conn.ResetSession(ctx)
}
return conn, err
}
// Driver underlying the Connector.
func (c *Connector) Driver() driver.Driver {
return c.driver
}

View File

@ -10,22 +10,22 @@ import (
"strings"
)
var _ driver.Pinger = &MssqlConn{}
var _ driver.Pinger = &Conn{}
// Ping is used to check if the remote server is available and satisfies the Pinger interface.
func (c *MssqlConn) Ping(ctx context.Context) error {
func (c *Conn) Ping(ctx context.Context) error {
if !c.connectionGood {
return driver.ErrBadConn
}
stmt := &MssqlStmt{c, `select 1;`, 0, nil}
stmt := &Stmt{c, `select 1;`, 0, nil}
_, err := stmt.ExecContext(ctx, nil)
return err
}
var _ driver.ConnBeginTx = &MssqlConn{}
var _ driver.ConnBeginTx = &Conn{}
// BeginTx satisfies ConnBeginTx.
func (c *MssqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
@ -57,18 +57,18 @@ func (c *MssqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.
return c.begin(ctx, tdsIsolation)
}
func (c *MssqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
return c.prepareCopyIn(query)
return c.prepareCopyIn(ctx, query)
}
return c.prepareContext(ctx, query)
}
func (s *MssqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
@ -79,7 +79,7 @@ func (s *MssqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue)
return s.queryContext(ctx, list)
}
func (s *MssqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}

View File

@ -9,9 +9,20 @@ import (
// "github.com/cockroachdb/apd"
)
var _ driver.NamedValueChecker = &MssqlConn{}
// Type alias provided for compibility.
//
// Deprecated: users should transition to the new names when possible.
type MssqlDriver = Driver
type MssqlBulk = Bulk
type MssqlBulkOptions = BulkOptions
type MssqlConn = Conn
type MssqlResult = Result
type MssqlRows = Rows
type MssqlStmt = Stmt
func (c *MssqlConn) CheckNamedValue(nv *driver.NamedValue) error {
var _ driver.NamedValueChecker = &Conn{}
func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
switch v := nv.Value.(type) {
case sql.Out:
if c.outs == nil {
@ -41,7 +52,7 @@ func (c *MssqlConn) CheckNamedValue(nv *driver.NamedValue) error {
}
}
func (s *MssqlStmt) makeParamExtra(val driver.Value) (res Param, err error) {
func (s *Stmt) makeParamExtra(val driver.Value) (res Param, err error) {
switch val := val.(type) {
case sql.Out:
res, err = s.makeParam(val.Dest)

View File

@ -7,6 +7,6 @@ import (
"fmt"
)
func (s *MssqlStmt) makeParamExtra(val driver.Value) (Param, error) {
func (s *Stmt) makeParamExtra(val driver.Value) (Param, error) {
return Param{}, fmt.Errorf("mssql: unknown type for %T", val)
}

View File

@ -48,9 +48,11 @@ func (c *timeoutConn) Read(b []byte) (n int, err error) {
n, err = c.buf.Read(b)
return
}
err = c.c.SetDeadline(time.Now().Add(c.timeout))
if err != nil {
return
if c.timeout > 0 {
err = c.c.SetDeadline(time.Now().Add(c.timeout))
if err != nil {
return
}
}
return c.c.Read(b)
}
@ -58,7 +60,7 @@ func (c *timeoutConn) Read(b []byte) (n int, err error) {
func (c *timeoutConn) Write(b []byte) (n int, err error) {
if c.buf != nil {
if !c.packetPending {
c.buf.BeginPacket(packPrelogin)
c.buf.BeginPacket(packPrelogin, false)
c.packetPending = true
}
n, err = c.buf.Write(b)
@ -67,9 +69,11 @@ func (c *timeoutConn) Write(b []byte) (n int, err error) {
}
return
}
err = c.c.SetDeadline(time.Now().Add(c.timeout))
if err != nil {
return
if c.timeout > 0 {
err = c.c.SetDeadline(time.Now().Add(c.timeout))
if err != nil {
return
}
}
return c.c.Write(b)
}

View File

@ -57,8 +57,8 @@ var (
)
// http://msdn.microsoft.com/en-us/library/dd357576.aspx
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc ProcId, flags uint16, params []Param) (err error) {
buf.BeginPacket(packRPCRequest)
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc ProcId, flags uint16, params []Param, resetSession bool) (err error) {
buf.BeginPacket(packRPCRequest, resetSession)
writeAllHeaders(buf, headers)
if len(proc.name) == 0 {
var idswitch uint16 = 0xffff

View File

@ -50,13 +50,17 @@ func parseInstances(msg []byte) map[string]map[string]string {
return results
}
func getInstances(address string) (map[string]map[string]string, error) {
conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
func getInstances(ctx context.Context, address string) (map[string]map[string]string, error) {
maxTime := 5 * time.Second
dialer := &net.Dialer{
Timeout: maxTime,
}
conn, err := dialer.DialContext(ctx, "udp", address+":1434")
if err != nil {
return nil, err
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
conn.SetDeadline(time.Now().Add(maxTime))
_, err = conn.Write([]byte{3})
if err != nil {
return nil, err
@ -159,7 +163,7 @@ func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
var err error
w.BeginPacket(packPrelogin)
w.BeginPacket(packPrelogin, false)
offset := uint16(5*len(fields) + 1)
keys := make(KeySlice, 0, len(fields))
for k := range fields {
@ -349,7 +353,7 @@ func manglePassword(password string) []byte {
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
func sendLogin(w *tdsBuffer, login login) error {
w.BeginPacket(packLogin7)
w.BeginPacket(packLogin7, false)
hostname := str2ucs2(login.HostName)
username := str2ucs2(login.UserName)
password := manglePassword(login.Password)
@ -630,8 +634,8 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
return nil
}
func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) {
buf.BeginPacket(packSQLBatch)
func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
buf.BeginPacket(packSQLBatch, resetSession)
if err = writeAllHeaders(buf, headers); err != nil {
return
@ -647,7 +651,7 @@ func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
func sendAttention(buf *tdsBuffer) error {
buf.BeginPacket(packAttention)
buf.BeginPacket(packAttention, false)
return buf.FinishPacket()
}
@ -935,13 +939,13 @@ func parseConnectParams(dsn string) (connectParams, error) {
strlog, ok := params["log"]
if ok {
var err error
p.logFlags, err = strconv.ParseUint(strlog, 10, 0)
p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
if err != nil {
return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
}
}
server := params["server"]
parts := strings.SplitN(server, "\\", 2)
parts := strings.SplitN(server, `\`, 2)
p.host = parts[0]
if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
p.host = "localhost"
@ -957,7 +961,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
strport, ok := params["port"]
if ok {
var err error
p.port, err = strconv.ParseUint(strport, 0, 16)
p.port, err = strconv.ParseUint(strport, 10, 16)
if err != nil {
f := "Invalid tcp port '%v': %v"
return p, fmt.Errorf(f, strport, err.Error())
@ -993,7 +997,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
p.conn_timeout = 30 * time.Second
strconntimeout, ok := params["connection timeout"]
if ok {
timeout, err := strconv.ParseUint(strconntimeout, 0, 16)
timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
if err != nil {
f := "Invalid connection timeout '%v': %v"
return p, fmt.Errorf(f, strconntimeout, err.Error())
@ -1002,7 +1006,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
strdialtimeout, ok := params["dial timeout"]
if ok {
timeout, err := strconv.ParseUint(strdialtimeout, 0, 16)
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
if err != nil {
f := "Invalid dial timeout '%v': %v"
return p, fmt.Errorf(f, strdialtimeout, err.Error())
@ -1015,7 +1019,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
p.keepAlive = 30 * time.Second
if keepAlive, ok := params["keepalive"]; ok {
timeout, err := strconv.ParseUint(keepAlive, 0, 16)
timeout, err := strconv.ParseUint(keepAlive, 10, 64)
if err != nil {
f := "Invalid keepAlive value '%s': %s"
return p, fmt.Errorf(f, keepAlive, err.Error())
@ -1109,7 +1113,7 @@ type auth interface {
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
// list of IP addresses. So if there is more than one, try them all and
// use the first one that allows a connection.
func dialConnection(p connectParams) (conn net.Conn, err error) {
func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err error) {
var ips []net.IP
ips, err = net.LookupIP(p.host)
if err != nil {
@ -1122,7 +1126,7 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
if len(ips) == 1 {
d := createDialer(&p)
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
conn, err = d.Dial(addr)
conn, err = d.Dial(ctx, addr)
} else {
//Try Dials in parallel to avoid waiting for timeouts.
@ -1133,7 +1137,7 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
go func(ip net.IP) {
d := createDialer(&p)
addr := net.JoinHostPort(ip.String(), portStr)
conn, err := d.Dial(addr)
conn, err := d.Dial(ctx, addr)
if err == nil {
connChan <- conn
} else {
@ -1171,12 +1175,17 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
return conn, err
}
func connect(log optionalLogger, p connectParams) (res *tdsSession, err error) {
res = nil
func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tdsSession, err error) {
dialCtx := ctx
if p.dial_timeout > 0 {
var cancel func()
dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
defer cancel()
}
// if instance is specified use instance resolution service
if p.instance != "" {
p.instance = strings.ToUpper(p.instance)
instances, err := getInstances(p.host)
instances, err := getInstances(dialCtx, p.host)
if err != nil {
f := "Unable to get instances from Sql Server Browser on host %v: %v"
return nil, fmt.Errorf(f, p.host, err.Error())
@ -1194,7 +1203,7 @@ func connect(log optionalLogger, p connectParams) (res *tdsSession, err error) {
}
initiate_connection:
conn, err := dialConnection(p)
conn, err := dialConnection(dialCtx, p)
if err != nil {
return nil, err
}
@ -1334,7 +1343,7 @@ continue_login:
}
}
if sspi_msg != nil {
outbuf.BeginPacket(packSSPIMessage)
outbuf.BeginPacket(packSSPIMessage, false)
_, err = outbuf.Write(sspi_msg)
if err != nil {
return nil, err

View File

@ -28,9 +28,8 @@ const (
isolationSnapshot = 5
)
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel,
name string) (err error) {
buf.BeginPacket(packTransMgrReq)
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) {
buf.BeginPacket(packTransMgrReq, resetSession)
writeAllHeaders(buf, headers)
var rqtype uint16 = tmBeginXact
err = binary.Write(buf, binary.LittleEndian, &rqtype)
@ -52,8 +51,8 @@ const (
fBeginXact = 1
)
func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error {
buf.BeginPacket(packTransMgrReq)
func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error {
buf.BeginPacket(packTransMgrReq, resetSession)
writeAllHeaders(buf, headers)
var rqtype uint16 = tmCommitXact
err := binary.Write(buf, binary.LittleEndian, &rqtype)
@ -81,8 +80,8 @@ func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags u
return buf.FinishPacket()
}
func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error {
buf.BeginPacket(packTransMgrReq)
func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error {
buf.BeginPacket(packTransMgrReq, resetSession)
writeAllHeaders(buf, headers)
var rqtype uint16 = tmRollbackXact
err := binary.Write(buf, binary.LittleEndian, &rqtype)

View File

@ -9,6 +9,8 @@ import (
"reflect"
"strconv"
"time"
"github.com/denisenkom/go-mssqldb/internal/cp"
)
// fixed-length data types
@ -79,7 +81,7 @@ type typeInfo struct {
Scale uint8
Prec uint8
Buffer []byte
Collation collation
Collation cp.Collation
UdtInfo udtInfo
XmlInfo xmlInfo
Reader func(ti *typeInfo, r *tdsBuffer) (res interface{})
@ -487,6 +489,20 @@ func writeLongLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
return
}
func readCollation(r *tdsBuffer) (res cp.Collation) {
res.LcidAndFlags = r.uint32()
res.SortId = r.byte()
return
}
func writeCollation(w io.Writer, col cp.Collation) (err error) {
if err = binary.Write(w, binary.LittleEndian, col.LcidAndFlags); err != nil {
return
}
err = binary.Write(w, binary.LittleEndian, col.SortId)
return
}
// reads variant value
// http://msdn.microsoft.com/en-us/library/dd303302.aspx
func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} {
@ -848,8 +864,8 @@ func dateTime2(t time.Time) (days int32, ns int64) {
return
}
func decodeChar(col collation, buf []byte) string {
return charset2utf8(col, buf)
func decodeChar(col cp.Collation, buf []byte) string {
return cp.CharsetToUTF8(col, buf)
}
func decodeUcs2(buf []byte) string {
@ -922,7 +938,7 @@ func makeGoLangScanType(ti typeInfo) reflect.Type {
return reflect.TypeOf(true)
case typeDecimalN, typeNumericN:
return reflect.TypeOf([]byte{})
case typeMoneyN:
case typeMoney, typeMoney4, typeMoneyN:
switch ti.Size {
case 4:
return reflect.TypeOf([]byte{})
@ -1083,6 +1099,8 @@ func makeDecl(ti typeInfo) string {
return "ntext"
case typeUdt:
return ti.UdtInfo.TypeName
case typeGuid:
return "uniqueidentifier"
default:
panic(fmt.Sprintf("not implemented makeDecl for type %#x", ti.TypeId))
}
@ -1140,7 +1158,7 @@ func makeGoLangTypeName(ti typeInfo) string {
return "BIT"
case typeDecimalN, typeNumericN:
return "DECIMAL"
case typeMoneyN:
case typeMoney, typeMoney4, typeMoneyN:
switch ti.Size {
case 4:
return "SMALLMONEY"
@ -1247,7 +1265,7 @@ func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
return 0, false
case typeDecimalN, typeNumericN:
return 0, false
case typeMoneyN:
case typeMoney, typeMoney4, typeMoneyN:
switch ti.Size {
case 4:
return 0, false
@ -1370,7 +1388,7 @@ func makeGoLangTypePrecisionScale(ti typeInfo) (int64, int64, bool) {
return 0, 0, false
case typeDecimalN, typeNumericN:
return int64(ti.Prec), int64(ti.Scale), true
case typeMoneyN:
case typeMoney, typeMoney4, typeMoneyN:
switch ti.Size {
case 4:
return 0, 0, false