gopkg.in/rethinkdb/rethinkdb-go.v6@v6.2.2/connection.go (about)

     1  package rethinkdb
     2  
     3  import (
     4  	"crypto/tls"
     5  	"encoding/binary"
     6  	"encoding/json"
     7  	"fmt"
     8  	"net"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"bytes"
    13  	"github.com/opentracing/opentracing-go"
    14  	"github.com/opentracing/opentracing-go/ext"
    15  	"github.com/opentracing/opentracing-go/log"
    16  	"golang.org/x/net/context"
    17  	p "gopkg.in/rethinkdb/rethinkdb-go.v6/ql2"
    18  	"sync"
    19  )
    20  
    21  const (
    22  	respHeaderLen          = 12
    23  	defaultKeepAlivePeriod = time.Second * 30
    24  
    25  	connNotBad = 0
    26  	connBad    = 1
    27  
    28  	connWorking = 0
    29  	connClosed  = 1
    30  )
    31  
    32  // Response represents the raw response from a query, most of the time you
    33  // should instead use a Cursor when reading from the database.
    34  type Response struct {
    35  	Token     int64
    36  	Type      p.Response_ResponseType   `json:"t"`
    37  	ErrorType p.Response_ErrorType      `json:"e"`
    38  	Notes     []p.Response_ResponseNote `json:"n"`
    39  	Responses []json.RawMessage         `json:"r"`
    40  	Backtrace []interface{}             `json:"b"`
    41  	Profile   interface{}               `json:"p"`
    42  }
    43  
    44  // Connection is a connection to a rethinkdb database. Connection is not thread
    45  // safe and should only be accessed be a single goroutine
    46  type Connection struct {
    47  	net.Conn
    48  
    49  	address string
    50  	opts    *ConnectOpts
    51  
    52  	_                  [4]byte
    53  	token              int64
    54  	cursors            map[int64]*Cursor
    55  	bad                int32 // 0 - not bad, 1 - bad
    56  	closed             int32 // 0 - working, 1 - closed
    57  	stopReadChan       chan bool
    58  	readRequestsChan   chan tokenAndPromise
    59  	responseChan       chan responseAndError
    60  	stopProcessingChan chan struct{}
    61  	mu                 sync.Mutex
    62  }
    63  
    64  type responseAndError struct {
    65  	response *Response
    66  	err      error
    67  }
    68  
    69  type responseAndCursor struct {
    70  	response *Response
    71  	cursor   *Cursor
    72  	err      error
    73  }
    74  
    75  type tokenAndPromise struct {
    76  	ctx     context.Context
    77  	query   *Query
    78  	promise chan responseAndCursor
    79  	span    opentracing.Span
    80  }
    81  
    82  // NewConnection creates a new connection to the database server
    83  func NewConnection(address string, opts *ConnectOpts) (*Connection, error) {
    84  	keepAlivePeriod := defaultKeepAlivePeriod
    85  	if opts.KeepAlivePeriod > 0 {
    86  		keepAlivePeriod = opts.KeepAlivePeriod
    87  	}
    88  
    89  	// Connect to Server
    90  	var err error
    91  	var conn net.Conn
    92  	nd := net.Dialer{Timeout: opts.Timeout, KeepAlive: keepAlivePeriod}
    93  	if opts.TLSConfig == nil {
    94  		conn, err = nd.Dial("tcp", address)
    95  	} else {
    96  		conn, err = tls.DialWithDialer(&nd, "tcp", address, opts.TLSConfig)
    97  	}
    98  	if err != nil {
    99  		return nil, RQLConnectionError{rqlError(err.Error())}
   100  	}
   101  
   102  	c := newConnection(conn, address, opts)
   103  
   104  	// Send handshake
   105  	handshake, err := c.handshake(opts.HandshakeVersion)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	if err = handshake.Send(); err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	// NOTE: mock.go: Mock.Query()
   115  	// NOTE: connection_test.go: runConnection()
   116  	go c.readSocket()
   117  	go c.processResponses()
   118  
   119  	return c, nil
   120  }
   121  
   122  func newConnection(conn net.Conn, address string, opts *ConnectOpts) *Connection {
   123  	c := &Connection{
   124  		Conn:               conn,
   125  		address:            address,
   126  		opts:               opts,
   127  		cursors:            make(map[int64]*Cursor),
   128  		stopReadChan:       make(chan bool, 1),
   129  		bad:                connNotBad,
   130  		closed:             connWorking,
   131  		readRequestsChan:   make(chan tokenAndPromise, 16),
   132  		responseChan:       make(chan responseAndError, 16),
   133  		stopProcessingChan: make(chan struct{}),
   134  	}
   135  	return c
   136  }
   137  
   138  // Close closes the underlying net.Conn
   139  func (c *Connection) Close() error {
   140  	var err error
   141  
   142  	c.mu.Lock()
   143  	defer c.mu.Unlock()
   144  
   145  	if !c.isClosed() {
   146  		c.setClosed()
   147  		close(c.stopReadChan)
   148  		err = c.Conn.Close()
   149  	}
   150  
   151  	return err
   152  }
   153  
   154  // Query sends a Query to the database, returning both the raw Response and a
   155  // Cursor which should be used to view the query's response.
   156  //
   157  // This function is used internally by Run which should be used for most queries.
   158  func (c *Connection) Query(ctx context.Context, q Query) (*Response, *Cursor, error) {
   159  	if c == nil {
   160  		return nil, nil, ErrConnectionClosed
   161  	}
   162  	if c.Conn == nil || c.isClosed() {
   163  		c.setBad()
   164  		return nil, nil, ErrConnectionClosed
   165  	}
   166  	if ctx == nil {
   167  		ctx = c.contextFromConnectionOpts()
   168  	}
   169  
   170  	// Add token if query is a START/NOREPLY_WAIT
   171  	if q.Type == p.Query_START || q.Type == p.Query_NOREPLY_WAIT || q.Type == p.Query_SERVER_INFO {
   172  		q.Token = c.nextToken()
   173  	}
   174  	if q.Type == p.Query_START || q.Type == p.Query_NOREPLY_WAIT {
   175  		if c.opts.Database != "" {
   176  			var err error
   177  			q.Opts["db"], err = DB(c.opts.Database).Build()
   178  			if err != nil {
   179  				return nil, nil, RQLDriverError{rqlError(err.Error())}
   180  			}
   181  		}
   182  	}
   183  
   184  	var fetchingSpan opentracing.Span
   185  	if c.opts.UseOpentracing {
   186  		parentSpan := opentracing.SpanFromContext(ctx)
   187  		if parentSpan != nil {
   188  			if q.Type == p.Query_START {
   189  				querySpan := c.startTracingSpan(parentSpan, &q) // will be Finished when cursor connClosed
   190  				parentSpan = querySpan
   191  				ctx = opentracing.ContextWithSpan(ctx, querySpan)
   192  			}
   193  
   194  			fetchingSpan = c.startTracingSpan(parentSpan, &q) // will be Finished when response arrived
   195  		}
   196  	}
   197  
   198  	err := c.sendQuery(q)
   199  	if err != nil {
   200  		if fetchingSpan != nil {
   201  			ext.Error.Set(fetchingSpan, true)
   202  			fetchingSpan.LogFields(log.Error(err))
   203  			fetchingSpan.Finish()
   204  			if q.Type == p.Query_START {
   205  				opentracing.SpanFromContext(ctx).Finish()
   206  			}
   207  		}
   208  		return nil, nil, err
   209  	}
   210  
   211  	if noreply, ok := q.Opts["noreply"]; ok && noreply.(bool) {
   212  		return nil, nil, nil
   213  	}
   214  
   215  	promise := make(chan responseAndCursor, 1)
   216  	select {
   217  	case c.readRequestsChan <- tokenAndPromise{ctx: ctx, query: &q, span: fetchingSpan, promise: promise}:
   218  	case <-ctx.Done():
   219  		return c.stopQuery(&q)
   220  	}
   221  
   222  	select {
   223  	case future := <-promise:
   224  		return future.response, future.cursor, future.err
   225  	case <-ctx.Done():
   226  		return c.stopQuery(&q)
   227  	case <-c.stopProcessingChan: // connection readRequests processing stopped, promise can be never answered
   228  		return nil, nil, ErrConnectionClosed
   229  	}
   230  }
   231  
   232  func (c *Connection) stopQuery(q *Query) (*Response, *Cursor, error) {
   233  	if q.Type != p.Query_STOP && !c.isClosed() && !c.isBad() {
   234  		stopQuery := newStopQuery(q.Token)
   235  		_, _, _ = c.Query(c.contextFromConnectionOpts(), stopQuery)
   236  	}
   237  	return nil, nil, ErrQueryTimeout
   238  }
   239  
   240  func (c *Connection) startTracingSpan(parentSpan opentracing.Span, q *Query) opentracing.Span {
   241  	span := parentSpan.Tracer().StartSpan(
   242  		"Query_"+q.Type.String(),
   243  		opentracing.ChildOf(parentSpan.Context()),
   244  		ext.SpanKindRPCClient)
   245  
   246  	ext.PeerAddress.Set(span, c.address)
   247  	ext.Component.Set(span, "rethinkdb-go")
   248  
   249  	if q.Type == p.Query_START {
   250  		span.LogFields(log.String("query", q.Term.String()))
   251  	}
   252  
   253  	return span
   254  }
   255  
   256  func (c *Connection) readSocket() {
   257  	for {
   258  		response, err := c.readResponse()
   259  
   260  		c.responseChan <- responseAndError{
   261  			response: response,
   262  			err:      err,
   263  		}
   264  
   265  		select {
   266  		case <-c.stopReadChan:
   267  			close(c.responseChan)
   268  			return
   269  		default:
   270  		}
   271  	}
   272  }
   273  
   274  func (c *Connection) processResponses() {
   275  	readRequests := make([]tokenAndPromise, 0, 16)
   276  	responses := make([]*Response, 0, 16)
   277  	for {
   278  		var response *Response
   279  		var readRequest tokenAndPromise
   280  		var ok bool
   281  
   282  		select {
   283  		case respPair, openned := <-c.responseChan:
   284  			if respPair.err != nil {
   285  				// Transport socket error, can't continue to work
   286  				// Don't know return to who (no token) - return to all
   287  				broadcastError(readRequests, respPair.err)
   288  				readRequests = []tokenAndPromise{}
   289  				_ = c.Close() // next `if` will be called indirect cascade by closing chans
   290  				continue
   291  			}
   292  			if !openned { // responseChan is connClosed (stopReadChan is closed too)
   293  				close(c.stopProcessingChan)
   294  				broadcastError(readRequests, ErrConnectionClosed)
   295  				c.cursors = nil
   296  
   297  				return
   298  			}
   299  
   300  			response = respPair.response
   301  
   302  			readRequest, ok = getReadRequest(readRequests, respPair.response.Token)
   303  			if !ok {
   304  				responses = append(responses, respPair.response)
   305  				continue
   306  			}
   307  			readRequests = removeReadRequest(readRequests, respPair.response.Token)
   308  
   309  		case readRequest = <-c.readRequestsChan:
   310  			response, ok = getResponse(responses, readRequest.query.Token)
   311  			if !ok {
   312  				readRequests = append(readRequests, readRequest)
   313  				continue
   314  			}
   315  			responses = removeResponse(responses, readRequest.query.Token)
   316  		}
   317  
   318  		response, cursor, err := c.processResponse(readRequest.ctx, *readRequest.query, response, readRequest.span)
   319  		if readRequest.promise != nil {
   320  			readRequest.promise <- responseAndCursor{response: response, cursor: cursor, err: err}
   321  			close(readRequest.promise)
   322  		}
   323  	}
   324  }
   325  
   326  func broadcastError(readRequests []tokenAndPromise, err error) {
   327  	for _, rr := range readRequests {
   328  		if rr.promise != nil {
   329  			rr.promise <- responseAndCursor{err: err}
   330  			close(rr.promise)
   331  		}
   332  	}
   333  }
   334  
   335  type ServerResponse struct {
   336  	ID   string `rethinkdb:"id"`
   337  	Name string `rethinkdb:"name"`
   338  }
   339  
   340  // Server returns the server name and server UUID being used by a connection.
   341  func (c *Connection) Server() (ServerResponse, error) {
   342  	var response ServerResponse
   343  
   344  	_, cur, err := c.Query(c.contextFromConnectionOpts(), Query{
   345  		Type: p.Query_SERVER_INFO,
   346  	})
   347  	if err != nil {
   348  		return response, err
   349  	}
   350  
   351  	if err = cur.One(&response); err != nil {
   352  		return response, err
   353  	}
   354  
   355  	if err = cur.Close(); err != nil {
   356  		return response, err
   357  	}
   358  
   359  	return response, nil
   360  }
   361  
   362  // sendQuery marshals the Query and sends the JSON to the server.
   363  func (c *Connection) sendQuery(q Query) error {
   364  	buf := &bytes.Buffer{}
   365  	buf.Grow(respHeaderLen)
   366  	buf.Write(buf.Bytes()[:respHeaderLen]) // reserve for header
   367  	enc := json.NewEncoder(buf)
   368  
   369  	// Build query
   370  	err := enc.Encode(q.Build())
   371  	if err != nil {
   372  		return RQLDriverError{rqlError(fmt.Sprintf("Error building query: %s", err.Error()))}
   373  	}
   374  
   375  	b := buf.Bytes()
   376  
   377  	// Write header
   378  	binary.LittleEndian.PutUint64(b, uint64(q.Token))
   379  	binary.LittleEndian.PutUint32(b[8:], uint32(len(b)-respHeaderLen))
   380  
   381  	// Send the JSON encoding of the query itself.
   382  	if err = c.writeData(b); err != nil {
   383  		c.setBad()
   384  		return RQLConnectionError{rqlError(err.Error())}
   385  	}
   386  
   387  	return nil
   388  }
   389  
   390  // getToken generates the next query token, used to number requests and match
   391  // responses with requests.
   392  func (c *Connection) nextToken() int64 {
   393  	// requires c.token to be 64-bit aligned on ARM
   394  	return atomic.AddInt64(&c.token, 1)
   395  }
   396  
   397  // readResponse attempts to read a Response from the server, if no response
   398  // could be read then an error is returned.
   399  func (c *Connection) readResponse() (*Response, error) {
   400  	// due to this is pooled connection, it always reads from socket even if idle
   401  	// timeouts should be only on query-level with context
   402  
   403  	// Read response header (token+length)
   404  	headerBuf := [respHeaderLen]byte{}
   405  	if _, err := c.read(headerBuf[:]); err != nil {
   406  		c.setBad()
   407  		return nil, RQLConnectionError{rqlError(err.Error())}
   408  	}
   409  
   410  	responseToken := int64(binary.LittleEndian.Uint64(headerBuf[:8]))
   411  	messageLength := binary.LittleEndian.Uint32(headerBuf[8:])
   412  
   413  	// Read the JSON encoding of the Response itself.
   414  	b := make([]byte, int(messageLength))
   415  
   416  	if _, err := c.read(b); err != nil {
   417  		c.setBad()
   418  		return nil, RQLConnectionError{rqlError(err.Error())}
   419  	}
   420  
   421  	// Decode the response
   422  	var response = new(Response)
   423  	if err := json.Unmarshal(b, response); err != nil {
   424  		c.setBad()
   425  		return nil, RQLDriverError{rqlError(err.Error())}
   426  	}
   427  	response.Token = responseToken
   428  
   429  	return response, nil
   430  }
   431  
   432  // Called to fill response for the query
   433  func (c *Connection) processResponse(ctx context.Context, q Query, response *Response, span opentracing.Span) (r *Response, cur *Cursor, err error) {
   434  	if span != nil {
   435  		defer func() {
   436  			if err != nil {
   437  				ext.Error.Set(span, true)
   438  				span.LogFields(log.Error(err))
   439  			}
   440  			span.Finish()
   441  		}()
   442  	}
   443  
   444  	switch response.Type {
   445  	case p.Response_CLIENT_ERROR:
   446  		return response, c.processErrorResponse(response), createClientError(response, q.Term)
   447  	case p.Response_COMPILE_ERROR:
   448  		return response, c.processErrorResponse(response), createCompileError(response, q.Term)
   449  	case p.Response_RUNTIME_ERROR:
   450  		return response, c.processErrorResponse(response), createRuntimeError(response.ErrorType, response, q.Term)
   451  	case p.Response_SUCCESS_ATOM, p.Response_SERVER_INFO:
   452  		return c.processAtomResponse(ctx, q, response)
   453  	case p.Response_SUCCESS_PARTIAL:
   454  		return c.processPartialResponse(ctx, q, response)
   455  	case p.Response_SUCCESS_SEQUENCE:
   456  		return c.processSequenceResponse(ctx, q, response)
   457  	case p.Response_WAIT_COMPLETE:
   458  		return c.processWaitResponse(response)
   459  	default:
   460  		return nil, nil, RQLDriverError{rqlError(fmt.Sprintf("Unexpected response type: %v", response.Type.String()))}
   461  	}
   462  }
   463  
   464  func (c *Connection) processErrorResponse(response *Response) *Cursor {
   465  	cursor := c.cursors[response.Token]
   466  	delete(c.cursors, response.Token)
   467  	return cursor
   468  }
   469  
   470  func (c *Connection) processAtomResponse(ctx context.Context, q Query, response *Response) (*Response, *Cursor, error) {
   471  	cursor := newCursor(ctx, c, "Cursor", response.Token, q.Term, q.Opts)
   472  	cursor.profile = response.Profile
   473  	cursor.extend(response)
   474  
   475  	return response, cursor, nil
   476  }
   477  
   478  func (c *Connection) processPartialResponse(ctx context.Context, q Query, response *Response) (*Response, *Cursor, error) {
   479  	cursorType := "Cursor"
   480  	if len(response.Notes) > 0 {
   481  		switch response.Notes[0] {
   482  		case p.Response_SEQUENCE_FEED:
   483  			cursorType = "Feed"
   484  		case p.Response_ATOM_FEED:
   485  			cursorType = "AtomFeed"
   486  		case p.Response_ORDER_BY_LIMIT_FEED:
   487  			cursorType = "OrderByLimitFeed"
   488  		case p.Response_UNIONED_FEED:
   489  			cursorType = "UnionedFeed"
   490  		case p.Response_INCLUDES_STATES:
   491  			cursorType = "IncludesFeed"
   492  		}
   493  	}
   494  
   495  	cursor, ok := c.cursors[response.Token]
   496  	if !ok {
   497  		// Create a new cursor if needed
   498  		cursor = newCursor(ctx, c, cursorType, response.Token, q.Term, q.Opts)
   499  		cursor.profile = response.Profile
   500  
   501  		c.cursors[response.Token] = cursor
   502  	}
   503  
   504  	cursor.extend(response)
   505  
   506  	return response, cursor, nil
   507  }
   508  
   509  func (c *Connection) processSequenceResponse(ctx context.Context, q Query, response *Response) (*Response, *Cursor, error) {
   510  	cursor, ok := c.cursors[response.Token]
   511  	if !ok {
   512  		// Create a new cursor if needed
   513  		cursor = newCursor(ctx, c, "Cursor", response.Token, q.Term, q.Opts)
   514  		cursor.profile = response.Profile
   515  	}
   516  	delete(c.cursors, response.Token)
   517  
   518  	cursor.extend(response)
   519  
   520  	return response, cursor, nil
   521  }
   522  
   523  func (c *Connection) processWaitResponse(response *Response) (*Response, *Cursor, error) {
   524  	delete(c.cursors, response.Token)
   525  	return response, nil, nil
   526  }
   527  
   528  func (c *Connection) setBad() {
   529  	atomic.StoreInt32(&c.bad, connBad)
   530  }
   531  
   532  func (c *Connection) isBad() bool {
   533  	return atomic.LoadInt32(&c.bad) == connBad
   534  }
   535  
   536  func (c *Connection) setClosed() {
   537  	atomic.StoreInt32(&c.closed, connClosed)
   538  }
   539  
   540  func (c *Connection) isClosed() bool {
   541  	return atomic.LoadInt32(&c.closed) == connClosed
   542  }
   543  
   544  func getReadRequest(list []tokenAndPromise, token int64) (tokenAndPromise, bool) {
   545  	for _, e := range list {
   546  		if e.query.Token == token {
   547  			return e, true
   548  		}
   549  	}
   550  	return tokenAndPromise{}, false
   551  }
   552  
   553  func getResponse(list []*Response, token int64) (*Response, bool) {
   554  	for _, e := range list {
   555  		if e.Token == token {
   556  			return e, true
   557  		}
   558  	}
   559  	return nil, false
   560  }
   561  
   562  func removeReadRequest(list []tokenAndPromise, token int64) []tokenAndPromise {
   563  	for i := range list {
   564  		if list[i].query.Token == token {
   565  			return append(list[:i], list[i+1:]...)
   566  		}
   567  	}
   568  	return list
   569  }
   570  
   571  func removeResponse(list []*Response, token int64) []*Response {
   572  	for i := range list {
   573  		if list[i].Token == token {
   574  			return append(list[:i], list[i+1:]...)
   575  		}
   576  	}
   577  	return list
   578  }