github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/chat/attachments/s3.go (about)

     1  package attachments
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/md5"
     7  	"encoding/hex"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"sync"
    12  
    13  	"github.com/keybase/client/go/chat/attachments/progress"
    14  	"github.com/keybase/client/go/chat/s3"
    15  	"github.com/keybase/client/go/chat/types"
    16  	"github.com/keybase/client/go/protocol/chat1"
    17  	"golang.org/x/sync/errgroup"
    18  )
    19  
    20  const s3PipelineMaxWidth = 10
    21  
    22  type s3UploadPipeliner struct {
    23  	sync.Mutex
    24  	width   int
    25  	waiters []chan struct{}
    26  }
    27  
    28  func (s *s3UploadPipeliner) QueueForTakeoff(ctx context.Context) error {
    29  	s.Lock()
    30  	if s.width >= s3PipelineMaxWidth {
    31  		ch := make(chan struct{})
    32  		s.waiters = append(s.waiters, ch)
    33  		s.Unlock()
    34  		select {
    35  		case <-ch:
    36  		case <-ctx.Done():
    37  			return ctx.Err()
    38  		}
    39  		s.Lock()
    40  		s.width++
    41  		s.Unlock()
    42  		return nil
    43  	}
    44  	s.width++
    45  	s.Unlock()
    46  	return nil
    47  }
    48  
    49  func (s *s3UploadPipeliner) Complete() {
    50  	s.Lock()
    51  	defer s.Unlock()
    52  	if len(s.waiters) > 0 {
    53  		close(s.waiters[0])
    54  		if len(s.waiters) > 1 {
    55  			s.waiters = s.waiters[1:]
    56  		} else {
    57  			s.waiters = nil
    58  		}
    59  	}
    60  	if s.width > 0 {
    61  		s.width--
    62  	}
    63  }
    64  
    65  var s3UploadPipeline = &s3UploadPipeliner{}
    66  
    67  const minMultiSize = 5 * 1024 * 1024 // can't use Multi API with parts less than 5MB
    68  const blockSize = 5 * 1024 * 1024    // 5MB is the minimum Multi part size
    69  
    70  // ErrAbortOnPartMismatch is returned when there is a mismatch between a current
    71  // part and a previous attempt part.  If ErrAbortOnPartMismatch is returned,
    72  // the caller should abort the upload attempt and start from scratch.
    73  var ErrAbortOnPartMismatch = errors.New("local part mismatch, aborting upload")
    74  
    75  // PutS3Result is the success result of calling PutS3.
    76  type PutS3Result struct {
    77  	Region   string
    78  	Endpoint string
    79  	Bucket   string
    80  	Path     string
    81  	Size     int64
    82  }
    83  
    84  // PutS3 uploads the data in Reader r to S3.  It chooses whether to use
    85  // putSingle or putMultiPipeline based on the size of the object.
    86  func (a *S3Store) PutS3(ctx context.Context, r io.Reader, size int64, task *UploadTask, previous *AttachmentInfo) (res *PutS3Result, err error) {
    87  	defer a.Trace(ctx, &err, "PutS3")()
    88  	region := a.regionFromParams(task.S3Params)
    89  	b := a.s3Conn(task.S3Signer, region, task.S3Params.AccessKey, task.S3Params.Token).Bucket(
    90  		task.S3Params.Bucket)
    91  
    92  	multiPartUpload := size > minMultiSize
    93  	if multiPartUpload && a.G().Env.GetAttachmentDisableMulti() {
    94  		a.Debug(ctx, "PutS3: multi part upload manually disabled, overriding for size: %v", size)
    95  		multiPartUpload = false
    96  	}
    97  
    98  	if !multiPartUpload {
    99  		if err := a.putSingle(ctx, r, size, task.S3Params, b, task.Progress); err != nil {
   100  			return nil, err
   101  		}
   102  	} else {
   103  		objectKey, err := a.putMultiPipeline(ctx, r, size, task, b, previous)
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  		task.S3Params.ObjectKey = objectKey
   108  	}
   109  
   110  	s3res := PutS3Result{
   111  		Region:   task.S3Params.RegionName,
   112  		Endpoint: task.S3Params.RegionEndpoint,
   113  		Bucket:   task.S3Params.Bucket,
   114  		Path:     task.S3Params.ObjectKey,
   115  		Size:     size,
   116  	}
   117  	return &s3res, nil
   118  }
   119  
   120  // putSingle uploads data in r to S3 with the Put API.  It has to be
   121  // used for anything less than 5MB.  It can be used for anything up
   122  // to 5GB, but putMultiPipeline best for anything over 5MB.
   123  func (a *S3Store) putSingle(ctx context.Context, r io.Reader, size int64, params chat1.S3Params,
   124  	b s3.BucketInt, progressReporter types.ProgressReporter) (err error) {
   125  	defer a.Trace(ctx, &err, fmt.Sprintf("putSingle(size=%d)", size))()
   126  
   127  	progWriter := progress.NewProgressWriter(progressReporter, size)
   128  	tee := io.TeeReader(r, progWriter)
   129  
   130  	if err := b.PutReader(ctx, params.ObjectKey, tee, size, "application/octet-stream", s3.ACL(params.Acl),
   131  		s3.Options{}); err != nil {
   132  		a.Debug(ctx, "putSingle: failed: %s", err)
   133  		return NewErrorWrapper("failed putSingle", err)
   134  	}
   135  	progWriter.Finish()
   136  	return nil
   137  }
   138  
   139  // putMultiPipeline uploads data in r to S3 using the Multi API.  It uses a
   140  // pipeline to upload 10 blocks of data concurrently.
   141  // Each block is 5MB. It returns the object key if no errors.  putMultiPipeline
   142  // will return a different object key from params.ObjectKey if a previous Put is
   143  // successfully resumed and completed.
   144  func (a *S3Store) putMultiPipeline(ctx context.Context, r io.Reader, size int64, task *UploadTask, b s3.BucketInt, previous *AttachmentInfo) (res string, err error) {
   145  	defer a.Trace(ctx, &err, fmt.Sprintf("putMultiPipeline(size=%d)", size))()
   146  
   147  	var multi s3.MultiInt
   148  	if previous != nil {
   149  		a.Debug(ctx, "putMultiPipeline: previous exists. Changing object key from %q to %q",
   150  			task.S3Params.ObjectKey, previous.ObjectKey)
   151  		task.S3Params.ObjectKey = previous.ObjectKey
   152  	}
   153  
   154  	multi, err = b.Multi(ctx, task.S3Params.ObjectKey, "application/octet-stream", s3.ACL(task.S3Params.Acl))
   155  	if err != nil {
   156  		a.Debug(ctx, "putMultiPipeline: b.Multi error: %s", err.Error())
   157  		return "", NewErrorWrapper("s3 Multi error", err)
   158  	}
   159  
   160  	var previousParts map[int]s3.Part
   161  	if previous != nil {
   162  		previousParts = make(map[int]s3.Part)
   163  		list, err := multi.ListParts(ctx)
   164  		if err != nil {
   165  			a.Debug(ctx, "putMultiPipeline: ignoring multi.ListParts error: %s", err)
   166  			// dump previous since we can't check it anymore
   167  			previous = nil
   168  		} else {
   169  			for _, p := range list {
   170  				previousParts[p.N] = p
   171  			}
   172  		}
   173  	}
   174  
   175  	// need to use ectx in everything in eg.Go() funcs since eg
   176  	// will cancel ectx in eg.Wait().
   177  	a.Debug(ctx, "putMultiPipeline: beginning parts uploader process")
   178  	eg, ectx := errgroup.WithContext(ctx)
   179  	blockCh := make(chan job)
   180  	retCh := make(chan s3.Part)
   181  	eg.Go(func() error {
   182  		defer close(blockCh)
   183  		return a.makeBlockJobs(ectx, r, blockCh, task.stashKey(), previous)
   184  	})
   185  	eg.Go(func() error {
   186  		for lb := range blockCh {
   187  			if err := s3UploadPipeline.QueueForTakeoff(ectx); err != nil {
   188  				return err
   189  			}
   190  			b := lb
   191  			eg.Go(func() error {
   192  				defer s3UploadPipeline.Complete()
   193  				if err := a.uploadPart(ectx, task, b, previous, previousParts, multi, retCh); err != nil {
   194  					return err
   195  				}
   196  				return nil
   197  			})
   198  		}
   199  		return nil
   200  	})
   201  	go func() {
   202  		err := eg.Wait()
   203  		if err != nil {
   204  			a.Debug(ctx, "putMultiPipeline: error waiting: %+v", err)
   205  		}
   206  		close(retCh)
   207  	}()
   208  
   209  	var parts []s3.Part
   210  	progWriter := progress.NewProgressWriter(task.Progress, size)
   211  	for p := range retCh {
   212  		parts = append(parts, p)
   213  		progWriter.Update(int(p.Size))
   214  	}
   215  	if err := eg.Wait(); err != nil {
   216  		return "", err
   217  	}
   218  
   219  	if a.blockLimit > 0 {
   220  		return "", errors.New("block limit hit, not completing multi upload")
   221  	}
   222  	a.Debug(ctx, "putMultiPipeline: all parts uploaded, completing request")
   223  	if err := multi.Complete(ctx, parts); err != nil {
   224  		a.Debug(ctx, "putMultiPipeline: Complete() failed: %s", err)
   225  		return "", err
   226  	}
   227  	a.Debug(ctx, "putMultiPipeline: success, %d parts", len(parts))
   228  	// Just to make sure the UI gets the 100% call
   229  	progWriter.Finish()
   230  	return task.S3Params.ObjectKey, nil
   231  }
   232  
   233  type job struct {
   234  	block []byte
   235  	index int
   236  	hash  string
   237  }
   238  
   239  func (j job) etag() string {
   240  	return `"` + j.hash + `"`
   241  }
   242  
   243  // makeBlockJobs reads ciphertext chunks from r and creates jobs that it puts onto blockCh.
   244  // If this is a resumed upload, it verifies the blocks against the local stash before
   245  // creating jobs.
   246  func (a *S3Store) makeBlockJobs(ctx context.Context, r io.Reader, blockCh chan job, stashKey StashKey, previous *AttachmentInfo) error {
   247  	var partNumber int
   248  	for {
   249  		partNumber++
   250  		block := make([]byte, blockSize)
   251  		// must call io.ReadFull to ensure full block read
   252  		n, err := io.ReadFull(r, block)
   253  		// io.ErrUnexpectedEOF will be returned for last partial block,
   254  		// which is ok.
   255  		if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
   256  			return err
   257  		}
   258  		if n < blockSize {
   259  			block = block[:n]
   260  		}
   261  		if n > 0 {
   262  			md5sum := md5.Sum(block)
   263  			md5hex := hex.EncodeToString(md5sum[:])
   264  
   265  			if previous != nil {
   266  				// resuming an upload, so check local stash record
   267  				// and abort on mismatch before adding a job for this block
   268  				// because if we don't it amounts to nonce reuse
   269  				lhash, found := previous.Parts[partNumber]
   270  				if found && lhash != md5hex {
   271  					a.Debug(ctx, "makeBlockJobs: part %d failed local part record verification", partNumber)
   272  					return ErrAbortOnPartMismatch
   273  				}
   274  			}
   275  
   276  			if err := a.addJob(ctx, blockCh, block, partNumber, md5hex); err != nil {
   277  				return err
   278  			}
   279  		}
   280  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   281  			break
   282  		}
   283  
   284  		if a.blockLimit > 0 && partNumber >= a.blockLimit {
   285  			a.Debug(ctx, "makeBlockJobs: hit blockLimit of %d", a.blockLimit)
   286  			break
   287  		}
   288  	}
   289  	return nil
   290  }
   291  
   292  // addJob creates a job and puts it on blockCh, unless the blockCh isn't ready and the context has been canceled.
   293  func (a *S3Store) addJob(ctx context.Context, blockCh chan job, block []byte, partNumber int, hash string) error {
   294  	// Create a job, unless the context has been canceled.
   295  	select {
   296  	case blockCh <- job{block: block, index: partNumber, hash: hash}:
   297  		return nil
   298  	case <-ctx.Done():
   299  		return ctx.Err()
   300  	}
   301  }
   302  
   303  // uploadPart handles uploading a job to S3.  The job `b` has already passed local stash verification.
   304  // If this is a resumed upload, it checks the previous parts reported by S3 and will skip uploading
   305  // any that already exist.
   306  func (a *S3Store) uploadPart(ctx context.Context, task *UploadTask, b job, previous *AttachmentInfo, previousParts map[int]s3.Part, multi s3.MultiInt, retCh chan s3.Part) (err error) {
   307  	defer a.Trace(ctx, &err, fmt.Sprintf("uploadPart(%d)", b.index))()
   308  
   309  	// check to see if this part has already been uploaded.
   310  	// for job `b` to be here, it has already passed local stash verification.
   311  	if previous != nil {
   312  		// check s3 previousParts for this block
   313  		p, ok := previousParts[b.index]
   314  		if ok && int(p.Size) == len(b.block) && p.ETag == b.etag() {
   315  			a.Debug(ctx, "uploadPart: part %d already uploaded to s3", b.index)
   316  
   317  			// part already uploaded, so put it in the retCh unless the context
   318  			// has been canceled
   319  			select {
   320  			case retCh <- p:
   321  			case <-ctx.Done():
   322  				return ctx.Err()
   323  			}
   324  
   325  			// nothing else to do
   326  			return nil
   327  		}
   328  
   329  		if p.Size > 0 {
   330  			// only abort if the part size from s3 is > 0.
   331  			a.Debug(ctx, "uploadPart: part %d s3 mismatch:  size %d != expected %d or etag %s != expected %s",
   332  				b.index, p.Size, len(b.block), p.ETag, b.etag())
   333  			return ErrAbortOnPartMismatch
   334  		}
   335  
   336  		// this part doesn't exist on s3, so it needs to be uploaded
   337  		a.Debug(ctx, "uploadPart: part %d not uploaded to s3 by previous upload attempt", b.index)
   338  	}
   339  
   340  	// stash part info locally before attempting S3 put
   341  	// doing this before attempting the S3 put is important
   342  	// for security concerns.
   343  	if err := a.stash.RecordPart(task.stashKey(), b.index, b.hash); err != nil {
   344  		a.Debug(ctx, "uploadPart: StashRecordPart error: %s", err)
   345  	}
   346  
   347  	part, putErr := multi.PutPart(ctx, b.index, bytes.NewReader(b.block))
   348  	if putErr != nil {
   349  		return NewErrorWrapper(fmt.Sprintf("failed to put part %d", b.index), putErr)
   350  	}
   351  
   352  	// put the successfully uploaded part information in the retCh
   353  	// unless the context has been canceled.
   354  	select {
   355  	case retCh <- part:
   356  	case <-ctx.Done():
   357  		a.Debug(ctx, "uploadPart: upload part %d, context canceled", b.index)
   358  		return ctx.Err()
   359  	}
   360  
   361  	return nil
   362  }
   363  
   364  type ErrorWrapper struct {
   365  	prefix string
   366  	err    error
   367  }
   368  
   369  func NewErrorWrapper(prefix string, err error) *ErrorWrapper {
   370  	return &ErrorWrapper{prefix: prefix, err: err}
   371  }
   372  
   373  func (e *ErrorWrapper) Error() string {
   374  	return fmt.Sprintf("%s: %s (%T)", e.prefix, e.err, e.err)
   375  }
   376  
   377  func (e *ErrorWrapper) Details() string {
   378  	switch err := e.err.(type) {
   379  	case *s3.Error:
   380  		return fmt.Sprintf("%s: error %q, status code: %d, code: %s, message: %s, bucket: %s", e.prefix, e.err, err.StatusCode, err.Code, err.Message, err.BucketName)
   381  	default:
   382  		return fmt.Sprintf("%s: error %q, no details for type %T", e.prefix, e.err, e.err)
   383  	}
   384  }
   385  
   386  type S3Signer struct {
   387  	ri func() chat1.RemoteInterface
   388  }
   389  
   390  func NewS3Signer(ri func() chat1.RemoteInterface) *S3Signer {
   391  	return &S3Signer{
   392  		ri: ri,
   393  	}
   394  }
   395  
   396  // Sign implements github.com/keybase/go/chat/s3.Signer interface.
   397  func (s *S3Signer) Sign(payload []byte) ([]byte, error) {
   398  	arg := chat1.S3SignArg{
   399  		Payload:   payload,
   400  		Version:   1,
   401  		TempCreds: true,
   402  	}
   403  	return s.ri().S3Sign(context.Background(), arg)
   404  }