github.com/slspeek/camlistore_namedsearch@v0.0.0-20140519202248-ed6f70f7721a/third_party/labix.org/v2/mgo/socket.go (about)

     1  // mgo - MongoDB driver for Go
     2  //
     3  // Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
     4  //
     5  // All rights reserved.
     6  //
     7  // Redistribution and use in source and binary forms, with or without
     8  // modification, are permitted provided that the following conditions are met:
     9  //
    10  // 1. Redistributions of source code must retain the above copyright notice, this
    11  //    list of conditions and the following disclaimer.
    12  // 2. Redistributions in binary form must reproduce the above copyright notice,
    13  //    this list of conditions and the following disclaimer in the documentation
    14  //    and/or other materials provided with the distribution.
    15  //
    16  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
    17  // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
    18  // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
    19  // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
    20  // ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
    21  // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
    22  // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
    23  // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    24  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    25  // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    26  
    27  package mgo
    28  
    29  import (
    30  	"camlistore.org/third_party/labix.org/v2/mgo/bson"
    31  	"errors"
    32  	"net"
    33  	"sync"
    34  	"time"
    35  )
    36  
    37  type replyFunc func(err error, reply *replyOp, docNum int, docData []byte)
    38  
    39  type mongoSocket struct {
    40  	sync.Mutex
    41  	server        *mongoServer // nil when cached
    42  	conn          net.Conn
    43  	timeout       time.Duration
    44  	addr          string // For debugging only.
    45  	nextRequestId uint32
    46  	replyFuncs    map[uint32]replyFunc
    47  	references    int
    48  	auth          []authInfo
    49  	logout        []authInfo
    50  	cachedNonce   string
    51  	gotNonce      sync.Cond
    52  	dead          error
    53  	serverInfo    *mongoServerInfo
    54  }
    55  
    56  type queryOpFlags uint32
    57  
    58  const (
    59  	_ queryOpFlags = 1 << iota
    60  	flagTailable
    61  	flagSlaveOk
    62  	flagLogReplay
    63  	flagNoCursorTimeout
    64  	flagAwaitData
    65  )
    66  
    67  type queryOp struct {
    68  	collection string
    69  	query      interface{}
    70  	skip       int32
    71  	limit      int32
    72  	selector   interface{}
    73  	flags      queryOpFlags
    74  	replyFunc  replyFunc
    75  
    76  	options    queryWrapper
    77  	hasOptions bool
    78  	serverTags []bson.D
    79  }
    80  
    81  type queryWrapper struct {
    82  	Query          interface{} "$query"
    83  	OrderBy        interface{} "$orderby,omitempty"
    84  	Hint           interface{} "$hint,omitempty"
    85  	Explain        bool        "$explain,omitempty"
    86  	Snapshot       bool        "$snapshot,omitempty"
    87  	ReadPreference bson.D      "$readPreference,omitempty"
    88  }
    89  
    90  func (op *queryOp) finalQuery(socket *mongoSocket) interface{} {
    91  	if op.flags&flagSlaveOk != 0 && len(op.serverTags) > 0 && socket.ServerInfo().Mongos {
    92  		op.hasOptions = true
    93  		op.options.ReadPreference = bson.D{{"mode", "secondaryPreferred"}, {"tags", op.serverTags}}
    94  	}
    95  	if op.hasOptions {
    96  		if op.query == nil {
    97  			var empty bson.D
    98  			op.options.Query = empty
    99  		} else {
   100  			op.options.Query = op.query
   101  		}
   102  		debugf("final query is %#v\n", &op.options)
   103  		return &op.options
   104  	}
   105  	return op.query
   106  }
   107  
   108  type getMoreOp struct {
   109  	collection string
   110  	limit      int32
   111  	cursorId   int64
   112  	replyFunc  replyFunc
   113  }
   114  
   115  type replyOp struct {
   116  	flags     uint32
   117  	cursorId  int64
   118  	firstDoc  int32
   119  	replyDocs int32
   120  }
   121  
   122  type insertOp struct {
   123  	collection string        // "database.collection"
   124  	documents  []interface{} // One or more documents to insert
   125  }
   126  
   127  type updateOp struct {
   128  	collection string // "database.collection"
   129  	selector   interface{}
   130  	update     interface{}
   131  	flags      uint32
   132  }
   133  
   134  type deleteOp struct {
   135  	collection string // "database.collection"
   136  	selector   interface{}
   137  	flags      uint32
   138  }
   139  
   140  type killCursorsOp struct {
   141  	cursorIds []int64
   142  }
   143  
   144  type requestInfo struct {
   145  	bufferPos int
   146  	replyFunc replyFunc
   147  }
   148  
   149  func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket {
   150  	socket := &mongoSocket{
   151  		conn:       conn,
   152  		addr:       server.Addr,
   153  		server:     server,
   154  		replyFuncs: make(map[uint32]replyFunc),
   155  	}
   156  	socket.gotNonce.L = &socket.Mutex
   157  	if err := socket.InitialAcquire(server.Info(), timeout); err != nil {
   158  		panic("newSocket: InitialAcquire returned error: " + err.Error())
   159  	}
   160  	stats.socketsAlive(+1)
   161  	debugf("Socket %p to %s: initialized", socket, socket.addr)
   162  	socket.resetNonce()
   163  	go socket.readLoop()
   164  	return socket
   165  }
   166  
   167  // Server returns the server that the socket is associated with.
   168  // It returns nil while the socket is cached in its respective server.
   169  func (socket *mongoSocket) Server() *mongoServer {
   170  	socket.Lock()
   171  	server := socket.server
   172  	socket.Unlock()
   173  	return server
   174  }
   175  
   176  // ServerInfo returns details for the server at the time the socket
   177  // was initially acquired.
   178  func (socket *mongoSocket) ServerInfo() *mongoServerInfo {
   179  	socket.Lock()
   180  	serverInfo := socket.serverInfo
   181  	socket.Unlock()
   182  	return serverInfo
   183  }
   184  
   185  // InitialAcquire obtains the first reference to the socket, either
   186  // right after the connection is made or once a recycled socket is
   187  // being put back in use.
   188  func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error {
   189  	socket.Lock()
   190  	if socket.references > 0 {
   191  		panic("Socket acquired out of cache with references")
   192  	}
   193  	if socket.dead != nil {
   194  		socket.Unlock()
   195  		return socket.dead
   196  	}
   197  	socket.references++
   198  	socket.serverInfo = serverInfo
   199  	socket.timeout = timeout
   200  	stats.socketsInUse(+1)
   201  	stats.socketRefs(+1)
   202  	socket.Unlock()
   203  	return nil
   204  }
   205  
   206  // Acquire obtains an additional reference to the socket.
   207  // The socket will only be recycled when it's released as many
   208  // times as it's been acquired.
   209  func (socket *mongoSocket) Acquire() (info *mongoServerInfo) {
   210  	socket.Lock()
   211  	if socket.references == 0 {
   212  		panic("Socket got non-initial acquire with references == 0")
   213  	}
   214  	// We'll track references to dead sockets as well.
   215  	// Caller is still supposed to release the socket.
   216  	socket.references++
   217  	stats.socketRefs(+1)
   218  	serverInfo := socket.serverInfo
   219  	socket.Unlock()
   220  	return serverInfo
   221  }
   222  
   223  // Release decrements a socket reference. The socket will be
   224  // recycled once its released as many times as it's been acquired.
   225  func (socket *mongoSocket) Release() {
   226  	socket.Lock()
   227  	if socket.references == 0 {
   228  		panic("socket.Release() with references == 0")
   229  	}
   230  	socket.references--
   231  	stats.socketRefs(-1)
   232  	if socket.references == 0 {
   233  		stats.socketsInUse(-1)
   234  		server := socket.server
   235  		socket.Unlock()
   236  		socket.LogoutAll()
   237  		// If the socket is dead server is nil.
   238  		if server != nil {
   239  			server.RecycleSocket(socket)
   240  		}
   241  	} else {
   242  		socket.Unlock()
   243  	}
   244  }
   245  
   246  // SetTimeout changes the timeout used on socket operations.
   247  func (socket *mongoSocket) SetTimeout(d time.Duration) {
   248  	socket.Lock()
   249  	socket.timeout = d
   250  	socket.Unlock()
   251  }
   252  
   253  type deadlineType int
   254  
   255  const (
   256  	readDeadline  deadlineType = 1
   257  	writeDeadline deadlineType = 2
   258  )
   259  
   260  func (socket *mongoSocket) updateDeadline(which deadlineType) {
   261  	var when time.Time
   262  	if socket.timeout > 0 {
   263  		when = time.Now().Add(socket.timeout)
   264  	}
   265  	whichstr := ""
   266  	switch which {
   267  	case readDeadline | writeDeadline:
   268  		whichstr = "read/write"
   269  		socket.conn.SetDeadline(when)
   270  	case readDeadline:
   271  		whichstr = "read"
   272  		socket.conn.SetReadDeadline(when)
   273  	case writeDeadline:
   274  		whichstr = "write"
   275  		socket.conn.SetWriteDeadline(when)
   276  	default:
   277  		panic("invalid parameter to updateDeadline")
   278  	}
   279  	debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when)
   280  }
   281  
   282  // Close terminates the socket use.
   283  func (socket *mongoSocket) Close() {
   284  	socket.kill(errors.New("Closed explicitly"), false)
   285  }
   286  
   287  func (socket *mongoSocket) kill(err error, abend bool) {
   288  	socket.Lock()
   289  	if socket.dead != nil {
   290  		debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error())
   291  		socket.Unlock()
   292  		return
   293  	}
   294  	logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend)
   295  	socket.dead = err
   296  	socket.conn.Close()
   297  	stats.socketsAlive(-1)
   298  	replyFuncs := socket.replyFuncs
   299  	socket.replyFuncs = make(map[uint32]replyFunc)
   300  	server := socket.server
   301  	socket.server = nil
   302  	socket.Unlock()
   303  	for _, f := range replyFuncs {
   304  		logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error())
   305  		f(err, nil, -1, nil)
   306  	}
   307  	if abend {
   308  		server.AbendSocket(socket)
   309  	}
   310  }
   311  
   312  func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) {
   313  	var mutex sync.Mutex
   314  	var replyData []byte
   315  	var replyErr error
   316  	mutex.Lock()
   317  	op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
   318  		replyData = docData
   319  		replyErr = err
   320  		mutex.Unlock()
   321  	}
   322  	err = socket.Query(op)
   323  	if err != nil {
   324  		return nil, err
   325  	}
   326  	mutex.Lock() // Wait.
   327  	if replyErr != nil {
   328  		return nil, replyErr
   329  	}
   330  	return replyData, nil
   331  }
   332  
   333  func (socket *mongoSocket) Query(ops ...interface{}) (err error) {
   334  
   335  	if lops := socket.flushLogout(); len(lops) > 0 {
   336  		ops = append(lops, ops...)
   337  	}
   338  
   339  	buf := make([]byte, 0, 256)
   340  
   341  	// Serialize operations synchronously to avoid interrupting
   342  	// other goroutines while we can't really be sending data.
   343  	// Also, record id positions so that we can compute request
   344  	// ids at once later with the lock already held.
   345  	requests := make([]requestInfo, len(ops))
   346  	requestCount := 0
   347  
   348  	for _, op := range ops {
   349  		debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op)
   350  		start := len(buf)
   351  		var replyFunc replyFunc
   352  		switch op := op.(type) {
   353  
   354  		case *updateOp:
   355  			buf = addHeader(buf, 2001)
   356  			buf = addInt32(buf, 0) // Reserved
   357  			buf = addCString(buf, op.collection)
   358  			buf = addInt32(buf, int32(op.flags))
   359  			debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector)
   360  			buf, err = addBSON(buf, op.selector)
   361  			if err != nil {
   362  				return err
   363  			}
   364  			debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.update)
   365  			buf, err = addBSON(buf, op.update)
   366  			if err != nil {
   367  				return err
   368  			}
   369  
   370  		case *insertOp:
   371  			buf = addHeader(buf, 2002)
   372  			buf = addInt32(buf, 0) // Reserved
   373  			buf = addCString(buf, op.collection)
   374  			for _, doc := range op.documents {
   375  				debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc)
   376  				buf, err = addBSON(buf, doc)
   377  				if err != nil {
   378  					return err
   379  				}
   380  			}
   381  
   382  		case *queryOp:
   383  			buf = addHeader(buf, 2004)
   384  			buf = addInt32(buf, int32(op.flags))
   385  			buf = addCString(buf, op.collection)
   386  			buf = addInt32(buf, op.skip)
   387  			buf = addInt32(buf, op.limit)
   388  			buf, err = addBSON(buf, op.finalQuery(socket))
   389  			if err != nil {
   390  				return err
   391  			}
   392  			if op.selector != nil {
   393  				buf, err = addBSON(buf, op.selector)
   394  				if err != nil {
   395  					return err
   396  				}
   397  			}
   398  			replyFunc = op.replyFunc
   399  
   400  		case *getMoreOp:
   401  			buf = addHeader(buf, 2005)
   402  			buf = addInt32(buf, 0) // Reserved
   403  			buf = addCString(buf, op.collection)
   404  			buf = addInt32(buf, op.limit)
   405  			buf = addInt64(buf, op.cursorId)
   406  			replyFunc = op.replyFunc
   407  
   408  		case *deleteOp:
   409  			buf = addHeader(buf, 2006)
   410  			buf = addInt32(buf, 0) // Reserved
   411  			buf = addCString(buf, op.collection)
   412  			buf = addInt32(buf, int32(op.flags))
   413  			debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector)
   414  			buf, err = addBSON(buf, op.selector)
   415  			if err != nil {
   416  				return err
   417  			}
   418  
   419  		case *killCursorsOp:
   420  			buf = addHeader(buf, 2007)
   421  			buf = addInt32(buf, 0) // Reserved
   422  			buf = addInt32(buf, int32(len(op.cursorIds)))
   423  			for _, cursorId := range op.cursorIds {
   424  				buf = addInt64(buf, cursorId)
   425  			}
   426  
   427  		default:
   428  			panic("Internal error: unknown operation type")
   429  		}
   430  
   431  		setInt32(buf, start, int32(len(buf)-start))
   432  
   433  		if replyFunc != nil {
   434  			request := &requests[requestCount]
   435  			request.replyFunc = replyFunc
   436  			request.bufferPos = start
   437  			requestCount++
   438  		}
   439  	}
   440  
   441  	// Buffer is ready for the pipe.  Lock, allocate ids, and enqueue.
   442  
   443  	socket.Lock()
   444  	if socket.dead != nil {
   445  		socket.Unlock()
   446  		debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error())
   447  		// XXX This seems necessary in case the session is closed concurrently
   448  		// with a query being performed, but it's not yet tested:
   449  		for i := 0; i != requestCount; i++ {
   450  			request := &requests[i]
   451  			if request.replyFunc != nil {
   452  				request.replyFunc(socket.dead, nil, -1, nil)
   453  			}
   454  		}
   455  		return socket.dead
   456  	}
   457  
   458  	wasWaiting := len(socket.replyFuncs) > 0
   459  
   460  	// Reserve id 0 for requests which should have no responses.
   461  	requestId := socket.nextRequestId + 1
   462  	if requestId == 0 {
   463  		requestId++
   464  	}
   465  	socket.nextRequestId = requestId + uint32(requestCount)
   466  	for i := 0; i != requestCount; i++ {
   467  		request := &requests[i]
   468  		setInt32(buf, request.bufferPos+4, int32(requestId))
   469  		socket.replyFuncs[requestId] = request.replyFunc
   470  		requestId++
   471  	}
   472  
   473  	debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf))
   474  	stats.sentOps(len(ops))
   475  
   476  	socket.updateDeadline(writeDeadline)
   477  	_, err = socket.conn.Write(buf)
   478  	if !wasWaiting && requestCount > 0 {
   479  		socket.updateDeadline(readDeadline)
   480  	}
   481  	socket.Unlock()
   482  	return err
   483  }
   484  
   485  func fill(r net.Conn, b []byte) error {
   486  	l := len(b)
   487  	n, err := r.Read(b)
   488  	for n != l && err == nil {
   489  		var ni int
   490  		ni, err = r.Read(b[n:])
   491  		n += ni
   492  	}
   493  	return err
   494  }
   495  
   496  // Estimated minimum cost per socket: 1 goroutine + memory for the largest
   497  // document ever seen.
   498  func (socket *mongoSocket) readLoop() {
   499  	p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields
   500  	s := make([]byte, 4)
   501  	conn := socket.conn // No locking, conn never changes.
   502  	for {
   503  		// XXX Handle timeouts, , etc
   504  		err := fill(conn, p)
   505  		if err != nil {
   506  			socket.kill(err, true)
   507  			return
   508  		}
   509  
   510  		totalLen := getInt32(p, 0)
   511  		responseTo := getInt32(p, 8)
   512  		opCode := getInt32(p, 12)
   513  
   514  		// Don't use socket.server.Addr here.  socket is not
   515  		// locked and socket.server may go away.
   516  		debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen)
   517  
   518  		_ = totalLen
   519  
   520  		if opCode != 1 {
   521  			socket.kill(errors.New("opcode != 1, corrupted data?"), true)
   522  			return
   523  		}
   524  
   525  		reply := replyOp{
   526  			flags:     uint32(getInt32(p, 16)),
   527  			cursorId:  getInt64(p, 20),
   528  			firstDoc:  getInt32(p, 28),
   529  			replyDocs: getInt32(p, 32),
   530  		}
   531  
   532  		stats.receivedOps(+1)
   533  		stats.receivedDocs(int(reply.replyDocs))
   534  
   535  		socket.Lock()
   536  		replyFunc, replyFuncFound := socket.replyFuncs[uint32(responseTo)]
   537  		socket.Unlock()
   538  
   539  		if replyFunc != nil && reply.replyDocs == 0 {
   540  			replyFunc(nil, &reply, -1, nil)
   541  		} else {
   542  			for i := 0; i != int(reply.replyDocs); i++ {
   543  				err := fill(conn, s)
   544  				if err != nil {
   545  					socket.kill(err, true)
   546  					return
   547  				}
   548  
   549  				b := make([]byte, int(getInt32(s, 0)))
   550  
   551  				// copy(b, s) in an efficient way.
   552  				b[0] = s[0]
   553  				b[1] = s[1]
   554  				b[2] = s[2]
   555  				b[3] = s[3]
   556  
   557  				err = fill(conn, b[4:])
   558  				if err != nil {
   559  					socket.kill(err, true)
   560  					return
   561  				}
   562  
   563  				if globalDebug && globalLogger != nil {
   564  					m := bson.M{}
   565  					if err := bson.Unmarshal(b, m); err == nil {
   566  						debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m)
   567  					}
   568  				}
   569  
   570  				if replyFunc != nil {
   571  					replyFunc(nil, &reply, i, b)
   572  				}
   573  
   574  				// XXX Do bound checking against totalLen.
   575  			}
   576  		}
   577  
   578  		// Only remove replyFunc after iteration, so that kill() will see it.
   579  		socket.Lock()
   580  		if replyFuncFound {
   581  			delete(socket.replyFuncs, uint32(responseTo))
   582  		}
   583  		if len(socket.replyFuncs) == 0 {
   584  			// Nothing else to read for now. Disable deadline.
   585  			socket.conn.SetReadDeadline(time.Time{})
   586  		} else {
   587  			socket.updateDeadline(readDeadline)
   588  		}
   589  		socket.Unlock()
   590  
   591  		// XXX Do bound checking against totalLen.
   592  	}
   593  }
   594  
   595  var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
   596  
   597  func addHeader(b []byte, opcode int) []byte {
   598  	i := len(b)
   599  	b = append(b, emptyHeader...)
   600  	// Enough for current opcodes.
   601  	b[i+12] = byte(opcode)
   602  	b[i+13] = byte(opcode >> 8)
   603  	return b
   604  }
   605  
   606  func addInt32(b []byte, i int32) []byte {
   607  	return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24))
   608  }
   609  
   610  func addInt64(b []byte, i int64) []byte {
   611  	return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24),
   612  		byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56))
   613  }
   614  
   615  func addCString(b []byte, s string) []byte {
   616  	b = append(b, []byte(s)...)
   617  	b = append(b, 0)
   618  	return b
   619  }
   620  
   621  func addBSON(b []byte, doc interface{}) ([]byte, error) {
   622  	if doc == nil {
   623  		return append(b, 5, 0, 0, 0, 0), nil
   624  	}
   625  	data, err := bson.Marshal(doc)
   626  	if err != nil {
   627  		return b, err
   628  	}
   629  	return append(b, data...), nil
   630  }
   631  
   632  func setInt32(b []byte, pos int, i int32) {
   633  	b[pos] = byte(i)
   634  	b[pos+1] = byte(i >> 8)
   635  	b[pos+2] = byte(i >> 16)
   636  	b[pos+3] = byte(i >> 24)
   637  }
   638  
   639  func getInt32(b []byte, pos int) int32 {
   640  	return (int32(b[pos+0])) |
   641  		(int32(b[pos+1]) << 8) |
   642  		(int32(b[pos+2]) << 16) |
   643  		(int32(b[pos+3]) << 24)
   644  }
   645  
   646  func getInt64(b []byte, pos int) int64 {
   647  	return (int64(b[pos+0])) |
   648  		(int64(b[pos+1]) << 8) |
   649  		(int64(b[pos+2]) << 16) |
   650  		(int64(b[pos+3]) << 24) |
   651  		(int64(b[pos+4]) << 32) |
   652  		(int64(b[pos+5]) << 40) |
   653  		(int64(b[pos+6]) << 48) |
   654  		(int64(b[pos+7]) << 56)
   655  }