github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/database/mysql/driver.go (about)

     1  // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     5  // You can obtain one at http://mozilla.org/MPL/2.0/.
     6  
     7  // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
     8  //
     9  // The driver should be used via the database/sql package:
    10  //
    11  //  import "database/sql"
    12  //  import _ "github.com/chai2010/gopkg/database/mysql"
    13  //
    14  //  db, err := sql.Open("mysql", "user:password@/dbname")
    15  //
    16  // See https://github.com/chai2010/gopkg/tree/master/database/mysql#usage for details
    17  package mysql
    18  
    19  import (
    20  	"database/sql"
    21  	"database/sql/driver"
    22  	"net"
    23  )
    24  
    25  // This struct is exported to make the driver directly accessible.
    26  // In general the driver is used via the database/sql package.
    27  type MySQLDriver struct{}
    28  
    29  // DialFunc is a function which can be used to establish the network connection.
    30  // Custom dial functions must be registered with RegisterDial
    31  type DialFunc func(addr string) (net.Conn, error)
    32  
    33  var dials map[string]DialFunc
    34  
    35  // RegisterDial registers a custom dial function. It can then be used by the
    36  // network address mynet(addr), where mynet is the registered new network.
    37  // addr is passed as a parameter to the dial function.
    38  func RegisterDial(net string, dial DialFunc) {
    39  	if dials == nil {
    40  		dials = make(map[string]DialFunc)
    41  	}
    42  	dials[net] = dial
    43  }
    44  
    45  // Open new Connection.
    46  // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
    47  // the DSN string is formated
    48  func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
    49  	var err error
    50  
    51  	// New mysqlConn
    52  	mc := &mysqlConn{
    53  		maxPacketAllowed: maxPacketSize,
    54  		maxWriteSize:     maxPacketSize - 1,
    55  	}
    56  	mc.cfg, err = parseDSN(dsn)
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	// Connect to Server
    62  	if dial, ok := dials[mc.cfg.net]; ok {
    63  		mc.netConn, err = dial(mc.cfg.addr)
    64  	} else {
    65  		nd := net.Dialer{Timeout: mc.cfg.timeout}
    66  		mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)
    67  	}
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	// Enable TCP Keepalives on TCP connections
    73  	if tc, ok := mc.netConn.(*net.TCPConn); ok {
    74  		if err := tc.SetKeepAlive(true); err != nil {
    75  			mc.Close()
    76  			return nil, err
    77  		}
    78  	}
    79  
    80  	mc.buf = newBuffer(mc.netConn)
    81  
    82  	// Reading Handshake Initialization Packet
    83  	cipher, err := mc.readInitPacket()
    84  	if err != nil {
    85  		mc.Close()
    86  		return nil, err
    87  	}
    88  
    89  	// Send Client Authentication Packet
    90  	if err = mc.writeAuthPacket(cipher); err != nil {
    91  		mc.Close()
    92  		return nil, err
    93  	}
    94  
    95  	// Read Result Packet
    96  	err = mc.readResultOK()
    97  	if err != nil {
    98  		// Retry with old authentication method, if allowed
    99  		if mc.cfg != nil && mc.cfg.allowOldPasswords && err == ErrOldPassword {
   100  			if err = mc.writeOldAuthPacket(cipher); err != nil {
   101  				mc.Close()
   102  				return nil, err
   103  			}
   104  			if err = mc.readResultOK(); err != nil {
   105  				mc.Close()
   106  				return nil, err
   107  			}
   108  		} else {
   109  			mc.Close()
   110  			return nil, err
   111  		}
   112  
   113  	}
   114  
   115  	// Get max allowed packet size
   116  	maxap, err := mc.getSystemVar("max_allowed_packet")
   117  	if err != nil {
   118  		mc.Close()
   119  		return nil, err
   120  	}
   121  	mc.maxPacketAllowed = stringToInt(maxap) - 1
   122  	if mc.maxPacketAllowed < maxPacketSize {
   123  		mc.maxWriteSize = mc.maxPacketAllowed
   124  	}
   125  
   126  	// Handle DSN Params
   127  	err = mc.handleParams()
   128  	if err != nil {
   129  		mc.Close()
   130  		return nil, err
   131  	}
   132  
   133  	return mc, nil
   134  }
   135  
   136  func init() {
   137  	sql.Register("mysql", &MySQLDriver{})
   138  }