github.com/abhinav/git-pr@v0.6.1-0.20171029234004-54218d68c11b/pr/walk.go (about)

     1  package pr
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"runtime"
     7  	"sync"
     8  
     9  	"go.uber.org/multierr"
    10  
    11  	"github.com/google/go-github/github"
    12  )
    13  
    14  // Visitor defines what to do at each pull request during a walk.
    15  type Visitor interface {
    16  	// Visits the given pull request and returns a new visitor to visit its
    17  	// children.
    18  	//
    19  	// If a visitor was not returned, the children of this PR will not be
    20  	// visited.
    21  	//
    22  	// This function MAY be called concurrently. Implementations MUST be
    23  	// thread-safe.
    24  	Visit(*github.PullRequest) (Visitor, error)
    25  }
    26  
    27  //go:generate mockgen -package=prtest -destination=prtest/mocks.go github.com/abhinav/git-pr/pr Visitor
    28  
    29  // WalkConfig configures a pull request traversal.
    30  type WalkConfig struct {
    31  	// Maximum number of pull requests to visit at the same time.
    32  	//
    33  	// Defaults to the number of CPUs available to this process.
    34  	Concurrency int
    35  
    36  	// Children retrieves the children of the given pull request.
    37  	//
    38  	// The definition of what constitutes a child of a PR is left up to the
    39  	// implementation.
    40  	//
    41  	// This function MAY be called concurrently. Implementations MUST be
    42  	// thread-safe.
    43  	Children func(*github.PullRequest) ([]*github.PullRequest, error)
    44  }
    45  
    46  // Walk traverses a pull request tree by visiting the given pull requests and
    47  // their children in an unspecified order. The only ordering guarantee is that
    48  // parents are visited before their children.
    49  //
    50  // Errors encountered while visiting pull requests are collatted and presented
    51  // as one.
    52  func Walk(cfg WalkConfig, pulls []*github.PullRequest, v Visitor) error {
    53  	if cfg.Children == nil {
    54  		panic("WalkConfig.Children must be set")
    55  	}
    56  
    57  	if cfg.Concurrency <= 0 {
    58  		cfg.Concurrency = runtime.NumCPU()
    59  	}
    60  
    61  	w := walker{
    62  		// TODO: Magic number. Should make this customizable or leave it the
    63  		// same as Concurrency.
    64  		tasks:    make(chan task, 8),
    65  		children: cfg.Children,
    66  	}
    67  
    68  	w.ongoing.Add(len(pulls))
    69  	go func() {
    70  		// If pulls contains more than 8 items, we don't want to block on
    71  		// filling tasks just yet.
    72  		for _, pr := range pulls {
    73  			w.tasks <- task{PR: pr, Visitor: v}
    74  		}
    75  	}()
    76  
    77  	for i := 0; i < cfg.Concurrency; i++ {
    78  		go w.Worker()
    79  	}
    80  	w.ongoing.Wait()
    81  	close(w.tasks)
    82  
    83  	return multierr.Combine(w.errors...)
    84  }
    85  
    86  // Request to visit a single pull request with a specific visitor.
    87  type task struct {
    88  	PR      *github.PullRequest
    89  	Visitor Visitor
    90  }
    91  
    92  type walker struct {
    93  	// Incoming tasks. Any worker can handle these.
    94  	tasks chan task
    95  
    96  	// Number of ongoing tasks.
    97  	ongoing sync.WaitGroup
    98  
    99  	children func(*github.PullRequest) ([]*github.PullRequest, error)
   100  
   101  	// Errors encountered while processing.
   102  	errorsMu sync.Mutex
   103  	errors   []error
   104  }
   105  
   106  func (w *walker) Worker() {
   107  	// Walker-local buffer for incoming tasks that should be pushed into
   108  	// w.tasks when it's empty.
   109  	taskBuffer := list.New()
   110  
   111  worker:
   112  	for {
   113  	fill:
   114  		// Exhaust as much of the buffer as we can.
   115  		for taskBuffer.Len() > 0 {
   116  			e := taskBuffer.Front()
   117  			select {
   118  			case w.tasks <- e.Value.(task):
   119  				taskBuffer.Remove(e)
   120  			default:
   121  				// No more room in channel.
   122  				break fill
   123  			}
   124  		}
   125  
   126  		t, ok := <-w.tasks
   127  		if !ok {
   128  			// Channel closed. We're done.
   129  			break worker
   130  		}
   131  
   132  		newTasks, err := w.visit(t)
   133  		if err != nil {
   134  			w.errorsMu.Lock()
   135  			w.errors = append(w.errors, err)
   136  			w.errorsMu.Unlock()
   137  		} else if len(newTasks) > 0 {
   138  			for _, task := range newTasks {
   139  				taskBuffer.PushBack(task)
   140  			}
   141  		}
   142  		w.ongoing.Add(len(newTasks) - 1)
   143  	}
   144  }
   145  
   146  func (w *walker) visit(t task) (_ []task, err error) {
   147  	defer func() {
   148  		if x := recover(); x != nil {
   149  			if e, ok := x.(error); ok {
   150  				err = e
   151  			} else {
   152  				// TODO: log the panic
   153  				err = fmt.Errorf("panic: %v", x)
   154  			}
   155  		}
   156  	}()
   157  
   158  	v, err := t.Visitor.Visit(t.PR)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  
   163  	if v == nil {
   164  		return nil, nil
   165  	}
   166  
   167  	children, err := w.children(t.PR)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	tasks := make([]task, len(children))
   173  	for i, pr := range children {
   174  		tasks[i] = task{PR: pr, Visitor: v}
   175  	}
   176  	return tasks, nil
   177  }