github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/worker/testing/runner.go (about)

     1  // Copyright 2015 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package testing
     5  
     6  import (
     7  	"sort"
     8  
     9  	"github.com/juju/errors"
    10  	"github.com/juju/testing"
    11  	jc "github.com/juju/testing/checkers"
    12  	gc "gopkg.in/check.v1"
    13  
    14  	"github.com/juju/juju/worker"
    15  )
    16  
    17  var _ worker.Runner = (*StubRunner)(nil)
    18  
    19  var runnerMethodNames = []string{
    20  	"StartWorker",
    21  	"StopWorker",
    22  }
    23  
    24  // StubRunner is a testing stub for worker.Runner.
    25  type StubRunner struct {
    26  	worker.Worker
    27  	// Stub is the underlying testing stub.
    28  	Stub *testing.Stub
    29  	// CallWhenStarted indicates that the newWorker func should be
    30  	// called when StartWorker is called.
    31  	CallWhenStarted bool
    32  }
    33  
    34  // NewStubRunner returns a new StubRunner.
    35  func NewStubRunner(stub *testing.Stub) *StubRunner {
    36  	return &StubRunner{
    37  		Worker: NewStubWorker(stub),
    38  		Stub:   stub,
    39  	}
    40  }
    41  
    42  func (r *StubRunner) validMethodName(funcName string) bool {
    43  	for _, knownName := range runnerMethodNames {
    44  		if funcName == knownName {
    45  			return true
    46  		}
    47  	}
    48  	return false
    49  }
    50  
    51  func (r *StubRunner) checkCallIDs(c *gc.C, methName string, skipMismatch bool, expected []string) {
    52  	var ids []string
    53  	for _, call := range r.Stub.Calls() {
    54  		if !r.validMethodName(call.FuncName) {
    55  			c.Logf("invalid called func name %q (must be one of %#v)", call.FuncName, runnerMethodNames)
    56  			c.FailNow()
    57  		}
    58  		if methName != "" {
    59  			if skipMismatch && call.FuncName != methName {
    60  				continue
    61  			}
    62  			c.Check(call.FuncName, gc.Equals, methName)
    63  		}
    64  		ids = append(ids, call.Args[0].(string))
    65  	}
    66  	sort.Strings(ids)
    67  	sort.Strings(expected)
    68  	c.Check(ids, jc.DeepEquals, expected)
    69  }
    70  
    71  // CheckCallIDs verifies that the worker IDs in all calls match the
    72  // provided ones. If a method name is provided as well then all calls must
    73  // have that method name.
    74  func (r *StubRunner) CheckCallIDs(c *gc.C, methName string, expected ...string) {
    75  	r.checkCallIDs(c, methName, false, expected)
    76  }
    77  
    78  // StartWorker implements worker.Runner.
    79  func (r *StubRunner) StartWorker(id string, newWorker func() (worker.Worker, error)) error {
    80  	r.Stub.AddCall("StartWorker", id, newWorker)
    81  	if err := r.Stub.NextErr(); err != nil {
    82  		return errors.Trace(err)
    83  	}
    84  
    85  	if r.CallWhenStarted {
    86  		// TODO(ericsnow) Save the workers?
    87  		if _, err := newWorker(); err != nil {
    88  			return errors.Trace(err)
    89  		}
    90  	}
    91  	return nil
    92  }
    93  
    94  // StopWorker implements worker.Runner.
    95  func (r *StubRunner) StopWorker(id string) error {
    96  	r.Stub.AddCall("StopWorker", id)
    97  	if err := r.Stub.NextErr(); err != nil {
    98  		return errors.Trace(err)
    99  	}
   100  
   101  	// Do nothing.
   102  	return nil
   103  }