gopkg.in/rethinkdb/rethinkdb-go.v6@v6.2.2/mock.go (about)

     1  package rethinkdb
     2  
     3  import (
     4  	"encoding/binary"
     5  	"encoding/json"
     6  	"fmt"
     7  	"gopkg.in/rethinkdb/rethinkdb-go.v6/encoding"
     8  	"net"
     9  	"reflect"
    10  	"sync"
    11  	"time"
    12  
    13  	"golang.org/x/net/context"
    14  	p "gopkg.in/rethinkdb/rethinkdb-go.v6/ql2"
    15  )
    16  
    17  // Mocking is based on the amazing package github.com/stretchr/testify
    18  
    19  // testingT is an interface wrapper around *testing.T
    20  type testingT interface {
    21  	Logf(format string, args ...interface{})
    22  	Errorf(format string, args ...interface{})
    23  	FailNow()
    24  }
    25  
    26  // MockAnything can be used in place of any term, this is useful when you want
    27  // mock similar queries or queries that you don't quite know the exact structure
    28  // of.
    29  func MockAnything() Term {
    30  	t := constructRootTerm("MockAnything", p.Term_DATUM, nil, nil)
    31  	t.isMockAnything = true
    32  
    33  	return t
    34  }
    35  
    36  func (t Term) MockAnything() Term {
    37  	t = constructMethodTerm(t, "MockAnything", p.Term_DATUM, nil, nil)
    38  	t.isMockAnything = true
    39  
    40  	return t
    41  }
    42  
    43  // MockQuery represents a mocked query and is used for setting expectations,
    44  // as well as recording activity.
    45  type MockQuery struct {
    46  	parent *Mock
    47  
    48  	// Holds the query and term
    49  	Query Query
    50  
    51  	// Holds the JSON representation of query
    52  	BuiltQuery []byte
    53  
    54  	// Holds the response that should be returned when this method is executed.
    55  	Response interface{}
    56  
    57  	// Holds the error that should be returned when this method is executed.
    58  	Error error
    59  
    60  	// The number of times to return the return arguments when setting
    61  	// expectations. 0 means to always return the value.
    62  	Repeatability int
    63  
    64  	// Holds a channel that will be used to block the Return until it either
    65  	// recieves a message or is connClosed. nil means it returns immediately.
    66  	WaitFor <-chan time.Time
    67  
    68  	// Amount of times this query has been executed
    69  	executed int
    70  }
    71  
    72  func newMockQuery(parent *Mock, q Query) *MockQuery {
    73  	// Build and marshal term
    74  	builtQuery, err := json.Marshal(q.Build())
    75  	if err != nil {
    76  		panic(fmt.Sprintf("Failed to build query: %s", err))
    77  	}
    78  
    79  	return &MockQuery{
    80  		parent:        parent,
    81  		Query:         q,
    82  		BuiltQuery:    builtQuery,
    83  		Response:      make([]interface{}, 0),
    84  		Repeatability: 0,
    85  		WaitFor:       nil,
    86  	}
    87  }
    88  
    89  func newMockQueryFromTerm(parent *Mock, t Term, opts map[string]interface{}) *MockQuery {
    90  	q, err := parent.newQuery(t, opts)
    91  	if err != nil {
    92  		panic(fmt.Sprintf("Failed to build query: %s", err))
    93  	}
    94  
    95  	return newMockQuery(parent, q)
    96  }
    97  
    98  func (mq *MockQuery) lock() {
    99  	mq.parent.mu.Lock()
   100  }
   101  
   102  func (mq *MockQuery) unlock() {
   103  	mq.parent.mu.Unlock()
   104  }
   105  
   106  // Return specifies the return arguments for the expectation.
   107  //
   108  //    mock.On(r.Table("test")).Return(nil, errors.New("failed"))
   109  //
   110  // values of `chan []interface{}` type will turn to delayed data that produce data
   111  // when there is an elements available on the channel. These elements are chunk of responses.
   112  // Values of `func() []interface{}` type will produce data by calling the function. E.g.
   113  // Closing channel or returning nil from func means end of data.
   114  //
   115  //    f := func() []interface{} { return []interface{}{1, 2} }
   116  //    mock.On(r.Table("test1")).Return(f)
   117  //
   118  //    ch := make(chan []interface{})
   119  //    mock.On(r.Table("test1")).Return(ch)
   120  //
   121  //    Running the query above will block until a value is pushed onto ch.
   122  func (mq *MockQuery) Return(response interface{}, err error) *MockQuery {
   123  	mq.lock()
   124  	defer mq.unlock()
   125  
   126  	mq.Response = response
   127  	mq.Error = err
   128  
   129  	return mq
   130  }
   131  
   132  // Once indicates that that the mock should only return the value once.
   133  //
   134  //    mock.On(r.Table("test")).Return(result, nil).Once()
   135  func (mq *MockQuery) Once() *MockQuery {
   136  	return mq.Times(1)
   137  }
   138  
   139  // Twice indicates that that the mock should only return the value twice.
   140  //
   141  //    mock.On(r.Table("test")).Return(result, nil).Twice()
   142  func (mq *MockQuery) Twice() *MockQuery {
   143  	return mq.Times(2)
   144  }
   145  
   146  // Times indicates that that the mock should only return the indicated number
   147  // of times.
   148  //
   149  //    mock.On(r.Table("test")).Return(result, nil).Times(5)
   150  func (mq *MockQuery) Times(i int) *MockQuery {
   151  	mq.lock()
   152  	defer mq.unlock()
   153  	mq.Repeatability = i
   154  	return mq
   155  }
   156  
   157  // WaitUntil sets the channel that will block the mock's return until its connClosed
   158  // or a message is received.
   159  //
   160  //    mock.On(r.Table("test")).WaitUntil(time.After(time.Second))
   161  func (mq *MockQuery) WaitUntil(w <-chan time.Time) *MockQuery {
   162  	mq.lock()
   163  	defer mq.unlock()
   164  	mq.WaitFor = w
   165  	return mq
   166  }
   167  
   168  // After sets how long to block until the query returns
   169  //
   170  //    mock.On(r.Table("test")).After(time.Second)
   171  func (mq *MockQuery) After(d time.Duration) *MockQuery {
   172  	return mq.WaitUntil(time.After(d))
   173  }
   174  
   175  // On chains a new expectation description onto the mocked interface. This
   176  // allows syntax like.
   177  //
   178  //    Mock.
   179  //       On(r.Table("test")).Return(result, nil).
   180  //       On(r.Table("test2")).Return(nil, errors.New("Some Error"))
   181  func (mq *MockQuery) On(t Term) *MockQuery {
   182  	return mq.parent.On(t)
   183  }
   184  
   185  // Mock is used to mock query execution and verify that the expected queries are
   186  // being executed. Mocks are used by creating an instance using NewMock and then
   187  // passing this when running your queries instead of a session. For example:
   188  //
   189  //     mock := r.NewMock()
   190  //     mock.On(r.Table("test")).Return([]interface{}{data}, nil)
   191  //
   192  //     cursor, err := r.Table("test").Run(mock)
   193  //
   194  //     mock.AssertExpectations(t)
   195  type Mock struct {
   196  	mu   sync.Mutex
   197  	opts ConnectOpts
   198  
   199  	ExpectedQueries []*MockQuery
   200  	Queries         []MockQuery
   201  }
   202  
   203  // NewMock creates an instance of Mock, you can optionally pass ConnectOpts to
   204  // the function, if passed any mocked query will be generated using those
   205  // options.
   206  func NewMock(opts ...ConnectOpts) *Mock {
   207  	m := &Mock{
   208  		ExpectedQueries: make([]*MockQuery, 0),
   209  		Queries:         make([]MockQuery, 0),
   210  	}
   211  
   212  	if len(opts) > 0 {
   213  		m.opts = opts[0]
   214  	}
   215  
   216  	return m
   217  }
   218  
   219  // On starts a description of an expectation of the specified query
   220  // being executed.
   221  //
   222  //     mock.On(r.Table("test"))
   223  func (m *Mock) On(t Term, opts ...map[string]interface{}) *MockQuery {
   224  	var qopts map[string]interface{}
   225  	if len(opts) > 0 {
   226  		qopts = opts[0]
   227  	}
   228  
   229  	m.mu.Lock()
   230  	defer m.mu.Unlock()
   231  	mq := newMockQueryFromTerm(m, t, qopts)
   232  	m.ExpectedQueries = append(m.ExpectedQueries, mq)
   233  	return mq
   234  }
   235  
   236  // AssertExpectations asserts that everything specified with On and Return was
   237  // in fact executed as expected. Queries may have been executed in any order.
   238  func (m *Mock) AssertExpectations(t testingT) bool {
   239  	var somethingMissing bool
   240  	var failedExpectations int
   241  
   242  	// iterate through each expectation
   243  	expectedQueries := m.expectedQueries()
   244  	for _, expectedQuery := range expectedQueries {
   245  		if !m.queryWasExecuted(expectedQuery) && expectedQuery.executed == 0 {
   246  			somethingMissing = true
   247  			failedExpectations++
   248  			t.Logf("❌\t%s", expectedQuery.Query.Term.String())
   249  		} else {
   250  			m.mu.Lock()
   251  			if expectedQuery.Repeatability > 0 {
   252  				somethingMissing = true
   253  				failedExpectations++
   254  			} else {
   255  				t.Logf("✅\t%s", expectedQuery.Query.Term.String())
   256  			}
   257  			m.mu.Unlock()
   258  		}
   259  	}
   260  
   261  	if somethingMissing {
   262  		t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe query you are testing needs to be executed %d more times(s).", len(expectedQueries)-failedExpectations, len(expectedQueries), failedExpectations)
   263  	}
   264  
   265  	return !somethingMissing
   266  }
   267  
   268  // AssertNumberOfExecutions asserts that the query was executed expectedExecutions times.
   269  func (m *Mock) AssertNumberOfExecutions(t testingT, expectedQuery *MockQuery, expectedExecutions int) bool {
   270  	var actualExecutions int
   271  	for _, query := range m.queries() {
   272  		if query.Query.Term.compare(*expectedQuery.Query.Term, map[int64]int64{}) && query.Repeatability > -1 {
   273  			// if bytes.Equal(query.BuiltQuery, expectedQuery.BuiltQuery) {
   274  			actualExecutions++
   275  		}
   276  	}
   277  
   278  	if expectedExecutions != actualExecutions {
   279  		t.Errorf("Expected number of executions (%d) does not match the actual number of executions (%d).", expectedExecutions, actualExecutions)
   280  		return false
   281  	}
   282  
   283  	return true
   284  }
   285  
   286  // AssertExecuted asserts that the method was executed.
   287  // It can produce a false result when an argument is a pointer type and the underlying value changed after executing the mocked method.
   288  func (m *Mock) AssertExecuted(t testingT, expectedQuery *MockQuery) bool {
   289  	if !m.queryWasExecuted(expectedQuery) {
   290  		t.Errorf("The query \"%s\" should have been executed, but was not.", expectedQuery.Query.Term.String())
   291  		return false
   292  	}
   293  	return true
   294  }
   295  
   296  // AssertNotExecuted asserts that the method was not executed.
   297  // It can produce a false result when an argument is a pointer type and the underlying value changed after executing the mocked method.
   298  func (m *Mock) AssertNotExecuted(t testingT, expectedQuery *MockQuery) bool {
   299  	if m.queryWasExecuted(expectedQuery) {
   300  		t.Errorf("The query \"%s\" was executed, but should NOT have been.", expectedQuery.Query.Term.String())
   301  		return false
   302  	}
   303  	return true
   304  }
   305  
   306  func (m *Mock) IsConnected() bool {
   307  	return true
   308  }
   309  
   310  func (m *Mock) Query(ctx context.Context, q Query) (*Cursor, error) {
   311  	found, query := m.findExpectedQuery(q)
   312  
   313  	if found < 0 {
   314  		panic(fmt.Sprintf("rethinkdb: mock: This query was unexpected:\n\t\t%s", q.Term.String()))
   315  	} else {
   316  		m.mu.Lock()
   317  		switch {
   318  		case query.Repeatability == 1:
   319  			query.Repeatability = -1
   320  			query.executed++
   321  
   322  		case query.Repeatability > 1:
   323  			query.Repeatability--
   324  			query.executed++
   325  
   326  		case query.Repeatability == 0:
   327  			query.executed++
   328  		}
   329  		m.mu.Unlock()
   330  	}
   331  
   332  	// add the query
   333  	m.mu.Lock()
   334  	m.Queries = append(m.Queries, *newMockQuery(m, q))
   335  	m.mu.Unlock()
   336  
   337  	// block if specified
   338  	if query.WaitFor != nil {
   339  		<-query.WaitFor
   340  	}
   341  
   342  	// Return error without building cursor if non-nil
   343  	if query.Error != nil {
   344  		return nil, query.Error
   345  	}
   346  
   347  	if ctx == nil {
   348  		ctx = context.Background()
   349  	}
   350  
   351  	conn := newConnection(newMockConn(query.Response), "mock", &ConnectOpts{})
   352  
   353  	query.Query.Type = p.Query_CONTINUE
   354  	query.Query.Token = conn.nextToken()
   355  
   356  	// Build cursor and return
   357  	c := newCursor(ctx, conn, "", query.Query.Token, query.Query.Term, query.Query.Opts)
   358  	c.finished = true
   359  	c.fetching = false
   360  	c.isAtom = true
   361  	c.finished = false
   362  	c.releaseConn = func() error { return conn.Close() }
   363  
   364  	conn.cursors[query.Query.Token] = c
   365  	go conn.readSocket()
   366  	go conn.processResponses()
   367  
   368  	c.mu.Lock()
   369  	err := c.fetchMore()
   370  	c.mu.Unlock()
   371  	if err != nil {
   372  		return nil, err
   373  	}
   374  
   375  	return c, nil
   376  }
   377  
   378  func (m *Mock) Exec(ctx context.Context, q Query) error {
   379  	_, err := m.Query(ctx, q)
   380  
   381  	return err
   382  }
   383  
   384  func (m *Mock) newQuery(t Term, opts map[string]interface{}) (Query, error) {
   385  	return newQuery(t, opts, &m.opts)
   386  }
   387  
   388  func (m *Mock) findExpectedQuery(q Query) (int, *MockQuery) {
   389  	m.mu.Lock()
   390  	defer m.mu.Unlock()
   391  
   392  	for i, query := range m.ExpectedQueries {
   393  		// if bytes.Equal(query.BuiltQuery, builtQuery) && query.Repeatability > -1 {
   394  		if query.Query.Term.compare(*q.Term, map[int64]int64{}) && query.Repeatability > -1 {
   395  			return i, query
   396  		}
   397  	}
   398  
   399  	return -1, nil
   400  }
   401  
   402  func (m *Mock) queryWasExecuted(expectedQuery *MockQuery) bool {
   403  	for _, query := range m.queries() {
   404  		if query.Query.Term.compare(*expectedQuery.Query.Term, map[int64]int64{}) {
   405  			// if bytes.Equal(query.BuiltQuery, expectedQuery.BuiltQuery) {
   406  			return true
   407  		}
   408  	}
   409  
   410  	// we didn't find the expected query
   411  	return false
   412  }
   413  
   414  func (m *Mock) expectedQueries() []*MockQuery {
   415  	m.mu.Lock()
   416  	defer m.mu.Unlock()
   417  	return append([]*MockQuery{}, m.ExpectedQueries...)
   418  }
   419  
   420  func (m *Mock) queries() []MockQuery {
   421  	m.mu.Lock()
   422  	defer m.mu.Unlock()
   423  	return append([]MockQuery{}, m.Queries...)
   424  }
   425  
   426  type mockConn struct {
   427  	mu          sync.Mutex
   428  	value       []byte
   429  	tokens      chan int64
   430  	valueGetter func() []interface{}
   431  }
   432  
   433  func newMockConn(response interface{}) *mockConn {
   434  	c := &mockConn{tokens: make(chan int64, 1)}
   435  	switch g := response.(type) {
   436  	case chan []interface{}:
   437  		c.valueGetter = func() []interface{} { return <-g }
   438  	case func() []interface{}:
   439  		c.valueGetter = g
   440  	default:
   441  		responseVal := reflect.ValueOf(response)
   442  		if responseVal.Kind() == reflect.Slice || responseVal.Kind() == reflect.Array {
   443  			responses := make([]interface{}, responseVal.Len())
   444  			for i := 0; i < responseVal.Len(); i++ {
   445  				responses[i] = responseVal.Index(i).Interface()
   446  			}
   447  			c.valueGetter = funcGetter(responses)
   448  		} else {
   449  			c.valueGetter = funcGetter([]interface{}{response})
   450  		}
   451  	}
   452  	return c
   453  }
   454  
   455  func funcGetter(responses []interface{}) func() []interface{} {
   456  	done := false
   457  	return func() []interface{} {
   458  		if done {
   459  			return nil
   460  		}
   461  		done = true
   462  		return responses
   463  	}
   464  }
   465  
   466  func (c *mockConn) Read(b []byte) (n int, err error) {
   467  	c.mu.Lock()
   468  	defer c.mu.Unlock()
   469  
   470  	if c.value == nil {
   471  		values := c.valueGetter()
   472  
   473  		jresps := make([]json.RawMessage, len(values))
   474  		for i := range values {
   475  			coded, err := encoding.Encode(values[i])
   476  			if err != nil {
   477  				panic(fmt.Sprintf("failed to encode response: %v", err))
   478  			}
   479  			jresps[i], err = json.Marshal(coded)
   480  			if err != nil {
   481  				panic(fmt.Sprintf("failed to encode response: %v", err))
   482  			}
   483  		}
   484  
   485  		token := <-c.tokens
   486  		resp := Response{
   487  			Token:     token,
   488  			Responses: jresps,
   489  			Type:      p.Response_SUCCESS_PARTIAL,
   490  		}
   491  		if values == nil {
   492  			resp.Type = p.Response_SUCCESS_SEQUENCE
   493  		}
   494  
   495  		c.value, err = json.Marshal(resp)
   496  		if err != nil {
   497  			panic(fmt.Sprintf("failed to encode response: %v", err))
   498  		}
   499  
   500  		if len(b) != respHeaderLen {
   501  			panic("wrong header len")
   502  		}
   503  		binary.LittleEndian.PutUint64(b[:8], uint64(token))
   504  		binary.LittleEndian.PutUint32(b[8:], uint32(len(c.value)))
   505  		return len(b), nil
   506  	} else {
   507  		copy(b, c.value)
   508  		c.value = nil
   509  		return len(b), nil
   510  	}
   511  }
   512  
   513  func (c *mockConn) Write(b []byte) (n int, err error) {
   514  	if len(b) < 8 {
   515  		panic("connBad socket write")
   516  	}
   517  	token := int64(binary.LittleEndian.Uint64(b[:8]))
   518  	c.tokens <- token
   519  	return len(b), nil
   520  }
   521  func (c *mockConn) Close() error                       { return nil }
   522  func (c *mockConn) LocalAddr() net.Addr                { panic("not implemented") }
   523  func (c *mockConn) RemoteAddr() net.Addr               { panic("not implemented") }
   524  func (c *mockConn) SetDeadline(t time.Time) error      { panic("not implemented") }
   525  func (c *mockConn) SetReadDeadline(t time.Time) error  { panic("not implemented") }
   526  func (c *mockConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }