github.com/cloudbase/juju-core@v0.0.0-20140504232958-a7271ac7912f/worker/peergrouper/mock_test.go (about) 1 // Copyright 2014 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package peergrouper 5 6 import ( 7 "encoding/json" 8 "fmt" 9 "path" 10 "reflect" 11 "sync" 12 13 "launchpad.net/tomb" 14 15 "launchpad.net/juju-core/errors" 16 "launchpad.net/juju-core/replicaset" 17 "launchpad.net/juju-core/state" 18 "launchpad.net/juju-core/utils/voyeur" 19 "launchpad.net/juju-core/worker" 20 ) 21 22 // This file holds helper functions for mocking pieces of State and replicaset 23 // that we don't want to directly depend on in unit tests. 24 25 type fakeState struct { 26 mu sync.Mutex 27 machines map[string]*fakeMachine 28 stateServers voyeur.Value // of *state.StateServerInfo 29 session *fakeMongoSession 30 check func(st *fakeState) error 31 } 32 33 var ( 34 _ stateInterface = (*fakeState)(nil) 35 _ stateMachine = (*fakeMachine)(nil) 36 _ mongoSession = (*fakeMongoSession)(nil) 37 ) 38 39 type errorPattern struct { 40 pattern string 41 errFunc func() error 42 } 43 44 var ( 45 errorsMutex sync.Mutex 46 errorPatterns []errorPattern 47 ) 48 49 // setErrorFor causes the given error to be returned 50 // from any mock call that matches the given 51 // string, which may contain wildcards as 52 // in path.Match. 53 // 54 // The standard form for errors is: 55 // Type.Function <arg>... 56 // See individual functions for details. 57 func setErrorFor(what string, err error) { 58 setErrorFuncFor(what, func() error { 59 return err 60 }) 61 } 62 63 // setErrorFuncFor causes the given function 64 // to be invoked to return the error for the 65 // given pattern. 66 func setErrorFuncFor(what string, errFunc func() error) { 67 errorsMutex.Lock() 68 defer errorsMutex.Unlock() 69 errorPatterns = append(errorPatterns, errorPattern{ 70 pattern: what, 71 errFunc: errFunc, 72 }) 73 } 74 75 // errorFor concatenates the call name 76 // with all the args, space separated, 77 // and returns any error registered with 78 // setErrorFor that matches the resulting string. 79 func errorFor(name string, args ...interface{}) error { 80 errorsMutex.Lock() 81 s := name 82 for _, arg := range args { 83 s += " " + fmt.Sprint(arg) 84 } 85 f := func() error { return nil } 86 for _, pattern := range errorPatterns { 87 if ok, _ := path.Match(pattern.pattern, s); ok { 88 f = pattern.errFunc 89 break 90 } 91 } 92 errorsMutex.Unlock() 93 err := f() 94 logger.Errorf("errorFor %q -> %v", s, err) 95 return err 96 } 97 98 func resetErrors() { 99 errorsMutex.Lock() 100 defer errorsMutex.Unlock() 101 errorPatterns = errorPatterns[:0] 102 } 103 104 func newFakeState() *fakeState { 105 st := &fakeState{ 106 machines: make(map[string]*fakeMachine), 107 } 108 st.session = newFakeMongoSession(st) 109 st.stateServers.Set(&state.StateServerInfo{}) 110 return st 111 } 112 113 func (st *fakeState) MongoSession() mongoSession { 114 return st.session 115 } 116 117 func (st *fakeState) checkInvariants() { 118 if st.check == nil { 119 return 120 } 121 if err := st.check(st); err != nil { 122 // Force a panic, otherwise we can deadlock 123 // when called from within the worker. 124 go panic(err) 125 select {} 126 } 127 } 128 129 // checkInvariants checks that all the expected invariants 130 // in the state hold true. Currently we check that: 131 // - total number of votes is odd. 132 // - member voting status implies that machine has vote. 133 func checkInvariants(st *fakeState) error { 134 members := st.session.members.Get().([]replicaset.Member) 135 voteCount := 0 136 for _, m := range members { 137 votes := 1 138 if m.Votes != nil { 139 votes = *m.Votes 140 } 141 voteCount += votes 142 if id, ok := m.Tags["juju-machine-id"]; ok { 143 if votes > 0 { 144 m := st.machine(id) 145 if m == nil { 146 return fmt.Errorf("voting member with machine id %q has no associated Machine", id) 147 } 148 if !m.HasVote() { 149 return fmt.Errorf("machine %q should be marked as having the vote, but does not", id) 150 } 151 } 152 } 153 } 154 if voteCount%2 != 1 { 155 return fmt.Errorf("total vote count is not odd (got %d)", voteCount) 156 } 157 return nil 158 } 159 160 type invariantChecker interface { 161 checkInvariants() 162 } 163 164 // machine is similar to Machine except that 165 // it bypasses the error mocking machinery. 166 // It returns nil if there is no machine with the 167 // given id. 168 func (st *fakeState) machine(id string) *fakeMachine { 169 st.mu.Lock() 170 defer st.mu.Unlock() 171 return st.machines[id] 172 } 173 174 func (st *fakeState) Machine(id string) (stateMachine, error) { 175 if err := errorFor("State.Machine", id); err != nil { 176 return nil, err 177 } 178 if m := st.machine(id); m != nil { 179 return m, nil 180 } 181 return nil, errors.NotFoundf("machine %s", id) 182 } 183 184 func (st *fakeState) addMachine(id string, wantsVote bool) *fakeMachine { 185 st.mu.Lock() 186 defer st.mu.Unlock() 187 logger.Infof("fakeState.addMachine %q", id) 188 if st.machines[id] != nil { 189 panic(fmt.Errorf("id %q already used", id)) 190 } 191 m := &fakeMachine{ 192 checker: st, 193 doc: machineDoc{ 194 id: id, 195 wantsVote: wantsVote, 196 }, 197 } 198 st.machines[id] = m 199 m.val.Set(m.doc) 200 return m 201 } 202 203 func (st *fakeState) removeMachine(id string) { 204 st.mu.Lock() 205 defer st.mu.Unlock() 206 if st.machines[id] == nil { 207 panic(fmt.Errorf("removing non-existent machine %q", id)) 208 } 209 delete(st.machines, id) 210 } 211 212 func (st *fakeState) setStateServers(ids ...string) { 213 st.stateServers.Set(&state.StateServerInfo{ 214 MachineIds: ids, 215 }) 216 } 217 218 func (st *fakeState) StateServerInfo() (*state.StateServerInfo, error) { 219 if err := errorFor("State.StateServerInfo"); err != nil { 220 return nil, err 221 } 222 return deepCopy(st.stateServers.Get()).(*state.StateServerInfo), nil 223 } 224 225 func (st *fakeState) WatchStateServerInfo() state.NotifyWatcher { 226 return WatchValue(&st.stateServers) 227 } 228 229 type fakeMachine struct { 230 mu sync.Mutex 231 val voyeur.Value // of machineDoc 232 doc machineDoc 233 checker invariantChecker 234 } 235 236 func (m *fakeMachine) Refresh() error { 237 if err := errorFor("Machine.Refresh", m.doc.id); err != nil { 238 return err 239 } 240 m.doc = m.val.Get().(machineDoc) 241 return nil 242 } 243 244 func (m *fakeMachine) GoString() string { 245 return fmt.Sprintf("&fakeMachine{%#v}", m.doc) 246 } 247 248 func (m *fakeMachine) Id() string { 249 return m.doc.id 250 } 251 252 func (m *fakeMachine) Watch() state.NotifyWatcher { 253 return WatchValue(&m.val) 254 } 255 256 func (m *fakeMachine) WantsVote() bool { 257 return m.doc.wantsVote 258 } 259 260 func (m *fakeMachine) HasVote() bool { 261 return m.doc.hasVote 262 } 263 264 func (m *fakeMachine) StateHostPort() string { 265 return m.doc.hostPort 266 } 267 268 // mutate atomically changes the machineDoc of 269 // the receiver by mutating it with the provided function. 270 func (m *fakeMachine) mutate(f func(*machineDoc)) { 271 m.mu.Lock() 272 doc := m.val.Get().(machineDoc) 273 f(&doc) 274 m.val.Set(doc) 275 f(&m.doc) 276 m.mu.Unlock() 277 m.checker.checkInvariants() 278 } 279 280 func (m *fakeMachine) setStateHostPort(hostPort string) { 281 m.mutate(func(doc *machineDoc) { 282 doc.hostPort = hostPort 283 }) 284 } 285 286 // SetHasVote implements stateMachine.SetHasVote. 287 func (m *fakeMachine) SetHasVote(hasVote bool) error { 288 if err := errorFor("Machine.SetHasVote", m.doc.id, hasVote); err != nil { 289 return err 290 } 291 m.mutate(func(doc *machineDoc) { 292 doc.hasVote = hasVote 293 }) 294 return nil 295 } 296 297 func (m *fakeMachine) setWantsVote(wantsVote bool) { 298 m.mutate(func(doc *machineDoc) { 299 doc.wantsVote = wantsVote 300 }) 301 } 302 303 type machineDoc struct { 304 id string 305 wantsVote bool 306 hasVote bool 307 hostPort string 308 } 309 310 type fakeMongoSession struct { 311 // If InstantlyReady is true, replica status of 312 // all members will be instantly reported as ready. 313 InstantlyReady bool 314 315 checker invariantChecker 316 members voyeur.Value // of []replicaset.Member 317 status voyeur.Value // of *replicaset.Status 318 } 319 320 // newFakeMongoSession returns a mock implementation of mongoSession. 321 func newFakeMongoSession(checker invariantChecker) *fakeMongoSession { 322 s := new(fakeMongoSession) 323 s.checker = checker 324 s.members.Set([]replicaset.Member(nil)) 325 s.status.Set(&replicaset.Status{}) 326 return s 327 } 328 329 // CurrentMembers implements mongoSession.CurrentMembers. 330 func (session *fakeMongoSession) CurrentMembers() ([]replicaset.Member, error) { 331 if err := errorFor("Session.CurrentMembers"); err != nil { 332 return nil, err 333 } 334 return deepCopy(session.members.Get()).([]replicaset.Member), nil 335 } 336 337 // CurrentStatus implements mongoSession.CurrentStatus. 338 func (session *fakeMongoSession) CurrentStatus() (*replicaset.Status, error) { 339 if err := errorFor("Session.CurrentStatus"); err != nil { 340 return nil, err 341 } 342 return deepCopy(session.status.Get()).(*replicaset.Status), nil 343 } 344 345 // setStatus sets the status of the current members of the session. 346 func (session *fakeMongoSession) setStatus(members []replicaset.MemberStatus) { 347 session.status.Set(deepCopy(&replicaset.Status{ 348 Members: members, 349 })) 350 } 351 352 // Set implements mongoSession.Set 353 func (session *fakeMongoSession) Set(members []replicaset.Member) error { 354 if err := errorFor("Session.Set"); err != nil { 355 logger.Infof("not setting replicaset members to %#v", members) 356 return err 357 } 358 logger.Infof("setting replicaset members to %#v", members) 359 session.members.Set(deepCopy(members)) 360 if session.InstantlyReady { 361 statuses := make([]replicaset.MemberStatus, len(members)) 362 for i, m := range members { 363 statuses[i] = replicaset.MemberStatus{ 364 Id: m.Id, 365 Address: m.Address, 366 Healthy: true, 367 State: replicaset.SecondaryState, 368 } 369 if i == 0 { 370 statuses[i].State = replicaset.PrimaryState 371 } 372 } 373 session.setStatus(statuses) 374 } 375 session.checker.checkInvariants() 376 return nil 377 } 378 379 // deepCopy makes a deep copy of any type by marshalling 380 // it as JSON, then unmarshalling it. 381 func deepCopy(x interface{}) interface{} { 382 v := reflect.ValueOf(x) 383 data, err := json.Marshal(x) 384 if err != nil { 385 panic(fmt.Errorf("cannot marshal %#v: %v", x, err)) 386 } 387 newv := reflect.New(v.Type()) 388 if err := json.Unmarshal(data, newv.Interface()); err != nil { 389 panic(fmt.Errorf("cannot unmarshal %q into %s", data, newv.Type())) 390 } 391 // sanity check 392 newx := newv.Elem().Interface() 393 if !reflect.DeepEqual(newx, x) { 394 panic(fmt.Errorf("value not deep-copied correctly")) 395 } 396 return newx 397 } 398 399 type notifier struct { 400 tomb tomb.Tomb 401 w *voyeur.Watcher 402 changes chan struct{} 403 } 404 405 // WatchValue returns a NotifyWatcher that triggers 406 // when the given value changes. Its Wait and Err methods 407 // never return a non-nil error. 408 func WatchValue(val *voyeur.Value) state.NotifyWatcher { 409 n := ¬ifier{ 410 w: val.Watch(), 411 changes: make(chan struct{}), 412 } 413 go n.loop() 414 return n 415 } 416 417 func (n *notifier) loop() { 418 defer n.tomb.Done() 419 for n.w.Next() { 420 select { 421 case n.changes <- struct{}{}: 422 case <-n.tomb.Dying(): 423 } 424 } 425 } 426 427 // Changes returns a channel that sends a value when the value changes. 428 // The value itself can be retrieved by calling the value's Get method. 429 func (n *notifier) Changes() <-chan struct{} { 430 return n.changes 431 } 432 433 // Kill stops the notifier but does not wait for it to finish. 434 func (n *notifier) Kill() { 435 n.tomb.Kill(nil) 436 n.w.Close() 437 } 438 439 func (n *notifier) Err() error { 440 return n.tomb.Err() 441 } 442 443 // Wait waits for the notifier to finish. It always returns nil. 444 func (n *notifier) Wait() error { 445 return n.tomb.Wait() 446 } 447 448 func (n *notifier) Stop() error { 449 return worker.Stop(n) 450 }