gopkg.in/rethinkdb/rethinkdb-go.v6@v6.2.2/mock.go (about) 1 package rethinkdb 2 3 import ( 4 "encoding/binary" 5 "encoding/json" 6 "fmt" 7 "gopkg.in/rethinkdb/rethinkdb-go.v6/encoding" 8 "net" 9 "reflect" 10 "sync" 11 "time" 12 13 "golang.org/x/net/context" 14 p "gopkg.in/rethinkdb/rethinkdb-go.v6/ql2" 15 ) 16 17 // Mocking is based on the amazing package github.com/stretchr/testify 18 19 // testingT is an interface wrapper around *testing.T 20 type testingT interface { 21 Logf(format string, args ...interface{}) 22 Errorf(format string, args ...interface{}) 23 FailNow() 24 } 25 26 // MockAnything can be used in place of any term, this is useful when you want 27 // mock similar queries or queries that you don't quite know the exact structure 28 // of. 29 func MockAnything() Term { 30 t := constructRootTerm("MockAnything", p.Term_DATUM, nil, nil) 31 t.isMockAnything = true 32 33 return t 34 } 35 36 func (t Term) MockAnything() Term { 37 t = constructMethodTerm(t, "MockAnything", p.Term_DATUM, nil, nil) 38 t.isMockAnything = true 39 40 return t 41 } 42 43 // MockQuery represents a mocked query and is used for setting expectations, 44 // as well as recording activity. 45 type MockQuery struct { 46 parent *Mock 47 48 // Holds the query and term 49 Query Query 50 51 // Holds the JSON representation of query 52 BuiltQuery []byte 53 54 // Holds the response that should be returned when this method is executed. 55 Response interface{} 56 57 // Holds the error that should be returned when this method is executed. 58 Error error 59 60 // The number of times to return the return arguments when setting 61 // expectations. 0 means to always return the value. 62 Repeatability int 63 64 // Holds a channel that will be used to block the Return until it either 65 // recieves a message or is connClosed. nil means it returns immediately. 66 WaitFor <-chan time.Time 67 68 // Amount of times this query has been executed 69 executed int 70 } 71 72 func newMockQuery(parent *Mock, q Query) *MockQuery { 73 // Build and marshal term 74 builtQuery, err := json.Marshal(q.Build()) 75 if err != nil { 76 panic(fmt.Sprintf("Failed to build query: %s", err)) 77 } 78 79 return &MockQuery{ 80 parent: parent, 81 Query: q, 82 BuiltQuery: builtQuery, 83 Response: make([]interface{}, 0), 84 Repeatability: 0, 85 WaitFor: nil, 86 } 87 } 88 89 func newMockQueryFromTerm(parent *Mock, t Term, opts map[string]interface{}) *MockQuery { 90 q, err := parent.newQuery(t, opts) 91 if err != nil { 92 panic(fmt.Sprintf("Failed to build query: %s", err)) 93 } 94 95 return newMockQuery(parent, q) 96 } 97 98 func (mq *MockQuery) lock() { 99 mq.parent.mu.Lock() 100 } 101 102 func (mq *MockQuery) unlock() { 103 mq.parent.mu.Unlock() 104 } 105 106 // Return specifies the return arguments for the expectation. 107 // 108 // mock.On(r.Table("test")).Return(nil, errors.New("failed")) 109 // 110 // values of `chan []interface{}` type will turn to delayed data that produce data 111 // when there is an elements available on the channel. These elements are chunk of responses. 112 // Values of `func() []interface{}` type will produce data by calling the function. E.g. 113 // Closing channel or returning nil from func means end of data. 114 // 115 // f := func() []interface{} { return []interface{}{1, 2} } 116 // mock.On(r.Table("test1")).Return(f) 117 // 118 // ch := make(chan []interface{}) 119 // mock.On(r.Table("test1")).Return(ch) 120 // 121 // Running the query above will block until a value is pushed onto ch. 122 func (mq *MockQuery) Return(response interface{}, err error) *MockQuery { 123 mq.lock() 124 defer mq.unlock() 125 126 mq.Response = response 127 mq.Error = err 128 129 return mq 130 } 131 132 // Once indicates that that the mock should only return the value once. 133 // 134 // mock.On(r.Table("test")).Return(result, nil).Once() 135 func (mq *MockQuery) Once() *MockQuery { 136 return mq.Times(1) 137 } 138 139 // Twice indicates that that the mock should only return the value twice. 140 // 141 // mock.On(r.Table("test")).Return(result, nil).Twice() 142 func (mq *MockQuery) Twice() *MockQuery { 143 return mq.Times(2) 144 } 145 146 // Times indicates that that the mock should only return the indicated number 147 // of times. 148 // 149 // mock.On(r.Table("test")).Return(result, nil).Times(5) 150 func (mq *MockQuery) Times(i int) *MockQuery { 151 mq.lock() 152 defer mq.unlock() 153 mq.Repeatability = i 154 return mq 155 } 156 157 // WaitUntil sets the channel that will block the mock's return until its connClosed 158 // or a message is received. 159 // 160 // mock.On(r.Table("test")).WaitUntil(time.After(time.Second)) 161 func (mq *MockQuery) WaitUntil(w <-chan time.Time) *MockQuery { 162 mq.lock() 163 defer mq.unlock() 164 mq.WaitFor = w 165 return mq 166 } 167 168 // After sets how long to block until the query returns 169 // 170 // mock.On(r.Table("test")).After(time.Second) 171 func (mq *MockQuery) After(d time.Duration) *MockQuery { 172 return mq.WaitUntil(time.After(d)) 173 } 174 175 // On chains a new expectation description onto the mocked interface. This 176 // allows syntax like. 177 // 178 // Mock. 179 // On(r.Table("test")).Return(result, nil). 180 // On(r.Table("test2")).Return(nil, errors.New("Some Error")) 181 func (mq *MockQuery) On(t Term) *MockQuery { 182 return mq.parent.On(t) 183 } 184 185 // Mock is used to mock query execution and verify that the expected queries are 186 // being executed. Mocks are used by creating an instance using NewMock and then 187 // passing this when running your queries instead of a session. For example: 188 // 189 // mock := r.NewMock() 190 // mock.On(r.Table("test")).Return([]interface{}{data}, nil) 191 // 192 // cursor, err := r.Table("test").Run(mock) 193 // 194 // mock.AssertExpectations(t) 195 type Mock struct { 196 mu sync.Mutex 197 opts ConnectOpts 198 199 ExpectedQueries []*MockQuery 200 Queries []MockQuery 201 } 202 203 // NewMock creates an instance of Mock, you can optionally pass ConnectOpts to 204 // the function, if passed any mocked query will be generated using those 205 // options. 206 func NewMock(opts ...ConnectOpts) *Mock { 207 m := &Mock{ 208 ExpectedQueries: make([]*MockQuery, 0), 209 Queries: make([]MockQuery, 0), 210 } 211 212 if len(opts) > 0 { 213 m.opts = opts[0] 214 } 215 216 return m 217 } 218 219 // On starts a description of an expectation of the specified query 220 // being executed. 221 // 222 // mock.On(r.Table("test")) 223 func (m *Mock) On(t Term, opts ...map[string]interface{}) *MockQuery { 224 var qopts map[string]interface{} 225 if len(opts) > 0 { 226 qopts = opts[0] 227 } 228 229 m.mu.Lock() 230 defer m.mu.Unlock() 231 mq := newMockQueryFromTerm(m, t, qopts) 232 m.ExpectedQueries = append(m.ExpectedQueries, mq) 233 return mq 234 } 235 236 // AssertExpectations asserts that everything specified with On and Return was 237 // in fact executed as expected. Queries may have been executed in any order. 238 func (m *Mock) AssertExpectations(t testingT) bool { 239 var somethingMissing bool 240 var failedExpectations int 241 242 // iterate through each expectation 243 expectedQueries := m.expectedQueries() 244 for _, expectedQuery := range expectedQueries { 245 if !m.queryWasExecuted(expectedQuery) && expectedQuery.executed == 0 { 246 somethingMissing = true 247 failedExpectations++ 248 t.Logf("❌\t%s", expectedQuery.Query.Term.String()) 249 } else { 250 m.mu.Lock() 251 if expectedQuery.Repeatability > 0 { 252 somethingMissing = true 253 failedExpectations++ 254 } else { 255 t.Logf("✅\t%s", expectedQuery.Query.Term.String()) 256 } 257 m.mu.Unlock() 258 } 259 } 260 261 if somethingMissing { 262 t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe query you are testing needs to be executed %d more times(s).", len(expectedQueries)-failedExpectations, len(expectedQueries), failedExpectations) 263 } 264 265 return !somethingMissing 266 } 267 268 // AssertNumberOfExecutions asserts that the query was executed expectedExecutions times. 269 func (m *Mock) AssertNumberOfExecutions(t testingT, expectedQuery *MockQuery, expectedExecutions int) bool { 270 var actualExecutions int 271 for _, query := range m.queries() { 272 if query.Query.Term.compare(*expectedQuery.Query.Term, map[int64]int64{}) && query.Repeatability > -1 { 273 // if bytes.Equal(query.BuiltQuery, expectedQuery.BuiltQuery) { 274 actualExecutions++ 275 } 276 } 277 278 if expectedExecutions != actualExecutions { 279 t.Errorf("Expected number of executions (%d) does not match the actual number of executions (%d).", expectedExecutions, actualExecutions) 280 return false 281 } 282 283 return true 284 } 285 286 // AssertExecuted asserts that the method was executed. 287 // It can produce a false result when an argument is a pointer type and the underlying value changed after executing the mocked method. 288 func (m *Mock) AssertExecuted(t testingT, expectedQuery *MockQuery) bool { 289 if !m.queryWasExecuted(expectedQuery) { 290 t.Errorf("The query \"%s\" should have been executed, but was not.", expectedQuery.Query.Term.String()) 291 return false 292 } 293 return true 294 } 295 296 // AssertNotExecuted asserts that the method was not executed. 297 // It can produce a false result when an argument is a pointer type and the underlying value changed after executing the mocked method. 298 func (m *Mock) AssertNotExecuted(t testingT, expectedQuery *MockQuery) bool { 299 if m.queryWasExecuted(expectedQuery) { 300 t.Errorf("The query \"%s\" was executed, but should NOT have been.", expectedQuery.Query.Term.String()) 301 return false 302 } 303 return true 304 } 305 306 func (m *Mock) IsConnected() bool { 307 return true 308 } 309 310 func (m *Mock) Query(ctx context.Context, q Query) (*Cursor, error) { 311 found, query := m.findExpectedQuery(q) 312 313 if found < 0 { 314 panic(fmt.Sprintf("rethinkdb: mock: This query was unexpected:\n\t\t%s", q.Term.String())) 315 } else { 316 m.mu.Lock() 317 switch { 318 case query.Repeatability == 1: 319 query.Repeatability = -1 320 query.executed++ 321 322 case query.Repeatability > 1: 323 query.Repeatability-- 324 query.executed++ 325 326 case query.Repeatability == 0: 327 query.executed++ 328 } 329 m.mu.Unlock() 330 } 331 332 // add the query 333 m.mu.Lock() 334 m.Queries = append(m.Queries, *newMockQuery(m, q)) 335 m.mu.Unlock() 336 337 // block if specified 338 if query.WaitFor != nil { 339 <-query.WaitFor 340 } 341 342 // Return error without building cursor if non-nil 343 if query.Error != nil { 344 return nil, query.Error 345 } 346 347 if ctx == nil { 348 ctx = context.Background() 349 } 350 351 conn := newConnection(newMockConn(query.Response), "mock", &ConnectOpts{}) 352 353 query.Query.Type = p.Query_CONTINUE 354 query.Query.Token = conn.nextToken() 355 356 // Build cursor and return 357 c := newCursor(ctx, conn, "", query.Query.Token, query.Query.Term, query.Query.Opts) 358 c.finished = true 359 c.fetching = false 360 c.isAtom = true 361 c.finished = false 362 c.releaseConn = func() error { return conn.Close() } 363 364 conn.cursors[query.Query.Token] = c 365 go conn.readSocket() 366 go conn.processResponses() 367 368 c.mu.Lock() 369 err := c.fetchMore() 370 c.mu.Unlock() 371 if err != nil { 372 return nil, err 373 } 374 375 return c, nil 376 } 377 378 func (m *Mock) Exec(ctx context.Context, q Query) error { 379 _, err := m.Query(ctx, q) 380 381 return err 382 } 383 384 func (m *Mock) newQuery(t Term, opts map[string]interface{}) (Query, error) { 385 return newQuery(t, opts, &m.opts) 386 } 387 388 func (m *Mock) findExpectedQuery(q Query) (int, *MockQuery) { 389 m.mu.Lock() 390 defer m.mu.Unlock() 391 392 for i, query := range m.ExpectedQueries { 393 // if bytes.Equal(query.BuiltQuery, builtQuery) && query.Repeatability > -1 { 394 if query.Query.Term.compare(*q.Term, map[int64]int64{}) && query.Repeatability > -1 { 395 return i, query 396 } 397 } 398 399 return -1, nil 400 } 401 402 func (m *Mock) queryWasExecuted(expectedQuery *MockQuery) bool { 403 for _, query := range m.queries() { 404 if query.Query.Term.compare(*expectedQuery.Query.Term, map[int64]int64{}) { 405 // if bytes.Equal(query.BuiltQuery, expectedQuery.BuiltQuery) { 406 return true 407 } 408 } 409 410 // we didn't find the expected query 411 return false 412 } 413 414 func (m *Mock) expectedQueries() []*MockQuery { 415 m.mu.Lock() 416 defer m.mu.Unlock() 417 return append([]*MockQuery{}, m.ExpectedQueries...) 418 } 419 420 func (m *Mock) queries() []MockQuery { 421 m.mu.Lock() 422 defer m.mu.Unlock() 423 return append([]MockQuery{}, m.Queries...) 424 } 425 426 type mockConn struct { 427 mu sync.Mutex 428 value []byte 429 tokens chan int64 430 valueGetter func() []interface{} 431 } 432 433 func newMockConn(response interface{}) *mockConn { 434 c := &mockConn{tokens: make(chan int64, 1)} 435 switch g := response.(type) { 436 case chan []interface{}: 437 c.valueGetter = func() []interface{} { return <-g } 438 case func() []interface{}: 439 c.valueGetter = g 440 default: 441 responseVal := reflect.ValueOf(response) 442 if responseVal.Kind() == reflect.Slice || responseVal.Kind() == reflect.Array { 443 responses := make([]interface{}, responseVal.Len()) 444 for i := 0; i < responseVal.Len(); i++ { 445 responses[i] = responseVal.Index(i).Interface() 446 } 447 c.valueGetter = funcGetter(responses) 448 } else { 449 c.valueGetter = funcGetter([]interface{}{response}) 450 } 451 } 452 return c 453 } 454 455 func funcGetter(responses []interface{}) func() []interface{} { 456 done := false 457 return func() []interface{} { 458 if done { 459 return nil 460 } 461 done = true 462 return responses 463 } 464 } 465 466 func (c *mockConn) Read(b []byte) (n int, err error) { 467 c.mu.Lock() 468 defer c.mu.Unlock() 469 470 if c.value == nil { 471 values := c.valueGetter() 472 473 jresps := make([]json.RawMessage, len(values)) 474 for i := range values { 475 coded, err := encoding.Encode(values[i]) 476 if err != nil { 477 panic(fmt.Sprintf("failed to encode response: %v", err)) 478 } 479 jresps[i], err = json.Marshal(coded) 480 if err != nil { 481 panic(fmt.Sprintf("failed to encode response: %v", err)) 482 } 483 } 484 485 token := <-c.tokens 486 resp := Response{ 487 Token: token, 488 Responses: jresps, 489 Type: p.Response_SUCCESS_PARTIAL, 490 } 491 if values == nil { 492 resp.Type = p.Response_SUCCESS_SEQUENCE 493 } 494 495 c.value, err = json.Marshal(resp) 496 if err != nil { 497 panic(fmt.Sprintf("failed to encode response: %v", err)) 498 } 499 500 if len(b) != respHeaderLen { 501 panic("wrong header len") 502 } 503 binary.LittleEndian.PutUint64(b[:8], uint64(token)) 504 binary.LittleEndian.PutUint32(b[8:], uint32(len(c.value))) 505 return len(b), nil 506 } else { 507 copy(b, c.value) 508 c.value = nil 509 return len(b), nil 510 } 511 } 512 513 func (c *mockConn) Write(b []byte) (n int, err error) { 514 if len(b) < 8 { 515 panic("connBad socket write") 516 } 517 token := int64(binary.LittleEndian.Uint64(b[:8])) 518 c.tokens <- token 519 return len(b), nil 520 } 521 func (c *mockConn) Close() error { return nil } 522 func (c *mockConn) LocalAddr() net.Addr { panic("not implemented") } 523 func (c *mockConn) RemoteAddr() net.Addr { panic("not implemented") } 524 func (c *mockConn) SetDeadline(t time.Time) error { panic("not implemented") } 525 func (c *mockConn) SetReadDeadline(t time.Time) error { panic("not implemented") } 526 func (c *mockConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }