github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/worker/peergrouper/mock_test.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package peergrouper
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"net"
    10  	"path"
    11  	"reflect"
    12  	"strconv"
    13  	"sync"
    14  
    15  	"github.com/juju/errors"
    16  	"github.com/juju/replicaset"
    17  	"github.com/juju/utils/voyeur"
    18  	"launchpad.net/tomb"
    19  
    20  	"github.com/juju/juju/apiserver/common/networkingcommon"
    21  	"github.com/juju/juju/apiserver/testing"
    22  	"github.com/juju/juju/environs/config"
    23  	"github.com/juju/juju/instance"
    24  	"github.com/juju/juju/network"
    25  	"github.com/juju/juju/state"
    26  	"github.com/juju/juju/status"
    27  	coretesting "github.com/juju/juju/testing"
    28  	"github.com/juju/juju/worker"
    29  )
    30  
    31  // This file holds helper functions for mocking pieces of State and replicaset
    32  // that we don't want to directly depend on in unit tests.
    33  
    34  type fakeState struct {
    35  	mu          sync.Mutex
    36  	errors      errorPatterns
    37  	machines    map[string]*fakeMachine
    38  	controllers voyeur.Value // of *state.ControllerInfo
    39  	statuses    voyeur.Value // of statuses collection
    40  	session     *fakeMongoSession
    41  	check       func(st *fakeState) error
    42  }
    43  
    44  var (
    45  	_ stateInterface = (*fakeState)(nil)
    46  	_ stateMachine   = (*fakeMachine)(nil)
    47  	_ mongoSession   = (*fakeMongoSession)(nil)
    48  )
    49  
    50  type errorPatterns struct {
    51  	patterns []errorPattern
    52  }
    53  
    54  type errorPattern struct {
    55  	pattern string
    56  	errFunc func() error
    57  }
    58  
    59  // setErrorFor causes the given error to be returned
    60  // from any mock call that matches the given
    61  // string, which may contain wildcards as
    62  // in path.Match.
    63  //
    64  // The standard form for errors is:
    65  //    Type.Function <arg>...
    66  // See individual functions for details.
    67  func (e *errorPatterns) setErrorFor(what string, err error) {
    68  	e.setErrorFuncFor(what, func() error {
    69  		return err
    70  	})
    71  }
    72  
    73  // setErrorFuncFor causes the given function
    74  // to be invoked to return the error for the
    75  // given pattern.
    76  func (e *errorPatterns) setErrorFuncFor(what string, errFunc func() error) {
    77  	e.patterns = append(e.patterns, errorPattern{
    78  		pattern: what,
    79  		errFunc: errFunc,
    80  	})
    81  }
    82  
    83  // errorFor concatenates the call name
    84  // with all the args, space separated,
    85  // and returns any error registered with
    86  // setErrorFor that matches the resulting string.
    87  func (e *errorPatterns) errorFor(name string, args ...interface{}) error {
    88  	s := name
    89  	for _, arg := range args {
    90  		s += " " + fmt.Sprint(arg)
    91  	}
    92  	f := func() error { return nil }
    93  	for _, pattern := range e.patterns {
    94  		if ok, _ := path.Match(pattern.pattern, s); ok {
    95  			f = pattern.errFunc
    96  			break
    97  		}
    98  	}
    99  	err := f()
   100  	logger.Errorf("errorFor %q -> %v", s, err)
   101  	return err
   102  }
   103  
   104  func (e *errorPatterns) resetErrors() {
   105  	e.patterns = e.patterns[:0]
   106  }
   107  
   108  func NewFakeState() *fakeState {
   109  	st := &fakeState{
   110  		machines: make(map[string]*fakeMachine),
   111  	}
   112  	st.session = newFakeMongoSession(st, &st.errors)
   113  	st.controllers.Set(&state.ControllerInfo{})
   114  	return st
   115  }
   116  
   117  func (st *fakeState) MongoSession() mongoSession {
   118  	return st.session
   119  }
   120  
   121  func (st *fakeState) checkInvariants() {
   122  	if st.check == nil {
   123  		return
   124  	}
   125  	if err := st.check(st); err != nil {
   126  		// Force a panic, otherwise we can deadlock
   127  		// when called from within the worker.
   128  		go panic(err)
   129  		select {}
   130  	}
   131  }
   132  
   133  // checkInvariants checks that all the expected invariants
   134  // in the state hold true. Currently we check that:
   135  // - total number of votes is odd.
   136  // - member voting status implies that machine has vote.
   137  func checkInvariants(st *fakeState) error {
   138  	members := st.session.members.Get().([]replicaset.Member)
   139  	voteCount := 0
   140  	for _, m := range members {
   141  		votes := 1
   142  		if m.Votes != nil {
   143  			votes = *m.Votes
   144  		}
   145  		voteCount += votes
   146  		if id, ok := m.Tags[jujuMachineKey]; ok {
   147  			if votes > 0 {
   148  				m := st.machine(id)
   149  				if m == nil {
   150  					return fmt.Errorf("voting member with machine id %q has no associated Machine", id)
   151  				}
   152  				if !m.HasVote() {
   153  					return fmt.Errorf("machine %q should be marked as having the vote, but does not", id)
   154  				}
   155  			}
   156  		}
   157  	}
   158  	if voteCount%2 != 1 {
   159  		return fmt.Errorf("total vote count is not odd (got %d)", voteCount)
   160  	}
   161  	return nil
   162  }
   163  
   164  type invariantChecker interface {
   165  	checkInvariants()
   166  }
   167  
   168  // machine is similar to Machine except that
   169  // it bypasses the error mocking machinery.
   170  // It returns nil if there is no machine with the
   171  // given id.
   172  func (st *fakeState) machine(id string) *fakeMachine {
   173  	st.mu.Lock()
   174  	defer st.mu.Unlock()
   175  	return st.machines[id]
   176  }
   177  
   178  func (st *fakeState) Machine(id string) (stateMachine, error) {
   179  	if err := st.errors.errorFor("State.Machine", id); err != nil {
   180  		return nil, err
   181  	}
   182  	if m := st.machine(id); m != nil {
   183  		return m, nil
   184  	}
   185  	return nil, errors.NotFoundf("machine %s", id)
   186  }
   187  
   188  func (st *fakeState) addMachine(id string, wantsVote bool) *fakeMachine {
   189  	st.mu.Lock()
   190  	defer st.mu.Unlock()
   191  	logger.Infof("fakeState.addMachine %q", id)
   192  	if st.machines[id] != nil {
   193  		panic(fmt.Errorf("id %q already used", id))
   194  	}
   195  	m := &fakeMachine{
   196  		errors:  &st.errors,
   197  		checker: st,
   198  		doc: machineDoc{
   199  			id:        id,
   200  			wantsVote: wantsVote,
   201  		},
   202  	}
   203  	st.machines[id] = m
   204  	m.val.Set(m.doc)
   205  	return m
   206  }
   207  
   208  func (st *fakeState) removeMachine(id string) {
   209  	st.mu.Lock()
   210  	defer st.mu.Unlock()
   211  	if st.machines[id] == nil {
   212  		panic(fmt.Errorf("removing non-existent machine %q", id))
   213  	}
   214  	delete(st.machines, id)
   215  }
   216  
   217  func (st *fakeState) setControllers(ids ...string) {
   218  	st.controllers.Set(&state.ControllerInfo{
   219  		MachineIds: ids,
   220  	})
   221  }
   222  
   223  func (st *fakeState) ControllerInfo() (*state.ControllerInfo, error) {
   224  	if err := st.errors.errorFor("State.ControllerInfo"); err != nil {
   225  		return nil, err
   226  	}
   227  	return deepCopy(st.controllers.Get()).(*state.ControllerInfo), nil
   228  }
   229  
   230  func (st *fakeState) WatchControllerInfo() state.NotifyWatcher {
   231  	return WatchValue(&st.controllers)
   232  }
   233  
   234  func (st *fakeState) WatchControllerStatusChanges() state.StringsWatcher {
   235  	return WatchStrings(&st.statuses)
   236  }
   237  
   238  func (st *fakeState) Space(name string) (SpaceReader, error) {
   239  	foo := []networkingcommon.BackingSpace{
   240  		&testing.FakeSpace{SpaceName: "Space" + name},
   241  		&testing.FakeSpace{SpaceName: "Space" + name},
   242  		&testing.FakeSpace{SpaceName: "Space" + name},
   243  	}
   244  	return foo[0].(SpaceReader), nil
   245  }
   246  
   247  func (st *fakeState) SetOrGetMongoSpaceName(mongoSpaceName network.SpaceName) (network.SpaceName, error) {
   248  	inf, _ := st.ControllerInfo()
   249  	strMongoSpaceName := string(mongoSpaceName)
   250  
   251  	if inf.MongoSpaceState == state.MongoSpaceUnknown {
   252  		inf.MongoSpaceName = strMongoSpaceName
   253  		inf.MongoSpaceState = state.MongoSpaceValid
   254  		st.controllers.Set(inf)
   255  	}
   256  	return network.SpaceName(inf.MongoSpaceName), nil
   257  }
   258  
   259  func (st *fakeState) SetMongoSpaceState(mongoSpaceState state.MongoSpaceStates) error {
   260  	inf, _ := st.ControllerInfo()
   261  	inf.MongoSpaceState = mongoSpaceState
   262  	st.controllers.Set(inf)
   263  	return nil
   264  }
   265  
   266  func (st *fakeState) getMongoSpaceName() string {
   267  	inf, _ := st.ControllerInfo()
   268  	return inf.MongoSpaceName
   269  }
   270  
   271  func (st *fakeState) getMongoSpaceState() state.MongoSpaceStates {
   272  	inf, _ := st.ControllerInfo()
   273  	return inf.MongoSpaceState
   274  }
   275  
   276  func (st *fakeState) ModelConfig() (*config.Config, error) {
   277  	attrs := coretesting.FakeConfig()
   278  	cfg, err := config.New(config.NoDefaults, attrs)
   279  	return cfg, err
   280  }
   281  
   282  type fakeMachine struct {
   283  	mu      sync.Mutex
   284  	errors  *errorPatterns
   285  	val     voyeur.Value // of machineDoc
   286  	doc     machineDoc
   287  	checker invariantChecker
   288  }
   289  
   290  type machineDoc struct {
   291  	id             string
   292  	wantsVote      bool
   293  	hasVote        bool
   294  	instanceId     instance.Id
   295  	mongoHostPorts []network.HostPort
   296  	apiHostPorts   []network.HostPort
   297  }
   298  
   299  func (m *fakeMachine) Refresh() error {
   300  	m.mu.Lock()
   301  	defer m.mu.Unlock()
   302  	if err := m.errors.errorFor("Machine.Refresh", m.doc.id); err != nil {
   303  		return err
   304  	}
   305  	m.doc = m.val.Get().(machineDoc)
   306  	return nil
   307  }
   308  
   309  func (m *fakeMachine) GoString() string {
   310  	m.mu.Lock()
   311  	defer m.mu.Unlock()
   312  	return fmt.Sprintf("&fakeMachine{%#v}", m.doc)
   313  }
   314  
   315  func (m *fakeMachine) Id() string {
   316  	m.mu.Lock()
   317  	defer m.mu.Unlock()
   318  	return m.doc.id
   319  }
   320  
   321  func (m *fakeMachine) InstanceId() (instance.Id, error) {
   322  	m.mu.Lock()
   323  	defer m.mu.Unlock()
   324  	if err := m.errors.errorFor("Machine.InstanceId", m.doc.id); err != nil {
   325  		return "", err
   326  	}
   327  	return m.doc.instanceId, nil
   328  }
   329  
   330  func (m *fakeMachine) Watch() state.NotifyWatcher {
   331  	m.mu.Lock()
   332  	defer m.mu.Unlock()
   333  	return WatchValue(&m.val)
   334  }
   335  
   336  func (m *fakeMachine) WantsVote() bool {
   337  	m.mu.Lock()
   338  	defer m.mu.Unlock()
   339  	return m.doc.wantsVote
   340  }
   341  
   342  func (m *fakeMachine) HasVote() bool {
   343  	m.mu.Lock()
   344  	defer m.mu.Unlock()
   345  	return m.doc.hasVote
   346  }
   347  
   348  func (m *fakeMachine) MongoHostPorts() []network.HostPort {
   349  	m.mu.Lock()
   350  	defer m.mu.Unlock()
   351  	return m.doc.mongoHostPorts
   352  }
   353  
   354  func (m *fakeMachine) APIHostPorts() []network.HostPort {
   355  	m.mu.Lock()
   356  	defer m.mu.Unlock()
   357  	return m.doc.apiHostPorts
   358  }
   359  
   360  func (m *fakeMachine) Status() (status.StatusInfo, error) {
   361  	return status.StatusInfo{
   362  		Status: status.StatusStarted,
   363  	}, nil
   364  }
   365  
   366  // mutate atomically changes the machineDoc of
   367  // the receiver by mutating it with the provided function.
   368  func (m *fakeMachine) mutate(f func(*machineDoc)) {
   369  	m.mu.Lock()
   370  	doc := m.val.Get().(machineDoc)
   371  	f(&doc)
   372  	m.val.Set(doc)
   373  	f(&m.doc)
   374  	m.mu.Unlock()
   375  	m.checker.checkInvariants()
   376  }
   377  
   378  func (m *fakeMachine) setStateHostPort(hostPort string) {
   379  	var mongoHostPorts []network.HostPort
   380  	if hostPort != "" {
   381  		host, portStr, err := net.SplitHostPort(hostPort)
   382  		if err != nil {
   383  			panic(err)
   384  		}
   385  		port, err := strconv.Atoi(portStr)
   386  		if err != nil {
   387  			panic(err)
   388  		}
   389  		mongoHostPorts = network.NewHostPorts(port, host)
   390  		mongoHostPorts[0].Scope = network.ScopeCloudLocal
   391  	}
   392  
   393  	m.mutate(func(doc *machineDoc) {
   394  		doc.mongoHostPorts = mongoHostPorts
   395  	})
   396  }
   397  
   398  func (m *fakeMachine) setMongoHostPorts(hostPorts []network.HostPort) {
   399  	m.mutate(func(doc *machineDoc) {
   400  		doc.mongoHostPorts = hostPorts
   401  	})
   402  }
   403  
   404  func (m *fakeMachine) setAPIHostPorts(hostPorts []network.HostPort) {
   405  	m.mutate(func(doc *machineDoc) {
   406  		doc.apiHostPorts = hostPorts
   407  	})
   408  }
   409  
   410  func (m *fakeMachine) setInstanceId(instanceId instance.Id) {
   411  	m.mutate(func(doc *machineDoc) {
   412  		doc.instanceId = instanceId
   413  	})
   414  }
   415  
   416  // SetHasVote implements stateMachine.SetHasVote.
   417  func (m *fakeMachine) SetHasVote(hasVote bool) error {
   418  	if err := m.errors.errorFor("Machine.SetHasVote", m.doc.id, hasVote); err != nil {
   419  		return err
   420  	}
   421  	m.mutate(func(doc *machineDoc) {
   422  		doc.hasVote = hasVote
   423  	})
   424  	return nil
   425  }
   426  
   427  func (m *fakeMachine) setWantsVote(wantsVote bool) {
   428  	m.mutate(func(doc *machineDoc) {
   429  		doc.wantsVote = wantsVote
   430  	})
   431  }
   432  
   433  type fakeMongoSession struct {
   434  	// If InstantlyReady is true, replica status of
   435  	// all members will be instantly reported as ready.
   436  	InstantlyReady bool
   437  
   438  	errors  *errorPatterns
   439  	checker invariantChecker
   440  	members voyeur.Value // of []replicaset.Member
   441  	status  voyeur.Value // of *replicaset.Status
   442  }
   443  
   444  // newFakeMongoSession returns a mock implementation of mongoSession.
   445  func newFakeMongoSession(checker invariantChecker, errors *errorPatterns) *fakeMongoSession {
   446  	s := new(fakeMongoSession)
   447  	s.checker = checker
   448  	s.errors = errors
   449  	s.members.Set([]replicaset.Member(nil))
   450  	s.status.Set(&replicaset.Status{})
   451  	return s
   452  }
   453  
   454  // CurrentMembers implements mongoSession.CurrentMembers.
   455  func (session *fakeMongoSession) CurrentMembers() ([]replicaset.Member, error) {
   456  	if err := session.errors.errorFor("Session.CurrentMembers"); err != nil {
   457  		return nil, err
   458  	}
   459  	return deepCopy(session.members.Get()).([]replicaset.Member), nil
   460  }
   461  
   462  // CurrentStatus implements mongoSession.CurrentStatus.
   463  func (session *fakeMongoSession) CurrentStatus() (*replicaset.Status, error) {
   464  	if err := session.errors.errorFor("Session.CurrentStatus"); err != nil {
   465  		return nil, err
   466  	}
   467  	return deepCopy(session.status.Get()).(*replicaset.Status), nil
   468  }
   469  
   470  // setStatus sets the status of the current members of the session.
   471  func (session *fakeMongoSession) setStatus(members []replicaset.MemberStatus) {
   472  	session.status.Set(deepCopy(&replicaset.Status{
   473  		Members: members,
   474  	}))
   475  }
   476  
   477  // Set implements mongoSession.Set
   478  func (session *fakeMongoSession) Set(members []replicaset.Member) error {
   479  	if err := session.errors.errorFor("Session.Set"); err != nil {
   480  		logger.Infof("not setting replicaset members to %#v", members)
   481  		return err
   482  	}
   483  	logger.Infof("setting replicaset members to %#v", members)
   484  	session.members.Set(deepCopy(members))
   485  	if session.InstantlyReady {
   486  		statuses := make([]replicaset.MemberStatus, len(members))
   487  		for i, m := range members {
   488  			statuses[i] = replicaset.MemberStatus{
   489  				Id:      m.Id,
   490  				Address: m.Address,
   491  				Healthy: true,
   492  				State:   replicaset.SecondaryState,
   493  			}
   494  			if i == 0 {
   495  				statuses[i].State = replicaset.PrimaryState
   496  			}
   497  		}
   498  		session.setStatus(statuses)
   499  	}
   500  	session.checker.checkInvariants()
   501  	return nil
   502  }
   503  
   504  // deepCopy makes a deep copy of any type by marshalling
   505  // it as JSON, then unmarshalling it.
   506  func deepCopy(x interface{}) interface{} {
   507  	v := reflect.ValueOf(x)
   508  	data, err := json.Marshal(x)
   509  	if err != nil {
   510  		panic(fmt.Errorf("cannot marshal %#v: %v", x, err))
   511  	}
   512  	newv := reflect.New(v.Type())
   513  	if err := json.Unmarshal(data, newv.Interface()); err != nil {
   514  		panic(fmt.Errorf("cannot unmarshal %q into %s", data, newv.Type()))
   515  	}
   516  	// sanity check
   517  	newx := newv.Elem().Interface()
   518  	if !reflect.DeepEqual(newx, x) {
   519  		panic(fmt.Errorf("value not deep-copied correctly"))
   520  	}
   521  	return newx
   522  }
   523  
   524  type notifier struct {
   525  	tomb    tomb.Tomb
   526  	w       *voyeur.Watcher
   527  	changes chan struct{}
   528  }
   529  
   530  // WatchValue returns a NotifyWatcher that triggers
   531  // when the given value changes. Its Wait and Err methods
   532  // never return a non-nil error.
   533  func WatchValue(val *voyeur.Value) state.NotifyWatcher {
   534  	n := &notifier{
   535  		w:       val.Watch(),
   536  		changes: make(chan struct{}),
   537  	}
   538  	go n.loop()
   539  	return n
   540  }
   541  
   542  func (n *notifier) loop() {
   543  	defer n.tomb.Done()
   544  	for n.w.Next() {
   545  		select {
   546  		case n.changes <- struct{}{}:
   547  		case <-n.tomb.Dying():
   548  		}
   549  	}
   550  }
   551  
   552  // Changes returns a channel that sends a value when the value changes.
   553  // The value itself can be retrieved by calling the value's Get method.
   554  func (n *notifier) Changes() <-chan struct{} {
   555  	return n.changes
   556  }
   557  
   558  // Kill stops the notifier but does not wait for it to finish.
   559  func (n *notifier) Kill() {
   560  	n.tomb.Kill(nil)
   561  	n.w.Close()
   562  }
   563  
   564  func (n *notifier) Err() error {
   565  	return n.tomb.Err()
   566  }
   567  
   568  // Wait waits for the notifier to finish. It always returns nil.
   569  func (n *notifier) Wait() error {
   570  	return n.tomb.Wait()
   571  }
   572  
   573  func (n *notifier) Stop() error {
   574  	return worker.Stop(n)
   575  }
   576  
   577  type stringsNotifier struct {
   578  	tomb    tomb.Tomb
   579  	w       *voyeur.Watcher
   580  	changes chan []string
   581  }
   582  
   583  // WatchStrings returns a StringsWatcher that triggers
   584  // when the given value changes. Its Wait and Err methods
   585  // never return a non-nil error.
   586  func WatchStrings(val *voyeur.Value) state.StringsWatcher {
   587  	n := &stringsNotifier{
   588  		w:       val.Watch(),
   589  		changes: make(chan []string),
   590  	}
   591  	go n.loop()
   592  	return n
   593  }
   594  
   595  func (n *stringsNotifier) loop() {
   596  	defer n.tomb.Done()
   597  	for n.w.Next() {
   598  		select {
   599  		case n.changes <- []string{}:
   600  		case <-n.tomb.Dying():
   601  		}
   602  	}
   603  }
   604  
   605  // Changes returns a channel that sends a value when the value changes.
   606  // The value itself can be retrieved by calling the value's Get method.
   607  func (n *stringsNotifier) Changes() <-chan []string {
   608  	return n.changes
   609  }
   610  
   611  // Kill stops the notifier but does not wait for it to finish.
   612  func (n *stringsNotifier) Kill() {
   613  	n.tomb.Kill(nil)
   614  	n.w.Close()
   615  }
   616  
   617  func (n *stringsNotifier) Err() error {
   618  	return n.tomb.Err()
   619  }
   620  
   621  // Wait waits for the notifier to finish. It always returns nil.
   622  func (n *stringsNotifier) Wait() error {
   623  	return n.tomb.Wait()
   624  }
   625  
   626  func (n *stringsNotifier) Stop() error {
   627  	return worker.Stop(n)
   628  }