github.com/hawser/git-hawser@v2.5.2+incompatible/tq/adapterbase.go (about)

     1  package tq
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"regexp"
     7  	"strings"
     8  	"sync"
     9  
    10  	"github.com/git-lfs/git-lfs/fs"
    11  	"github.com/git-lfs/git-lfs/lfsapi"
    12  	"github.com/rubyist/tracerx"
    13  )
    14  
    15  // adapterBase implements the common functionality for core adapters which
    16  // process transfers with N workers handling an oid each, and which wait for
    17  // authentication to succeed on one worker before proceeding
    18  type adapterBase struct {
    19  	fs           *fs.Filesystem
    20  	name         string
    21  	direction    Direction
    22  	transferImpl transferImplementation
    23  	apiClient    *lfsapi.Client
    24  	remote       string
    25  	jobChan      chan *job
    26  	debugging    bool
    27  	cb           ProgressCallback
    28  	// WaitGroup to sync the completion of all workers
    29  	workerWait sync.WaitGroup
    30  	// WaitGroup to sync the completion of all in-flight jobs
    31  	jobWait *sync.WaitGroup
    32  	// WaitGroup to serialise the first transfer response to perform login if needed
    33  	authWait sync.WaitGroup
    34  }
    35  
    36  // transferImplementation must be implemented to provide the actual upload/download
    37  // implementation for all core transfer approaches that use adapterBase for
    38  // convenience. This function will be called on multiple goroutines so it
    39  // must be either stateless or thread safe. However it will never be called
    40  // for the same oid in parallel.
    41  // If authOkFunc is not nil, implementations must call it as early as possible
    42  // when authentication succeeded, before the whole file content is transferred
    43  type transferImplementation interface {
    44  	// WorkerStarting is called when a worker goroutine starts to process jobs
    45  	// Implementations can run some startup logic here & return some context if needed
    46  	WorkerStarting(workerNum int) (interface{}, error)
    47  	// WorkerEnding is called when a worker goroutine is shutting down
    48  	// Implementations can clean up per-worker resources here, context is as returned from WorkerStarting
    49  	WorkerEnding(workerNum int, ctx interface{})
    50  	// DoTransfer performs a single transfer within a worker. ctx is any context returned from WorkerStarting
    51  	DoTransfer(ctx interface{}, t *Transfer, cb ProgressCallback, authOkFunc func()) error
    52  }
    53  
    54  func newAdapterBase(f *fs.Filesystem, name string, dir Direction, ti transferImplementation) *adapterBase {
    55  	return &adapterBase{
    56  		fs:           f,
    57  		name:         name,
    58  		direction:    dir,
    59  		transferImpl: ti,
    60  		jobWait:      new(sync.WaitGroup),
    61  	}
    62  }
    63  
    64  func (a *adapterBase) Name() string {
    65  	return a.name
    66  }
    67  
    68  func (a *adapterBase) Direction() Direction {
    69  	return a.direction
    70  }
    71  
    72  func (a *adapterBase) Begin(cfg AdapterConfig, cb ProgressCallback) error {
    73  	a.apiClient = cfg.APIClient()
    74  	a.remote = cfg.Remote()
    75  	a.cb = cb
    76  	a.jobChan = make(chan *job, 100)
    77  	a.debugging = a.apiClient.OSEnv().Bool("GIT_TRANSFER_TRACE", false)
    78  	maxConcurrency := cfg.ConcurrentTransfers()
    79  
    80  	a.Trace("xfer: adapter %q Begin() with %d workers", a.Name(), maxConcurrency)
    81  
    82  	a.workerWait.Add(maxConcurrency)
    83  	a.authWait.Add(1)
    84  	for i := 0; i < maxConcurrency; i++ {
    85  		ctx, err := a.transferImpl.WorkerStarting(i)
    86  		if err != nil {
    87  			return err
    88  		}
    89  		go a.worker(i, ctx)
    90  	}
    91  	a.Trace("xfer: adapter %q started", a.Name())
    92  	return nil
    93  }
    94  
    95  type job struct {
    96  	T *Transfer
    97  
    98  	results chan<- TransferResult
    99  	wg      *sync.WaitGroup
   100  }
   101  
   102  func (j *job) Done(err error) {
   103  	j.results <- TransferResult{j.T, err}
   104  	j.wg.Done()
   105  }
   106  
   107  func (a *adapterBase) Add(transfers ...*Transfer) <-chan TransferResult {
   108  	results := make(chan TransferResult, len(transfers))
   109  
   110  	a.jobWait.Add(len(transfers))
   111  
   112  	go func() {
   113  		for _, t := range transfers {
   114  			a.jobChan <- &job{t, results, a.jobWait}
   115  		}
   116  		a.jobWait.Wait()
   117  
   118  		close(results)
   119  	}()
   120  
   121  	return results
   122  }
   123  
   124  func (a *adapterBase) End() {
   125  	a.Trace("xfer: adapter %q End()", a.Name())
   126  
   127  	a.jobWait.Wait()
   128  	close(a.jobChan)
   129  
   130  	// wait for all transfers to complete
   131  	a.workerWait.Wait()
   132  
   133  	a.Trace("xfer: adapter %q stopped", a.Name())
   134  }
   135  
   136  func (a *adapterBase) Trace(format string, args ...interface{}) {
   137  	if !a.debugging {
   138  		return
   139  	}
   140  	tracerx.Printf(format, args...)
   141  }
   142  
   143  // worker function, many of these run per adapter
   144  func (a *adapterBase) worker(workerNum int, ctx interface{}) {
   145  	a.Trace("xfer: adapter %q worker %d starting", a.Name(), workerNum)
   146  	waitForAuth := workerNum > 0
   147  	signalAuthOnResponse := workerNum == 0
   148  
   149  	// First worker is the only one allowed to start immediately
   150  	// The rest wait until successful response from 1st worker to
   151  	// make sure only 1 login prompt is presented if necessary
   152  	// Deliberately outside jobChan processing so we know worker 0 will process 1st item
   153  	if waitForAuth {
   154  		a.Trace("xfer: adapter %q worker %d waiting for Auth", a.Name(), workerNum)
   155  		a.authWait.Wait()
   156  		a.Trace("xfer: adapter %q worker %d auth signal received", a.Name(), workerNum)
   157  	}
   158  
   159  	for job := range a.jobChan {
   160  		t := job.T
   161  
   162  		var authCallback func()
   163  		if signalAuthOnResponse {
   164  			authCallback = func() {
   165  				a.authWait.Done()
   166  				signalAuthOnResponse = false
   167  			}
   168  		}
   169  		a.Trace("xfer: adapter %q worker %d processing job for %q", a.Name(), workerNum, t.Oid)
   170  
   171  		// Actual transfer happens here
   172  		var err error
   173  		if t.Size < 0 {
   174  			err = fmt.Errorf("Git LFS: object %q has invalid size (got: %d)", t.Oid, t.Size)
   175  		} else {
   176  			err = a.transferImpl.DoTransfer(ctx, t, a.cb, authCallback)
   177  		}
   178  
   179  		// Mark the job as completed, and alter all listeners
   180  		job.Done(err)
   181  
   182  		a.Trace("xfer: adapter %q worker %d finished job for %q", a.Name(), workerNum, t.Oid)
   183  	}
   184  	// This will only happen if no jobs were submitted; just wake up all workers to finish
   185  	if signalAuthOnResponse {
   186  		a.authWait.Done()
   187  	}
   188  	a.Trace("xfer: adapter %q worker %d stopping", a.Name(), workerNum)
   189  	a.transferImpl.WorkerEnding(workerNum, ctx)
   190  	a.workerWait.Done()
   191  }
   192  
   193  var httpRE = regexp.MustCompile(`\Ahttps?://`)
   194  
   195  func (a *adapterBase) newHTTPRequest(method string, rel *Action) (*http.Request, error) {
   196  	if !httpRE.MatchString(rel.Href) {
   197  		urlfragment := strings.SplitN(rel.Href, "?", 2)[0]
   198  		return nil, fmt.Errorf("missing protocol: %q", urlfragment)
   199  	}
   200  
   201  	req, err := http.NewRequest(method, rel.Href, nil)
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  
   206  	for key, value := range rel.Header {
   207  		req.Header.Set(key, value)
   208  	}
   209  
   210  	return req, nil
   211  }
   212  
   213  func (a *adapterBase) doHTTP(t *Transfer, req *http.Request) (*http.Response, error) {
   214  	if t.Authenticated {
   215  		return a.apiClient.Do(req)
   216  	}
   217  	return a.apiClient.DoWithAuth(a.remote, req)
   218  }
   219  
   220  func advanceCallbackProgress(cb ProgressCallback, t *Transfer, numBytes int64) {
   221  	if cb != nil {
   222  		// Must split into max int sizes since read count is int
   223  		const maxInt = int(^uint(0) >> 1)
   224  		for read := int64(0); read < numBytes; {
   225  			remainder := numBytes - read
   226  			if remainder > int64(maxInt) {
   227  				read += int64(maxInt)
   228  				cb(t.Name, t.Size, read, maxInt)
   229  			} else {
   230  				read += remainder
   231  				cb(t.Name, t.Size, read, int(remainder))
   232  			}
   233  
   234  		}
   235  	}
   236  }