github.com/rogpeppe/juju@v0.0.0-20140613142852-6337964b789e/worker/peergrouper/mock_test.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package peergrouper
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"net"
    10  	"path"
    11  	"reflect"
    12  	"strconv"
    13  	"sync"
    14  
    15  	"github.com/juju/errors"
    16  	"github.com/juju/utils/voyeur"
    17  	"launchpad.net/tomb"
    18  
    19  	"github.com/juju/juju/instance"
    20  	"github.com/juju/juju/network"
    21  	"github.com/juju/juju/replicaset"
    22  	"github.com/juju/juju/state"
    23  	"github.com/juju/juju/worker"
    24  )
    25  
    26  // This file holds helper functions for mocking pieces of State and replicaset
    27  // that we don't want to directly depend on in unit tests.
    28  
    29  type fakeState struct {
    30  	mu           sync.Mutex
    31  	machines     map[string]*fakeMachine
    32  	stateServers voyeur.Value // of *state.StateServerInfo
    33  	session      *fakeMongoSession
    34  	check        func(st *fakeState) error
    35  }
    36  
    37  var (
    38  	_ stateInterface = (*fakeState)(nil)
    39  	_ stateMachine   = (*fakeMachine)(nil)
    40  	_ mongoSession   = (*fakeMongoSession)(nil)
    41  )
    42  
    43  type errorPattern struct {
    44  	pattern string
    45  	errFunc func() error
    46  }
    47  
    48  var (
    49  	errorsMutex   sync.Mutex
    50  	errorPatterns []errorPattern
    51  )
    52  
    53  // setErrorFor causes the given error to be returned
    54  // from any mock call that matches the given
    55  // string, which may contain wildcards as
    56  // in path.Match.
    57  //
    58  // The standard form for errors is:
    59  //    Type.Function <arg>...
    60  // See individual functions for details.
    61  func setErrorFor(what string, err error) {
    62  	setErrorFuncFor(what, func() error {
    63  		return err
    64  	})
    65  }
    66  
    67  // setErrorFuncFor causes the given function
    68  // to be invoked to return the error for the
    69  // given pattern.
    70  func setErrorFuncFor(what string, errFunc func() error) {
    71  	errorsMutex.Lock()
    72  	defer errorsMutex.Unlock()
    73  	errorPatterns = append(errorPatterns, errorPattern{
    74  		pattern: what,
    75  		errFunc: errFunc,
    76  	})
    77  }
    78  
    79  // errorFor concatenates the call name
    80  // with all the args, space separated,
    81  // and returns any error registered with
    82  // setErrorFor that matches the resulting string.
    83  func errorFor(name string, args ...interface{}) error {
    84  	errorsMutex.Lock()
    85  	s := name
    86  	for _, arg := range args {
    87  		s += " " + fmt.Sprint(arg)
    88  	}
    89  	f := func() error { return nil }
    90  	for _, pattern := range errorPatterns {
    91  		if ok, _ := path.Match(pattern.pattern, s); ok {
    92  			f = pattern.errFunc
    93  			break
    94  		}
    95  	}
    96  	errorsMutex.Unlock()
    97  	err := f()
    98  	logger.Errorf("errorFor %q -> %v", s, err)
    99  	return err
   100  }
   101  
   102  func resetErrors() {
   103  	errorsMutex.Lock()
   104  	defer errorsMutex.Unlock()
   105  	errorPatterns = errorPatterns[:0]
   106  }
   107  
   108  func newFakeState() *fakeState {
   109  	st := &fakeState{
   110  		machines: make(map[string]*fakeMachine),
   111  	}
   112  	st.session = newFakeMongoSession(st)
   113  	st.stateServers.Set(&state.StateServerInfo{})
   114  	return st
   115  }
   116  
   117  func (st *fakeState) MongoSession() mongoSession {
   118  	return st.session
   119  }
   120  
   121  func (st *fakeState) checkInvariants() {
   122  	if st.check == nil {
   123  		return
   124  	}
   125  	if err := st.check(st); err != nil {
   126  		// Force a panic, otherwise we can deadlock
   127  		// when called from within the worker.
   128  		go panic(err)
   129  		select {}
   130  	}
   131  }
   132  
   133  // checkInvariants checks that all the expected invariants
   134  // in the state hold true. Currently we check that:
   135  // - total number of votes is odd.
   136  // - member voting status implies that machine has vote.
   137  func checkInvariants(st *fakeState) error {
   138  	members := st.session.members.Get().([]replicaset.Member)
   139  	voteCount := 0
   140  	for _, m := range members {
   141  		votes := 1
   142  		if m.Votes != nil {
   143  			votes = *m.Votes
   144  		}
   145  		voteCount += votes
   146  		if id, ok := m.Tags[jujuMachineTag]; ok {
   147  			if votes > 0 {
   148  				m := st.machine(id)
   149  				if m == nil {
   150  					return fmt.Errorf("voting member with machine id %q has no associated Machine", id)
   151  				}
   152  				if !m.HasVote() {
   153  					return fmt.Errorf("machine %q should be marked as having the vote, but does not", id)
   154  				}
   155  			}
   156  		}
   157  	}
   158  	if voteCount%2 != 1 {
   159  		return fmt.Errorf("total vote count is not odd (got %d)", voteCount)
   160  	}
   161  	return nil
   162  }
   163  
   164  type invariantChecker interface {
   165  	checkInvariants()
   166  }
   167  
   168  // machine is similar to Machine except that
   169  // it bypasses the error mocking machinery.
   170  // It returns nil if there is no machine with the
   171  // given id.
   172  func (st *fakeState) machine(id string) *fakeMachine {
   173  	st.mu.Lock()
   174  	defer st.mu.Unlock()
   175  	return st.machines[id]
   176  }
   177  
   178  func (st *fakeState) Machine(id string) (stateMachine, error) {
   179  	if err := errorFor("State.Machine", id); err != nil {
   180  		return nil, err
   181  	}
   182  	if m := st.machine(id); m != nil {
   183  		return m, nil
   184  	}
   185  	return nil, errors.NotFoundf("machine %s", id)
   186  }
   187  
   188  func (st *fakeState) addMachine(id string, wantsVote bool) *fakeMachine {
   189  	st.mu.Lock()
   190  	defer st.mu.Unlock()
   191  	logger.Infof("fakeState.addMachine %q", id)
   192  	if st.machines[id] != nil {
   193  		panic(fmt.Errorf("id %q already used", id))
   194  	}
   195  	m := &fakeMachine{
   196  		checker: st,
   197  		doc: machineDoc{
   198  			id:        id,
   199  			wantsVote: wantsVote,
   200  		},
   201  	}
   202  	st.machines[id] = m
   203  	m.val.Set(m.doc)
   204  	return m
   205  }
   206  
   207  func (st *fakeState) removeMachine(id string) {
   208  	st.mu.Lock()
   209  	defer st.mu.Unlock()
   210  	if st.machines[id] == nil {
   211  		panic(fmt.Errorf("removing non-existent machine %q", id))
   212  	}
   213  	delete(st.machines, id)
   214  }
   215  
   216  func (st *fakeState) setStateServers(ids ...string) {
   217  	st.stateServers.Set(&state.StateServerInfo{
   218  		MachineIds: ids,
   219  	})
   220  }
   221  
   222  func (st *fakeState) StateServerInfo() (*state.StateServerInfo, error) {
   223  	if err := errorFor("State.StateServerInfo"); err != nil {
   224  		return nil, err
   225  	}
   226  	return deepCopy(st.stateServers.Get()).(*state.StateServerInfo), nil
   227  }
   228  
   229  func (st *fakeState) WatchStateServerInfo() state.NotifyWatcher {
   230  	return WatchValue(&st.stateServers)
   231  }
   232  
   233  type fakeMachine struct {
   234  	mu      sync.Mutex
   235  	val     voyeur.Value // of machineDoc
   236  	doc     machineDoc
   237  	checker invariantChecker
   238  }
   239  
   240  type machineDoc struct {
   241  	id             string
   242  	wantsVote      bool
   243  	hasVote        bool
   244  	instanceId     instance.Id
   245  	mongoHostPorts []network.HostPort
   246  	apiHostPorts   []network.HostPort
   247  }
   248  
   249  func (m *fakeMachine) Refresh() error {
   250  	if err := errorFor("Machine.Refresh", m.doc.id); err != nil {
   251  		return err
   252  	}
   253  	m.doc = m.val.Get().(machineDoc)
   254  	return nil
   255  }
   256  
   257  func (m *fakeMachine) GoString() string {
   258  	return fmt.Sprintf("&fakeMachine{%#v}", m.doc)
   259  }
   260  
   261  func (m *fakeMachine) Id() string {
   262  	return m.doc.id
   263  }
   264  
   265  func (m *fakeMachine) InstanceId() (instance.Id, error) {
   266  	if err := errorFor("Machine.InstanceId", m.doc.id); err != nil {
   267  		return "", err
   268  	}
   269  	return m.doc.instanceId, nil
   270  }
   271  
   272  func (m *fakeMachine) Watch() state.NotifyWatcher {
   273  	return WatchValue(&m.val)
   274  }
   275  
   276  func (m *fakeMachine) WantsVote() bool {
   277  	return m.doc.wantsVote
   278  }
   279  
   280  func (m *fakeMachine) HasVote() bool {
   281  	return m.doc.hasVote
   282  }
   283  
   284  func (m *fakeMachine) MongoHostPorts() []network.HostPort {
   285  	return m.doc.mongoHostPorts
   286  }
   287  
   288  func (m *fakeMachine) APIHostPorts() []network.HostPort {
   289  	return m.doc.apiHostPorts
   290  }
   291  
   292  // mutate atomically changes the machineDoc of
   293  // the receiver by mutating it with the provided function.
   294  func (m *fakeMachine) mutate(f func(*machineDoc)) {
   295  	m.mu.Lock()
   296  	doc := m.val.Get().(machineDoc)
   297  	f(&doc)
   298  	m.val.Set(doc)
   299  	f(&m.doc)
   300  	m.mu.Unlock()
   301  	m.checker.checkInvariants()
   302  }
   303  
   304  func (m *fakeMachine) setStateHostPort(hostPort string) {
   305  	var mongoHostPorts []network.HostPort
   306  	if hostPort != "" {
   307  		host, portStr, err := net.SplitHostPort(hostPort)
   308  		if err != nil {
   309  			panic(err)
   310  		}
   311  		port, err := strconv.Atoi(portStr)
   312  		if err != nil {
   313  			panic(err)
   314  		}
   315  		mongoHostPorts = network.AddressesWithPort(network.NewAddresses(host), port)
   316  		mongoHostPorts[0].Scope = network.ScopeCloudLocal
   317  	}
   318  
   319  	m.mutate(func(doc *machineDoc) {
   320  		doc.mongoHostPorts = mongoHostPorts
   321  	})
   322  }
   323  
   324  func (m *fakeMachine) setMongoHostPorts(hostPorts []network.HostPort) {
   325  	m.mutate(func(doc *machineDoc) {
   326  		doc.mongoHostPorts = hostPorts
   327  	})
   328  }
   329  
   330  func (m *fakeMachine) setAPIHostPorts(hostPorts []network.HostPort) {
   331  	m.mutate(func(doc *machineDoc) {
   332  		doc.apiHostPorts = hostPorts
   333  	})
   334  }
   335  
   336  func (m *fakeMachine) setInstanceId(instanceId instance.Id) {
   337  	m.mutate(func(doc *machineDoc) {
   338  		doc.instanceId = instanceId
   339  	})
   340  }
   341  
   342  // SetHasVote implements stateMachine.SetHasVote.
   343  func (m *fakeMachine) SetHasVote(hasVote bool) error {
   344  	if err := errorFor("Machine.SetHasVote", m.doc.id, hasVote); err != nil {
   345  		return err
   346  	}
   347  	m.mutate(func(doc *machineDoc) {
   348  		doc.hasVote = hasVote
   349  	})
   350  	return nil
   351  }
   352  
   353  func (m *fakeMachine) setWantsVote(wantsVote bool) {
   354  	m.mutate(func(doc *machineDoc) {
   355  		doc.wantsVote = wantsVote
   356  	})
   357  }
   358  
   359  type fakeMongoSession struct {
   360  	// If InstantlyReady is true, replica status of
   361  	// all members will be instantly reported as ready.
   362  	InstantlyReady bool
   363  
   364  	checker invariantChecker
   365  	members voyeur.Value // of []replicaset.Member
   366  	status  voyeur.Value // of *replicaset.Status
   367  }
   368  
   369  // newFakeMongoSession returns a mock implementation of mongoSession.
   370  func newFakeMongoSession(checker invariantChecker) *fakeMongoSession {
   371  	s := new(fakeMongoSession)
   372  	s.checker = checker
   373  	s.members.Set([]replicaset.Member(nil))
   374  	s.status.Set(&replicaset.Status{})
   375  	return s
   376  }
   377  
   378  // CurrentMembers implements mongoSession.CurrentMembers.
   379  func (session *fakeMongoSession) CurrentMembers() ([]replicaset.Member, error) {
   380  	if err := errorFor("Session.CurrentMembers"); err != nil {
   381  		return nil, err
   382  	}
   383  	return deepCopy(session.members.Get()).([]replicaset.Member), nil
   384  }
   385  
   386  // CurrentStatus implements mongoSession.CurrentStatus.
   387  func (session *fakeMongoSession) CurrentStatus() (*replicaset.Status, error) {
   388  	if err := errorFor("Session.CurrentStatus"); err != nil {
   389  		return nil, err
   390  	}
   391  	return deepCopy(session.status.Get()).(*replicaset.Status), nil
   392  }
   393  
   394  // setStatus sets the status of the current members of the session.
   395  func (session *fakeMongoSession) setStatus(members []replicaset.MemberStatus) {
   396  	session.status.Set(deepCopy(&replicaset.Status{
   397  		Members: members,
   398  	}))
   399  }
   400  
   401  // Set implements mongoSession.Set
   402  func (session *fakeMongoSession) Set(members []replicaset.Member) error {
   403  	if err := errorFor("Session.Set"); err != nil {
   404  		logger.Infof("not setting replicaset members to %#v", members)
   405  		return err
   406  	}
   407  	logger.Infof("setting replicaset members to %#v", members)
   408  	session.members.Set(deepCopy(members))
   409  	if session.InstantlyReady {
   410  		statuses := make([]replicaset.MemberStatus, len(members))
   411  		for i, m := range members {
   412  			statuses[i] = replicaset.MemberStatus{
   413  				Id:      m.Id,
   414  				Address: m.Address,
   415  				Healthy: true,
   416  				State:   replicaset.SecondaryState,
   417  			}
   418  			if i == 0 {
   419  				statuses[i].State = replicaset.PrimaryState
   420  			}
   421  		}
   422  		session.setStatus(statuses)
   423  	}
   424  	session.checker.checkInvariants()
   425  	return nil
   426  }
   427  
   428  // deepCopy makes a deep copy of any type by marshalling
   429  // it as JSON, then unmarshalling it.
   430  func deepCopy(x interface{}) interface{} {
   431  	v := reflect.ValueOf(x)
   432  	data, err := json.Marshal(x)
   433  	if err != nil {
   434  		panic(fmt.Errorf("cannot marshal %#v: %v", x, err))
   435  	}
   436  	newv := reflect.New(v.Type())
   437  	if err := json.Unmarshal(data, newv.Interface()); err != nil {
   438  		panic(fmt.Errorf("cannot unmarshal %q into %s", data, newv.Type()))
   439  	}
   440  	// sanity check
   441  	newx := newv.Elem().Interface()
   442  	if !reflect.DeepEqual(newx, x) {
   443  		panic(fmt.Errorf("value not deep-copied correctly"))
   444  	}
   445  	return newx
   446  }
   447  
   448  type notifier struct {
   449  	tomb    tomb.Tomb
   450  	w       *voyeur.Watcher
   451  	changes chan struct{}
   452  }
   453  
   454  // WatchValue returns a NotifyWatcher that triggers
   455  // when the given value changes. Its Wait and Err methods
   456  // never return a non-nil error.
   457  func WatchValue(val *voyeur.Value) state.NotifyWatcher {
   458  	n := &notifier{
   459  		w:       val.Watch(),
   460  		changes: make(chan struct{}),
   461  	}
   462  	go n.loop()
   463  	return n
   464  }
   465  
   466  func (n *notifier) loop() {
   467  	defer n.tomb.Done()
   468  	for n.w.Next() {
   469  		select {
   470  		case n.changes <- struct{}{}:
   471  		case <-n.tomb.Dying():
   472  		}
   473  	}
   474  }
   475  
   476  // Changes returns a channel that sends a value when the value changes.
   477  // The value itself can be retrieved by calling the value's Get method.
   478  func (n *notifier) Changes() <-chan struct{} {
   479  	return n.changes
   480  }
   481  
   482  // Kill stops the notifier but does not wait for it to finish.
   483  func (n *notifier) Kill() {
   484  	n.tomb.Kill(nil)
   485  	n.w.Close()
   486  }
   487  
   488  func (n *notifier) Err() error {
   489  	return n.tomb.Err()
   490  }
   491  
   492  // Wait waits for the notifier to finish. It always returns nil.
   493  func (n *notifier) Wait() error {
   494  	return n.tomb.Wait()
   495  }
   496  
   497  func (n *notifier) Stop() error {
   498  	return worker.Stop(n)
   499  }