github.com/mattyw/juju@v0.0.0-20140610034352-732aecd63861/worker/peergrouper/mock_test.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package peergrouper
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"net"
    10  	"path"
    11  	"reflect"
    12  	"strconv"
    13  	"sync"
    14  
    15  	"github.com/juju/errors"
    16  	"github.com/juju/utils/voyeur"
    17  	"launchpad.net/tomb"
    18  
    19  	"github.com/juju/juju/instance"
    20  	"github.com/juju/juju/replicaset"
    21  	"github.com/juju/juju/state"
    22  	"github.com/juju/juju/worker"
    23  )
    24  
    25  // This file holds helper functions for mocking pieces of State and replicaset
    26  // that we don't want to directly depend on in unit tests.
    27  
    28  type fakeState struct {
    29  	mu           sync.Mutex
    30  	machines     map[string]*fakeMachine
    31  	stateServers voyeur.Value // of *state.StateServerInfo
    32  	session      *fakeMongoSession
    33  	check        func(st *fakeState) error
    34  }
    35  
    36  var (
    37  	_ stateInterface = (*fakeState)(nil)
    38  	_ stateMachine   = (*fakeMachine)(nil)
    39  	_ mongoSession   = (*fakeMongoSession)(nil)
    40  )
    41  
    42  type errorPattern struct {
    43  	pattern string
    44  	errFunc func() error
    45  }
    46  
    47  var (
    48  	errorsMutex   sync.Mutex
    49  	errorPatterns []errorPattern
    50  )
    51  
    52  // setErrorFor causes the given error to be returned
    53  // from any mock call that matches the given
    54  // string, which may contain wildcards as
    55  // in path.Match.
    56  //
    57  // The standard form for errors is:
    58  //    Type.Function <arg>...
    59  // See individual functions for details.
    60  func setErrorFor(what string, err error) {
    61  	setErrorFuncFor(what, func() error {
    62  		return err
    63  	})
    64  }
    65  
    66  // setErrorFuncFor causes the given function
    67  // to be invoked to return the error for the
    68  // given pattern.
    69  func setErrorFuncFor(what string, errFunc func() error) {
    70  	errorsMutex.Lock()
    71  	defer errorsMutex.Unlock()
    72  	errorPatterns = append(errorPatterns, errorPattern{
    73  		pattern: what,
    74  		errFunc: errFunc,
    75  	})
    76  }
    77  
    78  // errorFor concatenates the call name
    79  // with all the args, space separated,
    80  // and returns any error registered with
    81  // setErrorFor that matches the resulting string.
    82  func errorFor(name string, args ...interface{}) error {
    83  	errorsMutex.Lock()
    84  	s := name
    85  	for _, arg := range args {
    86  		s += " " + fmt.Sprint(arg)
    87  	}
    88  	f := func() error { return nil }
    89  	for _, pattern := range errorPatterns {
    90  		if ok, _ := path.Match(pattern.pattern, s); ok {
    91  			f = pattern.errFunc
    92  			break
    93  		}
    94  	}
    95  	errorsMutex.Unlock()
    96  	err := f()
    97  	logger.Errorf("errorFor %q -> %v", s, err)
    98  	return err
    99  }
   100  
   101  func resetErrors() {
   102  	errorsMutex.Lock()
   103  	defer errorsMutex.Unlock()
   104  	errorPatterns = errorPatterns[:0]
   105  }
   106  
   107  func newFakeState() *fakeState {
   108  	st := &fakeState{
   109  		machines: make(map[string]*fakeMachine),
   110  	}
   111  	st.session = newFakeMongoSession(st)
   112  	st.stateServers.Set(&state.StateServerInfo{})
   113  	return st
   114  }
   115  
   116  func (st *fakeState) MongoSession() mongoSession {
   117  	return st.session
   118  }
   119  
   120  func (st *fakeState) checkInvariants() {
   121  	if st.check == nil {
   122  		return
   123  	}
   124  	if err := st.check(st); err != nil {
   125  		// Force a panic, otherwise we can deadlock
   126  		// when called from within the worker.
   127  		go panic(err)
   128  		select {}
   129  	}
   130  }
   131  
   132  // checkInvariants checks that all the expected invariants
   133  // in the state hold true. Currently we check that:
   134  // - total number of votes is odd.
   135  // - member voting status implies that machine has vote.
   136  func checkInvariants(st *fakeState) error {
   137  	members := st.session.members.Get().([]replicaset.Member)
   138  	voteCount := 0
   139  	for _, m := range members {
   140  		votes := 1
   141  		if m.Votes != nil {
   142  			votes = *m.Votes
   143  		}
   144  		voteCount += votes
   145  		if id, ok := m.Tags[jujuMachineTag]; ok {
   146  			if votes > 0 {
   147  				m := st.machine(id)
   148  				if m == nil {
   149  					return fmt.Errorf("voting member with machine id %q has no associated Machine", id)
   150  				}
   151  				if !m.HasVote() {
   152  					return fmt.Errorf("machine %q should be marked as having the vote, but does not", id)
   153  				}
   154  			}
   155  		}
   156  	}
   157  	if voteCount%2 != 1 {
   158  		return fmt.Errorf("total vote count is not odd (got %d)", voteCount)
   159  	}
   160  	return nil
   161  }
   162  
   163  type invariantChecker interface {
   164  	checkInvariants()
   165  }
   166  
   167  // machine is similar to Machine except that
   168  // it bypasses the error mocking machinery.
   169  // It returns nil if there is no machine with the
   170  // given id.
   171  func (st *fakeState) machine(id string) *fakeMachine {
   172  	st.mu.Lock()
   173  	defer st.mu.Unlock()
   174  	return st.machines[id]
   175  }
   176  
   177  func (st *fakeState) Machine(id string) (stateMachine, error) {
   178  	if err := errorFor("State.Machine", id); err != nil {
   179  		return nil, err
   180  	}
   181  	if m := st.machine(id); m != nil {
   182  		return m, nil
   183  	}
   184  	return nil, errors.NotFoundf("machine %s", id)
   185  }
   186  
   187  func (st *fakeState) addMachine(id string, wantsVote bool) *fakeMachine {
   188  	st.mu.Lock()
   189  	defer st.mu.Unlock()
   190  	logger.Infof("fakeState.addMachine %q", id)
   191  	if st.machines[id] != nil {
   192  		panic(fmt.Errorf("id %q already used", id))
   193  	}
   194  	m := &fakeMachine{
   195  		checker: st,
   196  		doc: machineDoc{
   197  			id:        id,
   198  			wantsVote: wantsVote,
   199  		},
   200  	}
   201  	st.machines[id] = m
   202  	m.val.Set(m.doc)
   203  	return m
   204  }
   205  
   206  func (st *fakeState) removeMachine(id string) {
   207  	st.mu.Lock()
   208  	defer st.mu.Unlock()
   209  	if st.machines[id] == nil {
   210  		panic(fmt.Errorf("removing non-existent machine %q", id))
   211  	}
   212  	delete(st.machines, id)
   213  }
   214  
   215  func (st *fakeState) setStateServers(ids ...string) {
   216  	st.stateServers.Set(&state.StateServerInfo{
   217  		MachineIds: ids,
   218  	})
   219  }
   220  
   221  func (st *fakeState) StateServerInfo() (*state.StateServerInfo, error) {
   222  	if err := errorFor("State.StateServerInfo"); err != nil {
   223  		return nil, err
   224  	}
   225  	return deepCopy(st.stateServers.Get()).(*state.StateServerInfo), nil
   226  }
   227  
   228  func (st *fakeState) WatchStateServerInfo() state.NotifyWatcher {
   229  	return WatchValue(&st.stateServers)
   230  }
   231  
   232  type fakeMachine struct {
   233  	mu      sync.Mutex
   234  	val     voyeur.Value // of machineDoc
   235  	doc     machineDoc
   236  	checker invariantChecker
   237  }
   238  
   239  type machineDoc struct {
   240  	id             string
   241  	wantsVote      bool
   242  	hasVote        bool
   243  	instanceId     instance.Id
   244  	mongoHostPorts []instance.HostPort
   245  	apiHostPorts   []instance.HostPort
   246  }
   247  
   248  func (m *fakeMachine) Refresh() error {
   249  	if err := errorFor("Machine.Refresh", m.doc.id); err != nil {
   250  		return err
   251  	}
   252  	m.doc = m.val.Get().(machineDoc)
   253  	return nil
   254  }
   255  
   256  func (m *fakeMachine) GoString() string {
   257  	return fmt.Sprintf("&fakeMachine{%#v}", m.doc)
   258  }
   259  
   260  func (m *fakeMachine) Id() string {
   261  	return m.doc.id
   262  }
   263  
   264  func (m *fakeMachine) InstanceId() (instance.Id, error) {
   265  	if err := errorFor("Machine.InstanceId", m.doc.id); err != nil {
   266  		return "", err
   267  	}
   268  	return m.doc.instanceId, nil
   269  }
   270  
   271  func (m *fakeMachine) Watch() state.NotifyWatcher {
   272  	return WatchValue(&m.val)
   273  }
   274  
   275  func (m *fakeMachine) WantsVote() bool {
   276  	return m.doc.wantsVote
   277  }
   278  
   279  func (m *fakeMachine) HasVote() bool {
   280  	return m.doc.hasVote
   281  }
   282  
   283  func (m *fakeMachine) MongoHostPorts() []instance.HostPort {
   284  	return m.doc.mongoHostPorts
   285  }
   286  
   287  func (m *fakeMachine) APIHostPorts() []instance.HostPort {
   288  	return m.doc.apiHostPorts
   289  }
   290  
   291  // mutate atomically changes the machineDoc of
   292  // the receiver by mutating it with the provided function.
   293  func (m *fakeMachine) mutate(f func(*machineDoc)) {
   294  	m.mu.Lock()
   295  	doc := m.val.Get().(machineDoc)
   296  	f(&doc)
   297  	m.val.Set(doc)
   298  	f(&m.doc)
   299  	m.mu.Unlock()
   300  	m.checker.checkInvariants()
   301  }
   302  
   303  func (m *fakeMachine) setStateHostPort(hostPort string) {
   304  	var mongoHostPorts []instance.HostPort
   305  	if hostPort != "" {
   306  		host, portStr, err := net.SplitHostPort(hostPort)
   307  		if err != nil {
   308  			panic(err)
   309  		}
   310  		port, err := strconv.Atoi(portStr)
   311  		if err != nil {
   312  			panic(err)
   313  		}
   314  		mongoHostPorts = instance.AddressesWithPort(instance.NewAddresses(host), port)
   315  		mongoHostPorts[0].NetworkScope = instance.NetworkCloudLocal
   316  	}
   317  
   318  	m.mutate(func(doc *machineDoc) {
   319  		doc.mongoHostPorts = mongoHostPorts
   320  	})
   321  }
   322  
   323  func (m *fakeMachine) setMongoHostPorts(hostPorts []instance.HostPort) {
   324  	m.mutate(func(doc *machineDoc) {
   325  		doc.mongoHostPorts = hostPorts
   326  	})
   327  }
   328  
   329  func (m *fakeMachine) setAPIHostPorts(hostPorts []instance.HostPort) {
   330  	m.mutate(func(doc *machineDoc) {
   331  		doc.apiHostPorts = hostPorts
   332  	})
   333  }
   334  
   335  func (m *fakeMachine) setInstanceId(instanceId instance.Id) {
   336  	m.mutate(func(doc *machineDoc) {
   337  		doc.instanceId = instanceId
   338  	})
   339  }
   340  
   341  // SetHasVote implements stateMachine.SetHasVote.
   342  func (m *fakeMachine) SetHasVote(hasVote bool) error {
   343  	if err := errorFor("Machine.SetHasVote", m.doc.id, hasVote); err != nil {
   344  		return err
   345  	}
   346  	m.mutate(func(doc *machineDoc) {
   347  		doc.hasVote = hasVote
   348  	})
   349  	return nil
   350  }
   351  
   352  func (m *fakeMachine) setWantsVote(wantsVote bool) {
   353  	m.mutate(func(doc *machineDoc) {
   354  		doc.wantsVote = wantsVote
   355  	})
   356  }
   357  
   358  type fakeMongoSession struct {
   359  	// If InstantlyReady is true, replica status of
   360  	// all members will be instantly reported as ready.
   361  	InstantlyReady bool
   362  
   363  	checker invariantChecker
   364  	members voyeur.Value // of []replicaset.Member
   365  	status  voyeur.Value // of *replicaset.Status
   366  }
   367  
   368  // newFakeMongoSession returns a mock implementation of mongoSession.
   369  func newFakeMongoSession(checker invariantChecker) *fakeMongoSession {
   370  	s := new(fakeMongoSession)
   371  	s.checker = checker
   372  	s.members.Set([]replicaset.Member(nil))
   373  	s.status.Set(&replicaset.Status{})
   374  	return s
   375  }
   376  
   377  // CurrentMembers implements mongoSession.CurrentMembers.
   378  func (session *fakeMongoSession) CurrentMembers() ([]replicaset.Member, error) {
   379  	if err := errorFor("Session.CurrentMembers"); err != nil {
   380  		return nil, err
   381  	}
   382  	return deepCopy(session.members.Get()).([]replicaset.Member), nil
   383  }
   384  
   385  // CurrentStatus implements mongoSession.CurrentStatus.
   386  func (session *fakeMongoSession) CurrentStatus() (*replicaset.Status, error) {
   387  	if err := errorFor("Session.CurrentStatus"); err != nil {
   388  		return nil, err
   389  	}
   390  	return deepCopy(session.status.Get()).(*replicaset.Status), nil
   391  }
   392  
   393  // setStatus sets the status of the current members of the session.
   394  func (session *fakeMongoSession) setStatus(members []replicaset.MemberStatus) {
   395  	session.status.Set(deepCopy(&replicaset.Status{
   396  		Members: members,
   397  	}))
   398  }
   399  
   400  // Set implements mongoSession.Set
   401  func (session *fakeMongoSession) Set(members []replicaset.Member) error {
   402  	if err := errorFor("Session.Set"); err != nil {
   403  		logger.Infof("not setting replicaset members to %#v", members)
   404  		return err
   405  	}
   406  	logger.Infof("setting replicaset members to %#v", members)
   407  	session.members.Set(deepCopy(members))
   408  	if session.InstantlyReady {
   409  		statuses := make([]replicaset.MemberStatus, len(members))
   410  		for i, m := range members {
   411  			statuses[i] = replicaset.MemberStatus{
   412  				Id:      m.Id,
   413  				Address: m.Address,
   414  				Healthy: true,
   415  				State:   replicaset.SecondaryState,
   416  			}
   417  			if i == 0 {
   418  				statuses[i].State = replicaset.PrimaryState
   419  			}
   420  		}
   421  		session.setStatus(statuses)
   422  	}
   423  	session.checker.checkInvariants()
   424  	return nil
   425  }
   426  
   427  // deepCopy makes a deep copy of any type by marshalling
   428  // it as JSON, then unmarshalling it.
   429  func deepCopy(x interface{}) interface{} {
   430  	v := reflect.ValueOf(x)
   431  	data, err := json.Marshal(x)
   432  	if err != nil {
   433  		panic(fmt.Errorf("cannot marshal %#v: %v", x, err))
   434  	}
   435  	newv := reflect.New(v.Type())
   436  	if err := json.Unmarshal(data, newv.Interface()); err != nil {
   437  		panic(fmt.Errorf("cannot unmarshal %q into %s", data, newv.Type()))
   438  	}
   439  	// sanity check
   440  	newx := newv.Elem().Interface()
   441  	if !reflect.DeepEqual(newx, x) {
   442  		panic(fmt.Errorf("value not deep-copied correctly"))
   443  	}
   444  	return newx
   445  }
   446  
   447  type notifier struct {
   448  	tomb    tomb.Tomb
   449  	w       *voyeur.Watcher
   450  	changes chan struct{}
   451  }
   452  
   453  // WatchValue returns a NotifyWatcher that triggers
   454  // when the given value changes. Its Wait and Err methods
   455  // never return a non-nil error.
   456  func WatchValue(val *voyeur.Value) state.NotifyWatcher {
   457  	n := &notifier{
   458  		w:       val.Watch(),
   459  		changes: make(chan struct{}),
   460  	}
   461  	go n.loop()
   462  	return n
   463  }
   464  
   465  func (n *notifier) loop() {
   466  	defer n.tomb.Done()
   467  	for n.w.Next() {
   468  		select {
   469  		case n.changes <- struct{}{}:
   470  		case <-n.tomb.Dying():
   471  		}
   472  	}
   473  }
   474  
   475  // Changes returns a channel that sends a value when the value changes.
   476  // The value itself can be retrieved by calling the value's Get method.
   477  func (n *notifier) Changes() <-chan struct{} {
   478  	return n.changes
   479  }
   480  
   481  // Kill stops the notifier but does not wait for it to finish.
   482  func (n *notifier) Kill() {
   483  	n.tomb.Kill(nil)
   484  	n.w.Close()
   485  }
   486  
   487  func (n *notifier) Err() error {
   488  	return n.tomb.Err()
   489  }
   490  
   491  // Wait waits for the notifier to finish. It always returns nil.
   492  func (n *notifier) Wait() error {
   493  	return n.tomb.Wait()
   494  }
   495  
   496  func (n *notifier) Stop() error {
   497  	return worker.Stop(n)
   498  }