github.com/cloud-green/juju@v0.0.0-20151002100041-a00291338d3d/worker/peergrouper/mock_test.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package peergrouper
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"net"
    10  	"path"
    11  	"reflect"
    12  	"strconv"
    13  	"sync"
    14  
    15  	"github.com/juju/errors"
    16  	"github.com/juju/replicaset"
    17  	"github.com/juju/utils/voyeur"
    18  	"launchpad.net/tomb"
    19  
    20  	"github.com/juju/juju/instance"
    21  	"github.com/juju/juju/network"
    22  	"github.com/juju/juju/state"
    23  	"github.com/juju/juju/worker"
    24  )
    25  
    26  // This file holds helper functions for mocking pieces of State and replicaset
    27  // that we don't want to directly depend on in unit tests.
    28  
    29  type fakeState struct {
    30  	mu           sync.Mutex
    31  	machines     map[string]*fakeMachine
    32  	stateServers voyeur.Value // of *state.StateServerInfo
    33  	session      *fakeMongoSession
    34  	check        func(st *fakeState) error
    35  }
    36  
    37  var (
    38  	_ stateInterface = (*fakeState)(nil)
    39  	_ stateMachine   = (*fakeMachine)(nil)
    40  	_ mongoSession   = (*fakeMongoSession)(nil)
    41  )
    42  
    43  type errorPattern struct {
    44  	pattern string
    45  	errFunc func() error
    46  }
    47  
    48  var (
    49  	errorsMutex   sync.Mutex
    50  	errorPatterns []errorPattern
    51  )
    52  
    53  // setErrorFor causes the given error to be returned
    54  // from any mock call that matches the given
    55  // string, which may contain wildcards as
    56  // in path.Match.
    57  //
    58  // The standard form for errors is:
    59  //    Type.Function <arg>...
    60  // See individual functions for details.
    61  func setErrorFor(what string, err error) {
    62  	setErrorFuncFor(what, func() error {
    63  		return err
    64  	})
    65  }
    66  
    67  // setErrorFuncFor causes the given function
    68  // to be invoked to return the error for the
    69  // given pattern.
    70  func setErrorFuncFor(what string, errFunc func() error) {
    71  	errorsMutex.Lock()
    72  	defer errorsMutex.Unlock()
    73  	errorPatterns = append(errorPatterns, errorPattern{
    74  		pattern: what,
    75  		errFunc: errFunc,
    76  	})
    77  }
    78  
    79  // errorFor concatenates the call name
    80  // with all the args, space separated,
    81  // and returns any error registered with
    82  // setErrorFor that matches the resulting string.
    83  func errorFor(name string, args ...interface{}) error {
    84  	errorsMutex.Lock()
    85  	s := name
    86  	for _, arg := range args {
    87  		s += " " + fmt.Sprint(arg)
    88  	}
    89  	f := func() error { return nil }
    90  	for _, pattern := range errorPatterns {
    91  		if ok, _ := path.Match(pattern.pattern, s); ok {
    92  			f = pattern.errFunc
    93  			break
    94  		}
    95  	}
    96  	errorsMutex.Unlock()
    97  	err := f()
    98  	logger.Errorf("errorFor %q -> %v", s, err)
    99  	return err
   100  }
   101  
   102  func resetErrors() {
   103  	errorsMutex.Lock()
   104  	defer errorsMutex.Unlock()
   105  	errorPatterns = errorPatterns[:0]
   106  }
   107  
   108  func NewFakeState() *fakeState {
   109  	st := &fakeState{
   110  		machines: make(map[string]*fakeMachine),
   111  	}
   112  	st.session = newFakeMongoSession(st)
   113  	st.stateServers.Set(&state.StateServerInfo{})
   114  	return st
   115  }
   116  
   117  func (st *fakeState) MongoSession() mongoSession {
   118  	return st.session
   119  }
   120  
   121  func (st *fakeState) checkInvariants() {
   122  	if st.check == nil {
   123  		return
   124  	}
   125  	if err := st.check(st); err != nil {
   126  		// Force a panic, otherwise we can deadlock
   127  		// when called from within the worker.
   128  		go panic(err)
   129  		select {}
   130  	}
   131  }
   132  
   133  // checkInvariants checks that all the expected invariants
   134  // in the state hold true. Currently we check that:
   135  // - total number of votes is odd.
   136  // - member voting status implies that machine has vote.
   137  func checkInvariants(st *fakeState) error {
   138  	members := st.session.members.Get().([]replicaset.Member)
   139  	voteCount := 0
   140  	for _, m := range members {
   141  		votes := 1
   142  		if m.Votes != nil {
   143  			votes = *m.Votes
   144  		}
   145  		voteCount += votes
   146  		if id, ok := m.Tags[jujuMachineKey]; ok {
   147  			if votes > 0 {
   148  				m := st.machine(id)
   149  				if m == nil {
   150  					return fmt.Errorf("voting member with machine id %q has no associated Machine", id)
   151  				}
   152  				if !m.HasVote() {
   153  					return fmt.Errorf("machine %q should be marked as having the vote, but does not", id)
   154  				}
   155  			}
   156  		}
   157  	}
   158  	if voteCount%2 != 1 {
   159  		return fmt.Errorf("total vote count is not odd (got %d)", voteCount)
   160  	}
   161  	return nil
   162  }
   163  
   164  type invariantChecker interface {
   165  	checkInvariants()
   166  }
   167  
   168  // machine is similar to Machine except that
   169  // it bypasses the error mocking machinery.
   170  // It returns nil if there is no machine with the
   171  // given id.
   172  func (st *fakeState) machine(id string) *fakeMachine {
   173  	st.mu.Lock()
   174  	defer st.mu.Unlock()
   175  	return st.machines[id]
   176  }
   177  
   178  func (st *fakeState) Machine(id string) (stateMachine, error) {
   179  	if err := errorFor("State.Machine", id); err != nil {
   180  		return nil, err
   181  	}
   182  	if m := st.machine(id); m != nil {
   183  		return m, nil
   184  	}
   185  	return nil, errors.NotFoundf("machine %s", id)
   186  }
   187  
   188  func (st *fakeState) addMachine(id string, wantsVote bool) *fakeMachine {
   189  	st.mu.Lock()
   190  	defer st.mu.Unlock()
   191  	logger.Infof("fakeState.addMachine %q", id)
   192  	if st.machines[id] != nil {
   193  		panic(fmt.Errorf("id %q already used", id))
   194  	}
   195  	m := &fakeMachine{
   196  		checker: st,
   197  		doc: machineDoc{
   198  			id:        id,
   199  			wantsVote: wantsVote,
   200  		},
   201  	}
   202  	st.machines[id] = m
   203  	m.val.Set(m.doc)
   204  	return m
   205  }
   206  
   207  func (st *fakeState) removeMachine(id string) {
   208  	st.mu.Lock()
   209  	defer st.mu.Unlock()
   210  	if st.machines[id] == nil {
   211  		panic(fmt.Errorf("removing non-existent machine %q", id))
   212  	}
   213  	delete(st.machines, id)
   214  }
   215  
   216  func (st *fakeState) setStateServers(ids ...string) {
   217  	st.stateServers.Set(&state.StateServerInfo{
   218  		MachineIds: ids,
   219  	})
   220  }
   221  
   222  func (st *fakeState) StateServerInfo() (*state.StateServerInfo, error) {
   223  	if err := errorFor("State.StateServerInfo"); err != nil {
   224  		return nil, err
   225  	}
   226  	return deepCopy(st.stateServers.Get()).(*state.StateServerInfo), nil
   227  }
   228  
   229  func (st *fakeState) WatchStateServerInfo() state.NotifyWatcher {
   230  	return WatchValue(&st.stateServers)
   231  }
   232  
   233  type fakeMachine struct {
   234  	mu      sync.Mutex
   235  	val     voyeur.Value // of machineDoc
   236  	doc     machineDoc
   237  	checker invariantChecker
   238  }
   239  
   240  type machineDoc struct {
   241  	id             string
   242  	wantsVote      bool
   243  	hasVote        bool
   244  	instanceId     instance.Id
   245  	mongoHostPorts []network.HostPort
   246  	apiHostPorts   []network.HostPort
   247  }
   248  
   249  func (m *fakeMachine) Refresh() error {
   250  	m.mu.Lock()
   251  	defer m.mu.Unlock()
   252  	if err := errorFor("Machine.Refresh", m.doc.id); err != nil {
   253  		return err
   254  	}
   255  	m.doc = m.val.Get().(machineDoc)
   256  	return nil
   257  }
   258  
   259  func (m *fakeMachine) GoString() string {
   260  	m.mu.Lock()
   261  	defer m.mu.Unlock()
   262  	return fmt.Sprintf("&fakeMachine{%#v}", m.doc)
   263  }
   264  
   265  func (m *fakeMachine) Id() string {
   266  	m.mu.Lock()
   267  	defer m.mu.Unlock()
   268  	return m.doc.id
   269  }
   270  
   271  func (m *fakeMachine) InstanceId() (instance.Id, error) {
   272  	m.mu.Lock()
   273  	defer m.mu.Unlock()
   274  	if err := errorFor("Machine.InstanceId", m.doc.id); err != nil {
   275  		return "", err
   276  	}
   277  	return m.doc.instanceId, nil
   278  }
   279  
   280  func (m *fakeMachine) Watch() state.NotifyWatcher {
   281  	m.mu.Lock()
   282  	defer m.mu.Unlock()
   283  	return WatchValue(&m.val)
   284  }
   285  
   286  func (m *fakeMachine) WantsVote() bool {
   287  	m.mu.Lock()
   288  	defer m.mu.Unlock()
   289  	return m.doc.wantsVote
   290  }
   291  
   292  func (m *fakeMachine) HasVote() bool {
   293  	m.mu.Lock()
   294  	defer m.mu.Unlock()
   295  	return m.doc.hasVote
   296  }
   297  
   298  func (m *fakeMachine) MongoHostPorts() []network.HostPort {
   299  	m.mu.Lock()
   300  	defer m.mu.Unlock()
   301  	return m.doc.mongoHostPorts
   302  }
   303  
   304  func (m *fakeMachine) APIHostPorts() []network.HostPort {
   305  	m.mu.Lock()
   306  	defer m.mu.Unlock()
   307  	return m.doc.apiHostPorts
   308  }
   309  
   310  // mutate atomically changes the machineDoc of
   311  // the receiver by mutating it with the provided function.
   312  func (m *fakeMachine) mutate(f func(*machineDoc)) {
   313  	m.mu.Lock()
   314  	doc := m.val.Get().(machineDoc)
   315  	f(&doc)
   316  	m.val.Set(doc)
   317  	f(&m.doc)
   318  	m.mu.Unlock()
   319  	m.checker.checkInvariants()
   320  }
   321  
   322  func (m *fakeMachine) setStateHostPort(hostPort string) {
   323  	var mongoHostPorts []network.HostPort
   324  	if hostPort != "" {
   325  		host, portStr, err := net.SplitHostPort(hostPort)
   326  		if err != nil {
   327  			panic(err)
   328  		}
   329  		port, err := strconv.Atoi(portStr)
   330  		if err != nil {
   331  			panic(err)
   332  		}
   333  		mongoHostPorts = network.NewHostPorts(port, host)
   334  		mongoHostPorts[0].Scope = network.ScopeCloudLocal
   335  	}
   336  
   337  	m.mutate(func(doc *machineDoc) {
   338  		doc.mongoHostPorts = mongoHostPorts
   339  	})
   340  }
   341  
   342  func (m *fakeMachine) setMongoHostPorts(hostPorts []network.HostPort) {
   343  	m.mutate(func(doc *machineDoc) {
   344  		doc.mongoHostPorts = hostPorts
   345  	})
   346  }
   347  
   348  func (m *fakeMachine) setAPIHostPorts(hostPorts []network.HostPort) {
   349  	m.mutate(func(doc *machineDoc) {
   350  		doc.apiHostPorts = hostPorts
   351  	})
   352  }
   353  
   354  func (m *fakeMachine) setInstanceId(instanceId instance.Id) {
   355  	m.mutate(func(doc *machineDoc) {
   356  		doc.instanceId = instanceId
   357  	})
   358  }
   359  
   360  // SetHasVote implements stateMachine.SetHasVote.
   361  func (m *fakeMachine) SetHasVote(hasVote bool) error {
   362  	if err := errorFor("Machine.SetHasVote", m.doc.id, hasVote); err != nil {
   363  		return err
   364  	}
   365  	m.mutate(func(doc *machineDoc) {
   366  		doc.hasVote = hasVote
   367  	})
   368  	return nil
   369  }
   370  
   371  func (m *fakeMachine) setWantsVote(wantsVote bool) {
   372  	m.mutate(func(doc *machineDoc) {
   373  		doc.wantsVote = wantsVote
   374  	})
   375  }
   376  
   377  type fakeMongoSession struct {
   378  	// If InstantlyReady is true, replica status of
   379  	// all members will be instantly reported as ready.
   380  	InstantlyReady bool
   381  
   382  	checker invariantChecker
   383  	members voyeur.Value // of []replicaset.Member
   384  	status  voyeur.Value // of *replicaset.Status
   385  }
   386  
   387  // newFakeMongoSession returns a mock implementation of mongoSession.
   388  func newFakeMongoSession(checker invariantChecker) *fakeMongoSession {
   389  	s := new(fakeMongoSession)
   390  	s.checker = checker
   391  	s.members.Set([]replicaset.Member(nil))
   392  	s.status.Set(&replicaset.Status{})
   393  	return s
   394  }
   395  
   396  // CurrentMembers implements mongoSession.CurrentMembers.
   397  func (session *fakeMongoSession) CurrentMembers() ([]replicaset.Member, error) {
   398  	if err := errorFor("Session.CurrentMembers"); err != nil {
   399  		return nil, err
   400  	}
   401  	return deepCopy(session.members.Get()).([]replicaset.Member), nil
   402  }
   403  
   404  // CurrentStatus implements mongoSession.CurrentStatus.
   405  func (session *fakeMongoSession) CurrentStatus() (*replicaset.Status, error) {
   406  	if err := errorFor("Session.CurrentStatus"); err != nil {
   407  		return nil, err
   408  	}
   409  	return deepCopy(session.status.Get()).(*replicaset.Status), nil
   410  }
   411  
   412  // setStatus sets the status of the current members of the session.
   413  func (session *fakeMongoSession) setStatus(members []replicaset.MemberStatus) {
   414  	session.status.Set(deepCopy(&replicaset.Status{
   415  		Members: members,
   416  	}))
   417  }
   418  
   419  // Set implements mongoSession.Set
   420  func (session *fakeMongoSession) Set(members []replicaset.Member) error {
   421  	if err := errorFor("Session.Set"); err != nil {
   422  		logger.Infof("not setting replicaset members to %#v", members)
   423  		return err
   424  	}
   425  	logger.Infof("setting replicaset members to %#v", members)
   426  	session.members.Set(deepCopy(members))
   427  	if session.InstantlyReady {
   428  		statuses := make([]replicaset.MemberStatus, len(members))
   429  		for i, m := range members {
   430  			statuses[i] = replicaset.MemberStatus{
   431  				Id:      m.Id,
   432  				Address: m.Address,
   433  				Healthy: true,
   434  				State:   replicaset.SecondaryState,
   435  			}
   436  			if i == 0 {
   437  				statuses[i].State = replicaset.PrimaryState
   438  			}
   439  		}
   440  		session.setStatus(statuses)
   441  	}
   442  	session.checker.checkInvariants()
   443  	return nil
   444  }
   445  
   446  // deepCopy makes a deep copy of any type by marshalling
   447  // it as JSON, then unmarshalling it.
   448  func deepCopy(x interface{}) interface{} {
   449  	v := reflect.ValueOf(x)
   450  	data, err := json.Marshal(x)
   451  	if err != nil {
   452  		panic(fmt.Errorf("cannot marshal %#v: %v", x, err))
   453  	}
   454  	newv := reflect.New(v.Type())
   455  	if err := json.Unmarshal(data, newv.Interface()); err != nil {
   456  		panic(fmt.Errorf("cannot unmarshal %q into %s", data, newv.Type()))
   457  	}
   458  	// sanity check
   459  	newx := newv.Elem().Interface()
   460  	if !reflect.DeepEqual(newx, x) {
   461  		panic(fmt.Errorf("value not deep-copied correctly"))
   462  	}
   463  	return newx
   464  }
   465  
   466  type notifier struct {
   467  	tomb    tomb.Tomb
   468  	w       *voyeur.Watcher
   469  	changes chan struct{}
   470  }
   471  
   472  // WatchValue returns a NotifyWatcher that triggers
   473  // when the given value changes. Its Wait and Err methods
   474  // never return a non-nil error.
   475  func WatchValue(val *voyeur.Value) state.NotifyWatcher {
   476  	n := &notifier{
   477  		w:       val.Watch(),
   478  		changes: make(chan struct{}),
   479  	}
   480  	go n.loop()
   481  	return n
   482  }
   483  
   484  func (n *notifier) loop() {
   485  	defer n.tomb.Done()
   486  	for n.w.Next() {
   487  		select {
   488  		case n.changes <- struct{}{}:
   489  		case <-n.tomb.Dying():
   490  		}
   491  	}
   492  }
   493  
   494  // Changes returns a channel that sends a value when the value changes.
   495  // The value itself can be retrieved by calling the value's Get method.
   496  func (n *notifier) Changes() <-chan struct{} {
   497  	return n.changes
   498  }
   499  
   500  // Kill stops the notifier but does not wait for it to finish.
   501  func (n *notifier) Kill() {
   502  	n.tomb.Kill(nil)
   503  	n.w.Close()
   504  }
   505  
   506  func (n *notifier) Err() error {
   507  	return n.tomb.Err()
   508  }
   509  
   510  // Wait waits for the notifier to finish. It always returns nil.
   511  func (n *notifier) Wait() error {
   512  	return n.tomb.Wait()
   513  }
   514  
   515  func (n *notifier) Stop() error {
   516  	return worker.Stop(n)
   517  }