github.com/aclements/go-misc@v0.0.0-20240129233631-2f6ede80790c/go-weave/weave/weave.go (about) 1 // Copyright 2016 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package weave 6 7 import ( 8 "errors" 9 "fmt" 10 11 "github.com/aclements/go-misc/go-weave/amb" 12 ) 13 14 // TODO: Implement simple partial order reduction. If the next actions 15 // on T1 and T2 commute, then we know that [T1,T2,...] and [T2,T1,...] 16 // are equivalent (however, we can't just cut off T2, since we still 17 // need [T2,T2,...]). 18 19 // TODO: Implement a PCT scheduler (https://www.microsoft.com/en-us/research/publication/a-randomized-scheduler-with-probabilistic-guarantees-of-finding-bugs/) 20 21 type Scheduler struct { 22 Strategy amb.Strategy 23 24 as amb.Scheduler 25 26 nextid int 27 runnable []*thread 28 blocked []*thread 29 curThread *thread 30 goErr interface{} 31 32 // wakeSched wakes the scheduler to select the next thread to 33 // run. The waking thread must immediately block on 34 // thread.wake or exit. 35 wakeSched chan void 36 37 trace []traceEntry 38 } 39 40 var globalSched *Scheduler 41 42 type void struct{} 43 44 type thread struct { 45 sched *Scheduler 46 id int 47 index int // Index in Scheduler.runnable or .blocked 48 blocked bool 49 50 tls map[*TLS]interface{} 51 52 wake chan void // Send void{} to wake this thread 53 } 54 55 func (t *thread) String() string { 56 return fmt.Sprintf("T%d", t.id) 57 } 58 59 const debug = false 60 61 func (s *Scheduler) newThread() *thread { 62 thr := &thread{s, s.nextid, -1, false, nil, make(chan void)} 63 s.nextid++ 64 if thr.id != -1 { 65 thr.index = len(s.runnable) 66 s.runnable = append(s.runnable, thr) 67 } 68 return thr 69 } 70 71 func (s *Scheduler) Run(main func()) { 72 if globalSched != nil { 73 panic("only one weave.Scheduler can be active at a time") 74 } 75 globalSched = s 76 defer func() { globalSched = nil }() 77 78 s.as = amb.Scheduler{Strategy: s.Strategy} 79 80 s.as.Run(func() { 81 // Initialize state. 82 s.nextid = 0 83 s.runnable = s.runnable[:0] 84 s.blocked = s.blocked[:0] 85 s.curThread = nil 86 s.goErr = nil 87 s.wakeSched = make(chan void) 88 s.trace = nil 89 s.goNoSched(main) 90 s.scheduler() 91 if s.goErr != nil { 92 panic(errorWithTrace{s.goErr, s.trace}) 93 } 94 if len(s.blocked) != 0 { 95 panic(errorWithTrace{fmt.Sprintf("threads asleep: %s", s.blocked), s.trace}) 96 } 97 if debug { 98 fmt.Println("run done") 99 } 100 }) 101 } 102 103 func (s *Scheduler) goNoSched(f func()) { 104 thr := s.newThread() 105 go func() { 106 defer func() { 107 goErr := recover() 108 109 if debug { 110 if goErr == threadAbort { 111 fmt.Printf("%v aborted\n", thr) 112 } else if goErr != nil { 113 fmt.Printf("%v panicked: %v\n", thr, goErr) 114 } else { 115 fmt.Printf("%v exiting normally\n", thr) 116 } 117 } 118 119 // Remove this thread from runnable. 120 s.runnable[thr.index] = s.runnable[len(s.runnable)-1] 121 s.runnable[thr.index].index = thr.index 122 s.runnable = s.runnable[:len(s.runnable)-1] 123 124 // If this is a thread abort, notify the 125 // scheduler that we're done aborting and 126 // exit. 127 if goErr == threadAbort { 128 s.wakeSched <- void{} 129 return 130 } 131 132 // If we're panicking, report the error so the 133 // scheduler can shut down this execution. 134 // 135 // TODO: Capture the stack trace. 136 if goErr != nil { 137 if s.goErr == nil { 138 s.goErr = goErr 139 } 140 s.wakeSched <- void{} 141 return 142 } 143 144 // Otherwise, this is a regular thread exit. 145 close(thr.wake) 146 s.wakeSched <- void{} 147 }() 148 if debug { 149 fmt.Printf("%v started\n", thr) 150 } 151 thr.desched() 152 f() 153 }() 154 } 155 156 func (s *Scheduler) Go(f func()) { 157 s.goNoSched(f) 158 s.Sched() 159 } 160 161 var threadAbort = errors.New("thread aborted because of panic in another thread") 162 163 // scheduler runs on the top-level thread and coordinates which thread 164 // to execute next. 165 func (s *Scheduler) scheduler() { 166 for len(s.runnable) > 0 { 167 // Pick a thread to run. If we're aborting, we just 168 // pick runnable[0], since it's not useful to explore 169 // this, and we might be aborting because amb 170 // terminated this path anyway. 171 var tid int 172 if s.goErr == nil { 173 // Amb may panic with PathTerminated. 174 func() { 175 defer func() { 176 err := recover() 177 if err == amb.PathTerminated { 178 s.goErr = err 179 } else if err != nil { 180 panic(err) 181 } 182 }() 183 tid = s.as.Amb(len(s.runnable)) 184 }() 185 } 186 s.curThread = s.runnable[tid] 187 188 if debug { 189 fmt.Printf("scheduling %v from %v\n", s.curThread, s.runnable) 190 } 191 192 // Switch to that thread. 193 s.curThread.wake <- void{} 194 195 // Wait for thread to deschedule. 196 <-s.wakeSched 197 if s.goErr != nil { 198 // This state will signal all threads to exit, 199 // but we have to wake blocked threads so they 200 // can exit, too. 201 s.runnable = append(s.runnable, s.blocked...) 202 s.blocked = nil 203 } 204 } 205 } 206 207 func (s *Scheduler) Sched() { 208 this := s.curThread 209 s.wakeSched <- void{} 210 this.desched() 211 } 212 213 func (t *thread) desched() { 214 <-t.wake 215 if t.sched.goErr != nil { 216 // We're shutting down this execution. 217 panic(threadAbort) 218 } 219 } 220 221 func (s *Scheduler) Amb(n int) int { 222 return s.as.Amb(n) 223 } 224 225 func (t *thread) block(abortf func()) { 226 if t.blocked { 227 panic("thread blocked multiple times") 228 } 229 t.blocked = true 230 231 s := t.sched 232 s.runnable[t.index] = s.runnable[len(s.runnable)-1] 233 s.runnable[t.index].index = t.index 234 s.runnable = s.runnable[:len(s.runnable)-1] 235 236 t.index = len(s.blocked) 237 s.blocked = append(s.blocked, t) 238 239 if abortf != nil { 240 defer func() { 241 if abortf != nil { 242 abortf() 243 } 244 }() 245 } 246 t.sched.Sched() 247 abortf = nil 248 } 249 250 func (t *thread) unblock() { 251 if !t.blocked { 252 panic("thread unblocked while not blocked") 253 } 254 t.blocked = false 255 256 s := t.sched 257 s.blocked[t.index] = s.blocked[len(s.blocked)-1] 258 s.blocked[t.index].index = t.index 259 s.blocked = s.blocked[:len(s.blocked)-1] 260 261 t.index = len(s.runnable) 262 s.runnable = append(s.runnable, t) 263 }