github.com/cloudbase/juju-core@v0.0.0-20140504232958-a7271ac7912f/worker/peergrouper/mock_test.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package peergrouper
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"path"
    10  	"reflect"
    11  	"sync"
    12  
    13  	"launchpad.net/tomb"
    14  
    15  	"launchpad.net/juju-core/errors"
    16  	"launchpad.net/juju-core/replicaset"
    17  	"launchpad.net/juju-core/state"
    18  	"launchpad.net/juju-core/utils/voyeur"
    19  	"launchpad.net/juju-core/worker"
    20  )
    21  
    22  // This file holds helper functions for mocking pieces of State and replicaset
    23  // that we don't want to directly depend on in unit tests.
    24  
    25  type fakeState struct {
    26  	mu           sync.Mutex
    27  	machines     map[string]*fakeMachine
    28  	stateServers voyeur.Value // of *state.StateServerInfo
    29  	session      *fakeMongoSession
    30  	check        func(st *fakeState) error
    31  }
    32  
    33  var (
    34  	_ stateInterface = (*fakeState)(nil)
    35  	_ stateMachine   = (*fakeMachine)(nil)
    36  	_ mongoSession   = (*fakeMongoSession)(nil)
    37  )
    38  
    39  type errorPattern struct {
    40  	pattern string
    41  	errFunc func() error
    42  }
    43  
    44  var (
    45  	errorsMutex   sync.Mutex
    46  	errorPatterns []errorPattern
    47  )
    48  
    49  // setErrorFor causes the given error to be returned
    50  // from any mock call that matches the given
    51  // string, which may contain wildcards as
    52  // in path.Match.
    53  //
    54  // The standard form for errors is:
    55  //    Type.Function <arg>...
    56  // See individual functions for details.
    57  func setErrorFor(what string, err error) {
    58  	setErrorFuncFor(what, func() error {
    59  		return err
    60  	})
    61  }
    62  
    63  // setErrorFuncFor causes the given function
    64  // to be invoked to return the error for the
    65  // given pattern.
    66  func setErrorFuncFor(what string, errFunc func() error) {
    67  	errorsMutex.Lock()
    68  	defer errorsMutex.Unlock()
    69  	errorPatterns = append(errorPatterns, errorPattern{
    70  		pattern: what,
    71  		errFunc: errFunc,
    72  	})
    73  }
    74  
    75  // errorFor concatenates the call name
    76  // with all the args, space separated,
    77  // and returns any error registered with
    78  // setErrorFor that matches the resulting string.
    79  func errorFor(name string, args ...interface{}) error {
    80  	errorsMutex.Lock()
    81  	s := name
    82  	for _, arg := range args {
    83  		s += " " + fmt.Sprint(arg)
    84  	}
    85  	f := func() error { return nil }
    86  	for _, pattern := range errorPatterns {
    87  		if ok, _ := path.Match(pattern.pattern, s); ok {
    88  			f = pattern.errFunc
    89  			break
    90  		}
    91  	}
    92  	errorsMutex.Unlock()
    93  	err := f()
    94  	logger.Errorf("errorFor %q -> %v", s, err)
    95  	return err
    96  }
    97  
    98  func resetErrors() {
    99  	errorsMutex.Lock()
   100  	defer errorsMutex.Unlock()
   101  	errorPatterns = errorPatterns[:0]
   102  }
   103  
   104  func newFakeState() *fakeState {
   105  	st := &fakeState{
   106  		machines: make(map[string]*fakeMachine),
   107  	}
   108  	st.session = newFakeMongoSession(st)
   109  	st.stateServers.Set(&state.StateServerInfo{})
   110  	return st
   111  }
   112  
   113  func (st *fakeState) MongoSession() mongoSession {
   114  	return st.session
   115  }
   116  
   117  func (st *fakeState) checkInvariants() {
   118  	if st.check == nil {
   119  		return
   120  	}
   121  	if err := st.check(st); err != nil {
   122  		// Force a panic, otherwise we can deadlock
   123  		// when called from within the worker.
   124  		go panic(err)
   125  		select {}
   126  	}
   127  }
   128  
   129  // checkInvariants checks that all the expected invariants
   130  // in the state hold true. Currently we check that:
   131  // - total number of votes is odd.
   132  // - member voting status implies that machine has vote.
   133  func checkInvariants(st *fakeState) error {
   134  	members := st.session.members.Get().([]replicaset.Member)
   135  	voteCount := 0
   136  	for _, m := range members {
   137  		votes := 1
   138  		if m.Votes != nil {
   139  			votes = *m.Votes
   140  		}
   141  		voteCount += votes
   142  		if id, ok := m.Tags["juju-machine-id"]; ok {
   143  			if votes > 0 {
   144  				m := st.machine(id)
   145  				if m == nil {
   146  					return fmt.Errorf("voting member with machine id %q has no associated Machine", id)
   147  				}
   148  				if !m.HasVote() {
   149  					return fmt.Errorf("machine %q should be marked as having the vote, but does not", id)
   150  				}
   151  			}
   152  		}
   153  	}
   154  	if voteCount%2 != 1 {
   155  		return fmt.Errorf("total vote count is not odd (got %d)", voteCount)
   156  	}
   157  	return nil
   158  }
   159  
   160  type invariantChecker interface {
   161  	checkInvariants()
   162  }
   163  
   164  // machine is similar to Machine except that
   165  // it bypasses the error mocking machinery.
   166  // It returns nil if there is no machine with the
   167  // given id.
   168  func (st *fakeState) machine(id string) *fakeMachine {
   169  	st.mu.Lock()
   170  	defer st.mu.Unlock()
   171  	return st.machines[id]
   172  }
   173  
   174  func (st *fakeState) Machine(id string) (stateMachine, error) {
   175  	if err := errorFor("State.Machine", id); err != nil {
   176  		return nil, err
   177  	}
   178  	if m := st.machine(id); m != nil {
   179  		return m, nil
   180  	}
   181  	return nil, errors.NotFoundf("machine %s", id)
   182  }
   183  
   184  func (st *fakeState) addMachine(id string, wantsVote bool) *fakeMachine {
   185  	st.mu.Lock()
   186  	defer st.mu.Unlock()
   187  	logger.Infof("fakeState.addMachine %q", id)
   188  	if st.machines[id] != nil {
   189  		panic(fmt.Errorf("id %q already used", id))
   190  	}
   191  	m := &fakeMachine{
   192  		checker: st,
   193  		doc: machineDoc{
   194  			id:        id,
   195  			wantsVote: wantsVote,
   196  		},
   197  	}
   198  	st.machines[id] = m
   199  	m.val.Set(m.doc)
   200  	return m
   201  }
   202  
   203  func (st *fakeState) removeMachine(id string) {
   204  	st.mu.Lock()
   205  	defer st.mu.Unlock()
   206  	if st.machines[id] == nil {
   207  		panic(fmt.Errorf("removing non-existent machine %q", id))
   208  	}
   209  	delete(st.machines, id)
   210  }
   211  
   212  func (st *fakeState) setStateServers(ids ...string) {
   213  	st.stateServers.Set(&state.StateServerInfo{
   214  		MachineIds: ids,
   215  	})
   216  }
   217  
   218  func (st *fakeState) StateServerInfo() (*state.StateServerInfo, error) {
   219  	if err := errorFor("State.StateServerInfo"); err != nil {
   220  		return nil, err
   221  	}
   222  	return deepCopy(st.stateServers.Get()).(*state.StateServerInfo), nil
   223  }
   224  
   225  func (st *fakeState) WatchStateServerInfo() state.NotifyWatcher {
   226  	return WatchValue(&st.stateServers)
   227  }
   228  
   229  type fakeMachine struct {
   230  	mu      sync.Mutex
   231  	val     voyeur.Value // of machineDoc
   232  	doc     machineDoc
   233  	checker invariantChecker
   234  }
   235  
   236  func (m *fakeMachine) Refresh() error {
   237  	if err := errorFor("Machine.Refresh", m.doc.id); err != nil {
   238  		return err
   239  	}
   240  	m.doc = m.val.Get().(machineDoc)
   241  	return nil
   242  }
   243  
   244  func (m *fakeMachine) GoString() string {
   245  	return fmt.Sprintf("&fakeMachine{%#v}", m.doc)
   246  }
   247  
   248  func (m *fakeMachine) Id() string {
   249  	return m.doc.id
   250  }
   251  
   252  func (m *fakeMachine) Watch() state.NotifyWatcher {
   253  	return WatchValue(&m.val)
   254  }
   255  
   256  func (m *fakeMachine) WantsVote() bool {
   257  	return m.doc.wantsVote
   258  }
   259  
   260  func (m *fakeMachine) HasVote() bool {
   261  	return m.doc.hasVote
   262  }
   263  
   264  func (m *fakeMachine) StateHostPort() string {
   265  	return m.doc.hostPort
   266  }
   267  
   268  // mutate atomically changes the machineDoc of
   269  // the receiver by mutating it with the provided function.
   270  func (m *fakeMachine) mutate(f func(*machineDoc)) {
   271  	m.mu.Lock()
   272  	doc := m.val.Get().(machineDoc)
   273  	f(&doc)
   274  	m.val.Set(doc)
   275  	f(&m.doc)
   276  	m.mu.Unlock()
   277  	m.checker.checkInvariants()
   278  }
   279  
   280  func (m *fakeMachine) setStateHostPort(hostPort string) {
   281  	m.mutate(func(doc *machineDoc) {
   282  		doc.hostPort = hostPort
   283  	})
   284  }
   285  
   286  // SetHasVote implements stateMachine.SetHasVote.
   287  func (m *fakeMachine) SetHasVote(hasVote bool) error {
   288  	if err := errorFor("Machine.SetHasVote", m.doc.id, hasVote); err != nil {
   289  		return err
   290  	}
   291  	m.mutate(func(doc *machineDoc) {
   292  		doc.hasVote = hasVote
   293  	})
   294  	return nil
   295  }
   296  
   297  func (m *fakeMachine) setWantsVote(wantsVote bool) {
   298  	m.mutate(func(doc *machineDoc) {
   299  		doc.wantsVote = wantsVote
   300  	})
   301  }
   302  
   303  type machineDoc struct {
   304  	id        string
   305  	wantsVote bool
   306  	hasVote   bool
   307  	hostPort  string
   308  }
   309  
   310  type fakeMongoSession struct {
   311  	// If InstantlyReady is true, replica status of
   312  	// all members will be instantly reported as ready.
   313  	InstantlyReady bool
   314  
   315  	checker invariantChecker
   316  	members voyeur.Value // of []replicaset.Member
   317  	status  voyeur.Value // of *replicaset.Status
   318  }
   319  
   320  // newFakeMongoSession returns a mock implementation of mongoSession.
   321  func newFakeMongoSession(checker invariantChecker) *fakeMongoSession {
   322  	s := new(fakeMongoSession)
   323  	s.checker = checker
   324  	s.members.Set([]replicaset.Member(nil))
   325  	s.status.Set(&replicaset.Status{})
   326  	return s
   327  }
   328  
   329  // CurrentMembers implements mongoSession.CurrentMembers.
   330  func (session *fakeMongoSession) CurrentMembers() ([]replicaset.Member, error) {
   331  	if err := errorFor("Session.CurrentMembers"); err != nil {
   332  		return nil, err
   333  	}
   334  	return deepCopy(session.members.Get()).([]replicaset.Member), nil
   335  }
   336  
   337  // CurrentStatus implements mongoSession.CurrentStatus.
   338  func (session *fakeMongoSession) CurrentStatus() (*replicaset.Status, error) {
   339  	if err := errorFor("Session.CurrentStatus"); err != nil {
   340  		return nil, err
   341  	}
   342  	return deepCopy(session.status.Get()).(*replicaset.Status), nil
   343  }
   344  
   345  // setStatus sets the status of the current members of the session.
   346  func (session *fakeMongoSession) setStatus(members []replicaset.MemberStatus) {
   347  	session.status.Set(deepCopy(&replicaset.Status{
   348  		Members: members,
   349  	}))
   350  }
   351  
   352  // Set implements mongoSession.Set
   353  func (session *fakeMongoSession) Set(members []replicaset.Member) error {
   354  	if err := errorFor("Session.Set"); err != nil {
   355  		logger.Infof("not setting replicaset members to %#v", members)
   356  		return err
   357  	}
   358  	logger.Infof("setting replicaset members to %#v", members)
   359  	session.members.Set(deepCopy(members))
   360  	if session.InstantlyReady {
   361  		statuses := make([]replicaset.MemberStatus, len(members))
   362  		for i, m := range members {
   363  			statuses[i] = replicaset.MemberStatus{
   364  				Id:      m.Id,
   365  				Address: m.Address,
   366  				Healthy: true,
   367  				State:   replicaset.SecondaryState,
   368  			}
   369  			if i == 0 {
   370  				statuses[i].State = replicaset.PrimaryState
   371  			}
   372  		}
   373  		session.setStatus(statuses)
   374  	}
   375  	session.checker.checkInvariants()
   376  	return nil
   377  }
   378  
   379  // deepCopy makes a deep copy of any type by marshalling
   380  // it as JSON, then unmarshalling it.
   381  func deepCopy(x interface{}) interface{} {
   382  	v := reflect.ValueOf(x)
   383  	data, err := json.Marshal(x)
   384  	if err != nil {
   385  		panic(fmt.Errorf("cannot marshal %#v: %v", x, err))
   386  	}
   387  	newv := reflect.New(v.Type())
   388  	if err := json.Unmarshal(data, newv.Interface()); err != nil {
   389  		panic(fmt.Errorf("cannot unmarshal %q into %s", data, newv.Type()))
   390  	}
   391  	// sanity check
   392  	newx := newv.Elem().Interface()
   393  	if !reflect.DeepEqual(newx, x) {
   394  		panic(fmt.Errorf("value not deep-copied correctly"))
   395  	}
   396  	return newx
   397  }
   398  
   399  type notifier struct {
   400  	tomb    tomb.Tomb
   401  	w       *voyeur.Watcher
   402  	changes chan struct{}
   403  }
   404  
   405  // WatchValue returns a NotifyWatcher that triggers
   406  // when the given value changes. Its Wait and Err methods
   407  // never return a non-nil error.
   408  func WatchValue(val *voyeur.Value) state.NotifyWatcher {
   409  	n := &notifier{
   410  		w:       val.Watch(),
   411  		changes: make(chan struct{}),
   412  	}
   413  	go n.loop()
   414  	return n
   415  }
   416  
   417  func (n *notifier) loop() {
   418  	defer n.tomb.Done()
   419  	for n.w.Next() {
   420  		select {
   421  		case n.changes <- struct{}{}:
   422  		case <-n.tomb.Dying():
   423  		}
   424  	}
   425  }
   426  
   427  // Changes returns a channel that sends a value when the value changes.
   428  // The value itself can be retrieved by calling the value's Get method.
   429  func (n *notifier) Changes() <-chan struct{} {
   430  	return n.changes
   431  }
   432  
   433  // Kill stops the notifier but does not wait for it to finish.
   434  func (n *notifier) Kill() {
   435  	n.tomb.Kill(nil)
   436  	n.w.Close()
   437  }
   438  
   439  func (n *notifier) Err() error {
   440  	return n.tomb.Err()
   441  }
   442  
   443  // Wait waits for the notifier to finish. It always returns nil.
   444  func (n *notifier) Wait() error {
   445  	return n.tomb.Wait()
   446  }
   447  
   448  func (n *notifier) Stop() error {
   449  	return worker.Stop(n)
   450  }