github.com/golang/mock@v1.6.0/gomock/controller.go (about)

     1  // Copyright 2010 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package gomock is a mock framework for Go.
    16  //
    17  // Standard usage:
    18  //   (1) Define an interface that you wish to mock.
    19  //         type MyInterface interface {
    20  //           SomeMethod(x int64, y string)
    21  //         }
    22  //   (2) Use mockgen to generate a mock from the interface.
    23  //   (3) Use the mock in a test:
    24  //         func TestMyThing(t *testing.T) {
    25  //           mockCtrl := gomock.NewController(t)
    26  //           defer mockCtrl.Finish()
    27  //
    28  //           mockObj := something.NewMockMyInterface(mockCtrl)
    29  //           mockObj.EXPECT().SomeMethod(4, "blah")
    30  //           // pass mockObj to a real object and play with it.
    31  //         }
    32  //
    33  // By default, expected calls are not enforced to run in any particular order.
    34  // Call order dependency can be enforced by use of InOrder and/or Call.After.
    35  // Call.After can create more varied call order dependencies, but InOrder is
    36  // often more convenient.
    37  //
    38  // The following examples create equivalent call order dependencies.
    39  //
    40  // Example of using Call.After to chain expected call order:
    41  //
    42  //     firstCall := mockObj.EXPECT().SomeMethod(1, "first")
    43  //     secondCall := mockObj.EXPECT().SomeMethod(2, "second").After(firstCall)
    44  //     mockObj.EXPECT().SomeMethod(3, "third").After(secondCall)
    45  //
    46  // Example of using InOrder to declare expected call order:
    47  //
    48  //     gomock.InOrder(
    49  //         mockObj.EXPECT().SomeMethod(1, "first"),
    50  //         mockObj.EXPECT().SomeMethod(2, "second"),
    51  //         mockObj.EXPECT().SomeMethod(3, "third"),
    52  //     )
    53  package gomock
    54  
    55  import (
    56  	"context"
    57  	"fmt"
    58  	"reflect"
    59  	"runtime"
    60  	"sync"
    61  )
    62  
    63  // A TestReporter is something that can be used to report test failures.  It
    64  // is satisfied by the standard library's *testing.T.
    65  type TestReporter interface {
    66  	Errorf(format string, args ...interface{})
    67  	Fatalf(format string, args ...interface{})
    68  }
    69  
    70  // TestHelper is a TestReporter that has the Helper method.  It is satisfied
    71  // by the standard library's *testing.T.
    72  type TestHelper interface {
    73  	TestReporter
    74  	Helper()
    75  }
    76  
    77  // cleanuper is used to check if TestHelper also has the `Cleanup` method. A
    78  // common pattern is to pass in a `*testing.T` to
    79  // `NewController(t TestReporter)`. In Go 1.14+, `*testing.T` has a cleanup
    80  // method. This can be utilized to call `Finish()` so the caller of this library
    81  // does not have to.
    82  type cleanuper interface {
    83  	Cleanup(func())
    84  }
    85  
    86  // A Controller represents the top-level control of a mock ecosystem.  It
    87  // defines the scope and lifetime of mock objects, as well as their
    88  // expectations.  It is safe to call Controller's methods from multiple
    89  // goroutines. Each test should create a new Controller and invoke Finish via
    90  // defer.
    91  //
    92  //   func TestFoo(t *testing.T) {
    93  //     ctrl := gomock.NewController(t)
    94  //     defer ctrl.Finish()
    95  //     // ..
    96  //   }
    97  //
    98  //   func TestBar(t *testing.T) {
    99  //     t.Run("Sub-Test-1", st) {
   100  //       ctrl := gomock.NewController(st)
   101  //       defer ctrl.Finish()
   102  //       // ..
   103  //     })
   104  //     t.Run("Sub-Test-2", st) {
   105  //       ctrl := gomock.NewController(st)
   106  //       defer ctrl.Finish()
   107  //       // ..
   108  //     })
   109  //   })
   110  type Controller struct {
   111  	// T should only be called within a generated mock. It is not intended to
   112  	// be used in user code and may be changed in future versions. T is the
   113  	// TestReporter passed in when creating the Controller via NewController.
   114  	// If the TestReporter does not implement a TestHelper it will be wrapped
   115  	// with a nopTestHelper.
   116  	T             TestHelper
   117  	mu            sync.Mutex
   118  	expectedCalls *callSet
   119  	finished      bool
   120  }
   121  
   122  // NewController returns a new Controller. It is the preferred way to create a
   123  // Controller.
   124  //
   125  // New in go1.14+, if you are passing a *testing.T into this function you no
   126  // longer need to call ctrl.Finish() in your test methods.
   127  func NewController(t TestReporter) *Controller {
   128  	h, ok := t.(TestHelper)
   129  	if !ok {
   130  		h = &nopTestHelper{t}
   131  	}
   132  	ctrl := &Controller{
   133  		T:             h,
   134  		expectedCalls: newCallSet(),
   135  	}
   136  	if c, ok := isCleanuper(ctrl.T); ok {
   137  		c.Cleanup(func() {
   138  			ctrl.T.Helper()
   139  			ctrl.finish(true, nil)
   140  		})
   141  	}
   142  
   143  	return ctrl
   144  }
   145  
   146  type cancelReporter struct {
   147  	t      TestHelper
   148  	cancel func()
   149  }
   150  
   151  func (r *cancelReporter) Errorf(format string, args ...interface{}) {
   152  	r.t.Errorf(format, args...)
   153  }
   154  func (r *cancelReporter) Fatalf(format string, args ...interface{}) {
   155  	defer r.cancel()
   156  	r.t.Fatalf(format, args...)
   157  }
   158  
   159  func (r *cancelReporter) Helper() {
   160  	r.t.Helper()
   161  }
   162  
   163  // WithContext returns a new Controller and a Context, which is cancelled on any
   164  // fatal failure.
   165  func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) {
   166  	h, ok := t.(TestHelper)
   167  	if !ok {
   168  		h = &nopTestHelper{t: t}
   169  	}
   170  
   171  	ctx, cancel := context.WithCancel(ctx)
   172  	return NewController(&cancelReporter{t: h, cancel: cancel}), ctx
   173  }
   174  
   175  type nopTestHelper struct {
   176  	t TestReporter
   177  }
   178  
   179  func (h *nopTestHelper) Errorf(format string, args ...interface{}) {
   180  	h.t.Errorf(format, args...)
   181  }
   182  func (h *nopTestHelper) Fatalf(format string, args ...interface{}) {
   183  	h.t.Fatalf(format, args...)
   184  }
   185  
   186  func (h nopTestHelper) Helper() {}
   187  
   188  // RecordCall is called by a mock. It should not be called by user code.
   189  func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...interface{}) *Call {
   190  	ctrl.T.Helper()
   191  
   192  	recv := reflect.ValueOf(receiver)
   193  	for i := 0; i < recv.Type().NumMethod(); i++ {
   194  		if recv.Type().Method(i).Name == method {
   195  			return ctrl.RecordCallWithMethodType(receiver, method, recv.Method(i).Type(), args...)
   196  		}
   197  	}
   198  	ctrl.T.Fatalf("gomock: failed finding method %s on %T", method, receiver)
   199  	panic("unreachable")
   200  }
   201  
   202  // RecordCallWithMethodType is called by a mock. It should not be called by user code.
   203  func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
   204  	ctrl.T.Helper()
   205  
   206  	call := newCall(ctrl.T, receiver, method, methodType, args...)
   207  
   208  	ctrl.mu.Lock()
   209  	defer ctrl.mu.Unlock()
   210  	ctrl.expectedCalls.Add(call)
   211  
   212  	return call
   213  }
   214  
   215  // Call is called by a mock. It should not be called by user code.
   216  func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {
   217  	ctrl.T.Helper()
   218  
   219  	// Nest this code so we can use defer to make sure the lock is released.
   220  	actions := func() []func([]interface{}) []interface{} {
   221  		ctrl.T.Helper()
   222  		ctrl.mu.Lock()
   223  		defer ctrl.mu.Unlock()
   224  
   225  		expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
   226  		if err != nil {
   227  			// callerInfo's skip should be updated if the number of calls between the user's test
   228  			// and this line changes, i.e. this code is wrapped in another anonymous function.
   229  			// 0 is us, 1 is controller.Call(), 2 is the generated mock, and 3 is the user's test.
   230  			origin := callerInfo(3)
   231  			ctrl.T.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
   232  		}
   233  
   234  		// Two things happen here:
   235  		// * the matching call no longer needs to check prerequite calls,
   236  		// * and the prerequite calls are no longer expected, so remove them.
   237  		preReqCalls := expected.dropPrereqs()
   238  		for _, preReqCall := range preReqCalls {
   239  			ctrl.expectedCalls.Remove(preReqCall)
   240  		}
   241  
   242  		actions := expected.call()
   243  		if expected.exhausted() {
   244  			ctrl.expectedCalls.Remove(expected)
   245  		}
   246  		return actions
   247  	}()
   248  
   249  	var rets []interface{}
   250  	for _, action := range actions {
   251  		if r := action(args); r != nil {
   252  			rets = r
   253  		}
   254  	}
   255  
   256  	return rets
   257  }
   258  
   259  // Finish checks to see if all the methods that were expected to be called
   260  // were called. It should be invoked for each Controller. It is not idempotent
   261  // and therefore can only be invoked once.
   262  //
   263  // New in go1.14+, if you are passing a *testing.T into NewController function you no
   264  // longer need to call ctrl.Finish() in your test methods.
   265  func (ctrl *Controller) Finish() {
   266  	// If we're currently panicking, probably because this is a deferred call.
   267  	// This must be recovered in the deferred function.
   268  	err := recover()
   269  	ctrl.finish(false, err)
   270  }
   271  
   272  func (ctrl *Controller) finish(cleanup bool, panicErr interface{}) {
   273  	ctrl.T.Helper()
   274  
   275  	ctrl.mu.Lock()
   276  	defer ctrl.mu.Unlock()
   277  
   278  	if ctrl.finished {
   279  		if _, ok := isCleanuper(ctrl.T); !ok {
   280  			ctrl.T.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.")
   281  		}
   282  		return
   283  	}
   284  	ctrl.finished = true
   285  
   286  	// Short-circuit, pass through the panic.
   287  	if panicErr != nil {
   288  		panic(panicErr)
   289  	}
   290  
   291  	// Check that all remaining expected calls are satisfied.
   292  	failures := ctrl.expectedCalls.Failures()
   293  	for _, call := range failures {
   294  		ctrl.T.Errorf("missing call(s) to %v", call)
   295  	}
   296  	if len(failures) != 0 {
   297  		if !cleanup {
   298  			ctrl.T.Fatalf("aborting test due to missing call(s)")
   299  			return
   300  		}
   301  		ctrl.T.Errorf("aborting test due to missing call(s)")
   302  	}
   303  }
   304  
   305  // callerInfo returns the file:line of the call site. skip is the number
   306  // of stack frames to skip when reporting. 0 is callerInfo's call site.
   307  func callerInfo(skip int) string {
   308  	if _, file, line, ok := runtime.Caller(skip + 1); ok {
   309  		return fmt.Sprintf("%s:%d", file, line)
   310  	}
   311  	return "unknown file"
   312  }
   313  
   314  // isCleanuper checks it if t's base TestReporter has a Cleanup method.
   315  func isCleanuper(t TestReporter) (cleanuper, bool) {
   316  	tr := unwrapTestReporter(t)
   317  	c, ok := tr.(cleanuper)
   318  	return c, ok
   319  }
   320  
   321  // unwrapTestReporter unwraps TestReporter to the base implementation.
   322  func unwrapTestReporter(t TestReporter) TestReporter {
   323  	tr := t
   324  	switch nt := t.(type) {
   325  	case *cancelReporter:
   326  		tr = nt.t
   327  		if h, check := tr.(*nopTestHelper); check {
   328  			tr = h.t
   329  		}
   330  	case *nopTestHelper:
   331  		tr = nt.t
   332  	default:
   333  		// not wrapped
   334  	}
   335  	return tr
   336  }