github.com/aclements/go-misc@v0.0.0-20240129233631-2f6ede80790c/go-weave/weave/weave.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package weave
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  
    11  	"github.com/aclements/go-misc/go-weave/amb"
    12  )
    13  
    14  // TODO: Implement simple partial order reduction. If the next actions
    15  // on T1 and T2 commute, then we know that [T1,T2,...] and [T2,T1,...]
    16  // are equivalent (however, we can't just cut off T2, since we still
    17  // need [T2,T2,...]).
    18  
    19  // TODO: Implement a PCT scheduler (https://www.microsoft.com/en-us/research/publication/a-randomized-scheduler-with-probabilistic-guarantees-of-finding-bugs/)
    20  
    21  type Scheduler struct {
    22  	Strategy amb.Strategy
    23  
    24  	as amb.Scheduler
    25  
    26  	nextid    int
    27  	runnable  []*thread
    28  	blocked   []*thread
    29  	curThread *thread
    30  	goErr     interface{}
    31  
    32  	// wakeSched wakes the scheduler to select the next thread to
    33  	// run. The waking thread must immediately block on
    34  	// thread.wake or exit.
    35  	wakeSched chan void
    36  
    37  	trace []traceEntry
    38  }
    39  
    40  var globalSched *Scheduler
    41  
    42  type void struct{}
    43  
    44  type thread struct {
    45  	sched   *Scheduler
    46  	id      int
    47  	index   int // Index in Scheduler.runnable or .blocked
    48  	blocked bool
    49  
    50  	tls map[*TLS]interface{}
    51  
    52  	wake chan void // Send void{} to wake this thread
    53  }
    54  
    55  func (t *thread) String() string {
    56  	return fmt.Sprintf("T%d", t.id)
    57  }
    58  
    59  const debug = false
    60  
    61  func (s *Scheduler) newThread() *thread {
    62  	thr := &thread{s, s.nextid, -1, false, nil, make(chan void)}
    63  	s.nextid++
    64  	if thr.id != -1 {
    65  		thr.index = len(s.runnable)
    66  		s.runnable = append(s.runnable, thr)
    67  	}
    68  	return thr
    69  }
    70  
    71  func (s *Scheduler) Run(main func()) {
    72  	if globalSched != nil {
    73  		panic("only one weave.Scheduler can be active at a time")
    74  	}
    75  	globalSched = s
    76  	defer func() { globalSched = nil }()
    77  
    78  	s.as = amb.Scheduler{Strategy: s.Strategy}
    79  
    80  	s.as.Run(func() {
    81  		// Initialize state.
    82  		s.nextid = 0
    83  		s.runnable = s.runnable[:0]
    84  		s.blocked = s.blocked[:0]
    85  		s.curThread = nil
    86  		s.goErr = nil
    87  		s.wakeSched = make(chan void)
    88  		s.trace = nil
    89  		s.goNoSched(main)
    90  		s.scheduler()
    91  		if s.goErr != nil {
    92  			panic(errorWithTrace{s.goErr, s.trace})
    93  		}
    94  		if len(s.blocked) != 0 {
    95  			panic(errorWithTrace{fmt.Sprintf("threads asleep: %s", s.blocked), s.trace})
    96  		}
    97  		if debug {
    98  			fmt.Println("run done")
    99  		}
   100  	})
   101  }
   102  
   103  func (s *Scheduler) goNoSched(f func()) {
   104  	thr := s.newThread()
   105  	go func() {
   106  		defer func() {
   107  			goErr := recover()
   108  
   109  			if debug {
   110  				if goErr == threadAbort {
   111  					fmt.Printf("%v aborted\n", thr)
   112  				} else if goErr != nil {
   113  					fmt.Printf("%v panicked: %v\n", thr, goErr)
   114  				} else {
   115  					fmt.Printf("%v exiting normally\n", thr)
   116  				}
   117  			}
   118  
   119  			// Remove this thread from runnable.
   120  			s.runnable[thr.index] = s.runnable[len(s.runnable)-1]
   121  			s.runnable[thr.index].index = thr.index
   122  			s.runnable = s.runnable[:len(s.runnable)-1]
   123  
   124  			// If this is a thread abort, notify the
   125  			// scheduler that we're done aborting and
   126  			// exit.
   127  			if goErr == threadAbort {
   128  				s.wakeSched <- void{}
   129  				return
   130  			}
   131  
   132  			// If we're panicking, report the error so the
   133  			// scheduler can shut down this execution.
   134  			//
   135  			// TODO: Capture the stack trace.
   136  			if goErr != nil {
   137  				if s.goErr == nil {
   138  					s.goErr = goErr
   139  				}
   140  				s.wakeSched <- void{}
   141  				return
   142  			}
   143  
   144  			// Otherwise, this is a regular thread exit.
   145  			close(thr.wake)
   146  			s.wakeSched <- void{}
   147  		}()
   148  		if debug {
   149  			fmt.Printf("%v started\n", thr)
   150  		}
   151  		thr.desched()
   152  		f()
   153  	}()
   154  }
   155  
   156  func (s *Scheduler) Go(f func()) {
   157  	s.goNoSched(f)
   158  	s.Sched()
   159  }
   160  
   161  var threadAbort = errors.New("thread aborted because of panic in another thread")
   162  
   163  // scheduler runs on the top-level thread and coordinates which thread
   164  // to execute next.
   165  func (s *Scheduler) scheduler() {
   166  	for len(s.runnable) > 0 {
   167  		// Pick a thread to run. If we're aborting, we just
   168  		// pick runnable[0], since it's not useful to explore
   169  		// this, and we might be aborting because amb
   170  		// terminated this path anyway.
   171  		var tid int
   172  		if s.goErr == nil {
   173  			// Amb may panic with PathTerminated.
   174  			func() {
   175  				defer func() {
   176  					err := recover()
   177  					if err == amb.PathTerminated {
   178  						s.goErr = err
   179  					} else if err != nil {
   180  						panic(err)
   181  					}
   182  				}()
   183  				tid = s.as.Amb(len(s.runnable))
   184  			}()
   185  		}
   186  		s.curThread = s.runnable[tid]
   187  
   188  		if debug {
   189  			fmt.Printf("scheduling %v from %v\n", s.curThread, s.runnable)
   190  		}
   191  
   192  		// Switch to that thread.
   193  		s.curThread.wake <- void{}
   194  
   195  		// Wait for thread to deschedule.
   196  		<-s.wakeSched
   197  		if s.goErr != nil {
   198  			// This state will signal all threads to exit,
   199  			// but we have to wake blocked threads so they
   200  			// can exit, too.
   201  			s.runnable = append(s.runnable, s.blocked...)
   202  			s.blocked = nil
   203  		}
   204  	}
   205  }
   206  
   207  func (s *Scheduler) Sched() {
   208  	this := s.curThread
   209  	s.wakeSched <- void{}
   210  	this.desched()
   211  }
   212  
   213  func (t *thread) desched() {
   214  	<-t.wake
   215  	if t.sched.goErr != nil {
   216  		// We're shutting down this execution.
   217  		panic(threadAbort)
   218  	}
   219  }
   220  
   221  func (s *Scheduler) Amb(n int) int {
   222  	return s.as.Amb(n)
   223  }
   224  
   225  func (t *thread) block(abortf func()) {
   226  	if t.blocked {
   227  		panic("thread blocked multiple times")
   228  	}
   229  	t.blocked = true
   230  
   231  	s := t.sched
   232  	s.runnable[t.index] = s.runnable[len(s.runnable)-1]
   233  	s.runnable[t.index].index = t.index
   234  	s.runnable = s.runnable[:len(s.runnable)-1]
   235  
   236  	t.index = len(s.blocked)
   237  	s.blocked = append(s.blocked, t)
   238  
   239  	if abortf != nil {
   240  		defer func() {
   241  			if abortf != nil {
   242  				abortf()
   243  			}
   244  		}()
   245  	}
   246  	t.sched.Sched()
   247  	abortf = nil
   248  }
   249  
   250  func (t *thread) unblock() {
   251  	if !t.blocked {
   252  		panic("thread unblocked while not blocked")
   253  	}
   254  	t.blocked = false
   255  
   256  	s := t.sched
   257  	s.blocked[t.index] = s.blocked[len(s.blocked)-1]
   258  	s.blocked[t.index].index = t.index
   259  	s.blocked = s.blocked[:len(s.blocked)-1]
   260  
   261  	t.index = len(s.runnable)
   262  	s.runnable = append(s.runnable, t)
   263  }