github.com/rogpeppe/juju@v0.0.0-20140613142852-6337964b789e/worker/peergrouper/mock_test.go (about) 1 // Copyright 2014 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package peergrouper 5 6 import ( 7 "encoding/json" 8 "fmt" 9 "net" 10 "path" 11 "reflect" 12 "strconv" 13 "sync" 14 15 "github.com/juju/errors" 16 "github.com/juju/utils/voyeur" 17 "launchpad.net/tomb" 18 19 "github.com/juju/juju/instance" 20 "github.com/juju/juju/network" 21 "github.com/juju/juju/replicaset" 22 "github.com/juju/juju/state" 23 "github.com/juju/juju/worker" 24 ) 25 26 // This file holds helper functions for mocking pieces of State and replicaset 27 // that we don't want to directly depend on in unit tests. 28 29 type fakeState struct { 30 mu sync.Mutex 31 machines map[string]*fakeMachine 32 stateServers voyeur.Value // of *state.StateServerInfo 33 session *fakeMongoSession 34 check func(st *fakeState) error 35 } 36 37 var ( 38 _ stateInterface = (*fakeState)(nil) 39 _ stateMachine = (*fakeMachine)(nil) 40 _ mongoSession = (*fakeMongoSession)(nil) 41 ) 42 43 type errorPattern struct { 44 pattern string 45 errFunc func() error 46 } 47 48 var ( 49 errorsMutex sync.Mutex 50 errorPatterns []errorPattern 51 ) 52 53 // setErrorFor causes the given error to be returned 54 // from any mock call that matches the given 55 // string, which may contain wildcards as 56 // in path.Match. 57 // 58 // The standard form for errors is: 59 // Type.Function <arg>... 60 // See individual functions for details. 61 func setErrorFor(what string, err error) { 62 setErrorFuncFor(what, func() error { 63 return err 64 }) 65 } 66 67 // setErrorFuncFor causes the given function 68 // to be invoked to return the error for the 69 // given pattern. 70 func setErrorFuncFor(what string, errFunc func() error) { 71 errorsMutex.Lock() 72 defer errorsMutex.Unlock() 73 errorPatterns = append(errorPatterns, errorPattern{ 74 pattern: what, 75 errFunc: errFunc, 76 }) 77 } 78 79 // errorFor concatenates the call name 80 // with all the args, space separated, 81 // and returns any error registered with 82 // setErrorFor that matches the resulting string. 83 func errorFor(name string, args ...interface{}) error { 84 errorsMutex.Lock() 85 s := name 86 for _, arg := range args { 87 s += " " + fmt.Sprint(arg) 88 } 89 f := func() error { return nil } 90 for _, pattern := range errorPatterns { 91 if ok, _ := path.Match(pattern.pattern, s); ok { 92 f = pattern.errFunc 93 break 94 } 95 } 96 errorsMutex.Unlock() 97 err := f() 98 logger.Errorf("errorFor %q -> %v", s, err) 99 return err 100 } 101 102 func resetErrors() { 103 errorsMutex.Lock() 104 defer errorsMutex.Unlock() 105 errorPatterns = errorPatterns[:0] 106 } 107 108 func newFakeState() *fakeState { 109 st := &fakeState{ 110 machines: make(map[string]*fakeMachine), 111 } 112 st.session = newFakeMongoSession(st) 113 st.stateServers.Set(&state.StateServerInfo{}) 114 return st 115 } 116 117 func (st *fakeState) MongoSession() mongoSession { 118 return st.session 119 } 120 121 func (st *fakeState) checkInvariants() { 122 if st.check == nil { 123 return 124 } 125 if err := st.check(st); err != nil { 126 // Force a panic, otherwise we can deadlock 127 // when called from within the worker. 128 go panic(err) 129 select {} 130 } 131 } 132 133 // checkInvariants checks that all the expected invariants 134 // in the state hold true. Currently we check that: 135 // - total number of votes is odd. 136 // - member voting status implies that machine has vote. 137 func checkInvariants(st *fakeState) error { 138 members := st.session.members.Get().([]replicaset.Member) 139 voteCount := 0 140 for _, m := range members { 141 votes := 1 142 if m.Votes != nil { 143 votes = *m.Votes 144 } 145 voteCount += votes 146 if id, ok := m.Tags[jujuMachineTag]; ok { 147 if votes > 0 { 148 m := st.machine(id) 149 if m == nil { 150 return fmt.Errorf("voting member with machine id %q has no associated Machine", id) 151 } 152 if !m.HasVote() { 153 return fmt.Errorf("machine %q should be marked as having the vote, but does not", id) 154 } 155 } 156 } 157 } 158 if voteCount%2 != 1 { 159 return fmt.Errorf("total vote count is not odd (got %d)", voteCount) 160 } 161 return nil 162 } 163 164 type invariantChecker interface { 165 checkInvariants() 166 } 167 168 // machine is similar to Machine except that 169 // it bypasses the error mocking machinery. 170 // It returns nil if there is no machine with the 171 // given id. 172 func (st *fakeState) machine(id string) *fakeMachine { 173 st.mu.Lock() 174 defer st.mu.Unlock() 175 return st.machines[id] 176 } 177 178 func (st *fakeState) Machine(id string) (stateMachine, error) { 179 if err := errorFor("State.Machine", id); err != nil { 180 return nil, err 181 } 182 if m := st.machine(id); m != nil { 183 return m, nil 184 } 185 return nil, errors.NotFoundf("machine %s", id) 186 } 187 188 func (st *fakeState) addMachine(id string, wantsVote bool) *fakeMachine { 189 st.mu.Lock() 190 defer st.mu.Unlock() 191 logger.Infof("fakeState.addMachine %q", id) 192 if st.machines[id] != nil { 193 panic(fmt.Errorf("id %q already used", id)) 194 } 195 m := &fakeMachine{ 196 checker: st, 197 doc: machineDoc{ 198 id: id, 199 wantsVote: wantsVote, 200 }, 201 } 202 st.machines[id] = m 203 m.val.Set(m.doc) 204 return m 205 } 206 207 func (st *fakeState) removeMachine(id string) { 208 st.mu.Lock() 209 defer st.mu.Unlock() 210 if st.machines[id] == nil { 211 panic(fmt.Errorf("removing non-existent machine %q", id)) 212 } 213 delete(st.machines, id) 214 } 215 216 func (st *fakeState) setStateServers(ids ...string) { 217 st.stateServers.Set(&state.StateServerInfo{ 218 MachineIds: ids, 219 }) 220 } 221 222 func (st *fakeState) StateServerInfo() (*state.StateServerInfo, error) { 223 if err := errorFor("State.StateServerInfo"); err != nil { 224 return nil, err 225 } 226 return deepCopy(st.stateServers.Get()).(*state.StateServerInfo), nil 227 } 228 229 func (st *fakeState) WatchStateServerInfo() state.NotifyWatcher { 230 return WatchValue(&st.stateServers) 231 } 232 233 type fakeMachine struct { 234 mu sync.Mutex 235 val voyeur.Value // of machineDoc 236 doc machineDoc 237 checker invariantChecker 238 } 239 240 type machineDoc struct { 241 id string 242 wantsVote bool 243 hasVote bool 244 instanceId instance.Id 245 mongoHostPorts []network.HostPort 246 apiHostPorts []network.HostPort 247 } 248 249 func (m *fakeMachine) Refresh() error { 250 if err := errorFor("Machine.Refresh", m.doc.id); err != nil { 251 return err 252 } 253 m.doc = m.val.Get().(machineDoc) 254 return nil 255 } 256 257 func (m *fakeMachine) GoString() string { 258 return fmt.Sprintf("&fakeMachine{%#v}", m.doc) 259 } 260 261 func (m *fakeMachine) Id() string { 262 return m.doc.id 263 } 264 265 func (m *fakeMachine) InstanceId() (instance.Id, error) { 266 if err := errorFor("Machine.InstanceId", m.doc.id); err != nil { 267 return "", err 268 } 269 return m.doc.instanceId, nil 270 } 271 272 func (m *fakeMachine) Watch() state.NotifyWatcher { 273 return WatchValue(&m.val) 274 } 275 276 func (m *fakeMachine) WantsVote() bool { 277 return m.doc.wantsVote 278 } 279 280 func (m *fakeMachine) HasVote() bool { 281 return m.doc.hasVote 282 } 283 284 func (m *fakeMachine) MongoHostPorts() []network.HostPort { 285 return m.doc.mongoHostPorts 286 } 287 288 func (m *fakeMachine) APIHostPorts() []network.HostPort { 289 return m.doc.apiHostPorts 290 } 291 292 // mutate atomically changes the machineDoc of 293 // the receiver by mutating it with the provided function. 294 func (m *fakeMachine) mutate(f func(*machineDoc)) { 295 m.mu.Lock() 296 doc := m.val.Get().(machineDoc) 297 f(&doc) 298 m.val.Set(doc) 299 f(&m.doc) 300 m.mu.Unlock() 301 m.checker.checkInvariants() 302 } 303 304 func (m *fakeMachine) setStateHostPort(hostPort string) { 305 var mongoHostPorts []network.HostPort 306 if hostPort != "" { 307 host, portStr, err := net.SplitHostPort(hostPort) 308 if err != nil { 309 panic(err) 310 } 311 port, err := strconv.Atoi(portStr) 312 if err != nil { 313 panic(err) 314 } 315 mongoHostPorts = network.AddressesWithPort(network.NewAddresses(host), port) 316 mongoHostPorts[0].Scope = network.ScopeCloudLocal 317 } 318 319 m.mutate(func(doc *machineDoc) { 320 doc.mongoHostPorts = mongoHostPorts 321 }) 322 } 323 324 func (m *fakeMachine) setMongoHostPorts(hostPorts []network.HostPort) { 325 m.mutate(func(doc *machineDoc) { 326 doc.mongoHostPorts = hostPorts 327 }) 328 } 329 330 func (m *fakeMachine) setAPIHostPorts(hostPorts []network.HostPort) { 331 m.mutate(func(doc *machineDoc) { 332 doc.apiHostPorts = hostPorts 333 }) 334 } 335 336 func (m *fakeMachine) setInstanceId(instanceId instance.Id) { 337 m.mutate(func(doc *machineDoc) { 338 doc.instanceId = instanceId 339 }) 340 } 341 342 // SetHasVote implements stateMachine.SetHasVote. 343 func (m *fakeMachine) SetHasVote(hasVote bool) error { 344 if err := errorFor("Machine.SetHasVote", m.doc.id, hasVote); err != nil { 345 return err 346 } 347 m.mutate(func(doc *machineDoc) { 348 doc.hasVote = hasVote 349 }) 350 return nil 351 } 352 353 func (m *fakeMachine) setWantsVote(wantsVote bool) { 354 m.mutate(func(doc *machineDoc) { 355 doc.wantsVote = wantsVote 356 }) 357 } 358 359 type fakeMongoSession struct { 360 // If InstantlyReady is true, replica status of 361 // all members will be instantly reported as ready. 362 InstantlyReady bool 363 364 checker invariantChecker 365 members voyeur.Value // of []replicaset.Member 366 status voyeur.Value // of *replicaset.Status 367 } 368 369 // newFakeMongoSession returns a mock implementation of mongoSession. 370 func newFakeMongoSession(checker invariantChecker) *fakeMongoSession { 371 s := new(fakeMongoSession) 372 s.checker = checker 373 s.members.Set([]replicaset.Member(nil)) 374 s.status.Set(&replicaset.Status{}) 375 return s 376 } 377 378 // CurrentMembers implements mongoSession.CurrentMembers. 379 func (session *fakeMongoSession) CurrentMembers() ([]replicaset.Member, error) { 380 if err := errorFor("Session.CurrentMembers"); err != nil { 381 return nil, err 382 } 383 return deepCopy(session.members.Get()).([]replicaset.Member), nil 384 } 385 386 // CurrentStatus implements mongoSession.CurrentStatus. 387 func (session *fakeMongoSession) CurrentStatus() (*replicaset.Status, error) { 388 if err := errorFor("Session.CurrentStatus"); err != nil { 389 return nil, err 390 } 391 return deepCopy(session.status.Get()).(*replicaset.Status), nil 392 } 393 394 // setStatus sets the status of the current members of the session. 395 func (session *fakeMongoSession) setStatus(members []replicaset.MemberStatus) { 396 session.status.Set(deepCopy(&replicaset.Status{ 397 Members: members, 398 })) 399 } 400 401 // Set implements mongoSession.Set 402 func (session *fakeMongoSession) Set(members []replicaset.Member) error { 403 if err := errorFor("Session.Set"); err != nil { 404 logger.Infof("not setting replicaset members to %#v", members) 405 return err 406 } 407 logger.Infof("setting replicaset members to %#v", members) 408 session.members.Set(deepCopy(members)) 409 if session.InstantlyReady { 410 statuses := make([]replicaset.MemberStatus, len(members)) 411 for i, m := range members { 412 statuses[i] = replicaset.MemberStatus{ 413 Id: m.Id, 414 Address: m.Address, 415 Healthy: true, 416 State: replicaset.SecondaryState, 417 } 418 if i == 0 { 419 statuses[i].State = replicaset.PrimaryState 420 } 421 } 422 session.setStatus(statuses) 423 } 424 session.checker.checkInvariants() 425 return nil 426 } 427 428 // deepCopy makes a deep copy of any type by marshalling 429 // it as JSON, then unmarshalling it. 430 func deepCopy(x interface{}) interface{} { 431 v := reflect.ValueOf(x) 432 data, err := json.Marshal(x) 433 if err != nil { 434 panic(fmt.Errorf("cannot marshal %#v: %v", x, err)) 435 } 436 newv := reflect.New(v.Type()) 437 if err := json.Unmarshal(data, newv.Interface()); err != nil { 438 panic(fmt.Errorf("cannot unmarshal %q into %s", data, newv.Type())) 439 } 440 // sanity check 441 newx := newv.Elem().Interface() 442 if !reflect.DeepEqual(newx, x) { 443 panic(fmt.Errorf("value not deep-copied correctly")) 444 } 445 return newx 446 } 447 448 type notifier struct { 449 tomb tomb.Tomb 450 w *voyeur.Watcher 451 changes chan struct{} 452 } 453 454 // WatchValue returns a NotifyWatcher that triggers 455 // when the given value changes. Its Wait and Err methods 456 // never return a non-nil error. 457 func WatchValue(val *voyeur.Value) state.NotifyWatcher { 458 n := ¬ifier{ 459 w: val.Watch(), 460 changes: make(chan struct{}), 461 } 462 go n.loop() 463 return n 464 } 465 466 func (n *notifier) loop() { 467 defer n.tomb.Done() 468 for n.w.Next() { 469 select { 470 case n.changes <- struct{}{}: 471 case <-n.tomb.Dying(): 472 } 473 } 474 } 475 476 // Changes returns a channel that sends a value when the value changes. 477 // The value itself can be retrieved by calling the value's Get method. 478 func (n *notifier) Changes() <-chan struct{} { 479 return n.changes 480 } 481 482 // Kill stops the notifier but does not wait for it to finish. 483 func (n *notifier) Kill() { 484 n.tomb.Kill(nil) 485 n.w.Close() 486 } 487 488 func (n *notifier) Err() error { 489 return n.tomb.Err() 490 } 491 492 // Wait waits for the notifier to finish. It always returns nil. 493 func (n *notifier) Wait() error { 494 return n.tomb.Wait() 495 } 496 497 func (n *notifier) Stop() error { 498 return worker.Stop(n) 499 }