github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/chat/attachments/s3.go (about)

     1  package attachments
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/md5"
     7  	"encoding/hex"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"sync"
    12  
    13  	"github.com/keybase/client/go/chat/attachments/progress"
    14  	"github.com/keybase/client/go/chat/s3"
    15  	"github.com/keybase/client/go/chat/types"
    16  	"github.com/keybase/client/go/protocol/chat1"
    17  	"golang.org/x/sync/errgroup"
    18  )
    19  
    20  const s3PipelineMaxWidth = 10
    21  
    22  type s3UploadPipeliner struct {
    23  	sync.Mutex
    24  	width   int
    25  	waiters []chan struct{}
    26  }
    27  
    28  func (s *s3UploadPipeliner) QueueForTakeoff(ctx context.Context) error {
    29  	s.Lock()
    30  	if s.width >= s3PipelineMaxWidth {
    31  		ch := make(chan struct{})
    32  		s.waiters = append(s.waiters, ch)
    33  		s.Unlock()
    34  		select {
    35  		case <-ch:
    36  		case <-ctx.Done():
    37  			return ctx.Err()
    38  		}
    39  		s.Lock()
    40  		s.width++
    41  		s.Unlock()
    42  		return nil
    43  	}
    44  	s.width++
    45  	s.Unlock()
    46  	return nil
    47  }
    48  
    49  func (s *s3UploadPipeliner) Complete() {
    50  	s.Lock()
    51  	defer s.Unlock()
    52  	if len(s.waiters) > 0 {
    53  		close(s.waiters[0])
    54  		if len(s.waiters) > 1 {
    55  			s.waiters = s.waiters[1:]
    56  		} else {
    57  			s.waiters = nil
    58  		}
    59  	}
    60  	if s.width > 0 {
    61  		s.width--
    62  	}
    63  }
    64  
    65  var s3UploadPipeline = &s3UploadPipeliner{}
    66  
    67  const minMultiSize = 5 * 1024 * 1024 // can't use Multi API with parts less than 5MB
    68  const blockSize = 5 * 1024 * 1024    // 5MB is the minimum Multi part size
    69  
    70  // ErrAbortOnPartMismatch is returned when there is a mismatch between a current
    71  // part and a previous attempt part.  If ErrAbortOnPartMismatch is returned,
    72  // the caller should abort the upload attempt and start from scratch.
    73  var ErrAbortOnPartMismatch = errors.New("local part mismatch, aborting upload")
    74  
    75  // PutS3Result is the success result of calling PutS3.
    76  type PutS3Result struct {
    77  	Region   string
    78  	Endpoint string
    79  	Bucket   string
    80  	Path     string
    81  	Size     int64
    82  }
    83  
    84  // PutS3 uploads the data in Reader r to S3.  It chooses whether to use
    85  // putSingle or putMultiPipeline based on the size of the object.
    86  func (a *S3Store) PutS3(ctx context.Context, r io.Reader, size int64, task *UploadTask, previous *AttachmentInfo) (res *PutS3Result, err error) {
    87  	defer a.Trace(ctx, &err, "PutS3")()
    88  	region := a.regionFromParams(task.S3Params)
    89  	b := a.s3Conn(task.S3Signer, region, task.S3Params.AccessKey).Bucket(task.S3Params.Bucket)
    90  
    91  	multiPartUpload := size > minMultiSize
    92  	if multiPartUpload && a.G().Env.GetAttachmentDisableMulti() {
    93  		a.Debug(ctx, "PutS3: multi part upload manually disabled, overriding for size: %v", size)
    94  		multiPartUpload = false
    95  	}
    96  
    97  	if !multiPartUpload {
    98  		if err := a.putSingle(ctx, r, size, task.S3Params, b, task.Progress); err != nil {
    99  			return nil, err
   100  		}
   101  	} else {
   102  		objectKey, err := a.putMultiPipeline(ctx, r, size, task, b, previous)
   103  		if err != nil {
   104  			return nil, err
   105  		}
   106  		task.S3Params.ObjectKey = objectKey
   107  	}
   108  
   109  	s3res := PutS3Result{
   110  		Region:   task.S3Params.RegionName,
   111  		Endpoint: task.S3Params.RegionEndpoint,
   112  		Bucket:   task.S3Params.Bucket,
   113  		Path:     task.S3Params.ObjectKey,
   114  		Size:     size,
   115  	}
   116  	return &s3res, nil
   117  }
   118  
   119  // putSingle uploads data in r to S3 with the Put API.  It has to be
   120  // used for anything less than 5MB.  It can be used for anything up
   121  // to 5GB, but putMultiPipeline best for anything over 5MB.
   122  func (a *S3Store) putSingle(ctx context.Context, r io.Reader, size int64, params chat1.S3Params,
   123  	b s3.BucketInt, progressReporter types.ProgressReporter) (err error) {
   124  	defer a.Trace(ctx, &err, fmt.Sprintf("putSingle(size=%d)", size))()
   125  
   126  	progWriter := progress.NewProgressWriter(progressReporter, size)
   127  	tee := io.TeeReader(r, progWriter)
   128  
   129  	if err := b.PutReader(ctx, params.ObjectKey, tee, size, "application/octet-stream", s3.ACL(params.Acl),
   130  		s3.Options{}); err != nil {
   131  		a.Debug(ctx, "putSingle: failed: %s", err)
   132  		return NewErrorWrapper("failed putSingle", err)
   133  	}
   134  	progWriter.Finish()
   135  	return nil
   136  }
   137  
   138  // putMultiPipeline uploads data in r to S3 using the Multi API.  It uses a
   139  // pipeline to upload 10 blocks of data concurrently.
   140  // Each block is 5MB. It returns the object key if no errors.  putMultiPipeline
   141  // will return a different object key from params.ObjectKey if a previous Put is
   142  // successfully resumed and completed.
   143  func (a *S3Store) putMultiPipeline(ctx context.Context, r io.Reader, size int64, task *UploadTask, b s3.BucketInt, previous *AttachmentInfo) (res string, err error) {
   144  	defer a.Trace(ctx, &err, fmt.Sprintf("putMultiPipeline(size=%d)", size))()
   145  
   146  	var multi s3.MultiInt
   147  	if previous != nil {
   148  		a.Debug(ctx, "putMultiPipeline: previous exists. Changing object key from %q to %q",
   149  			task.S3Params.ObjectKey, previous.ObjectKey)
   150  		task.S3Params.ObjectKey = previous.ObjectKey
   151  	}
   152  
   153  	multi, err = b.Multi(ctx, task.S3Params.ObjectKey, "application/octet-stream", s3.ACL(task.S3Params.Acl))
   154  	if err != nil {
   155  		a.Debug(ctx, "putMultiPipeline: b.Multi error: %s", err.Error())
   156  		return "", NewErrorWrapper("s3 Multi error", err)
   157  	}
   158  
   159  	var previousParts map[int]s3.Part
   160  	if previous != nil {
   161  		previousParts = make(map[int]s3.Part)
   162  		list, err := multi.ListParts(ctx)
   163  		if err != nil {
   164  			a.Debug(ctx, "putMultiPipeline: ignoring multi.ListParts error: %s", err)
   165  			// dump previous since we can't check it anymore
   166  			previous = nil
   167  		} else {
   168  			for _, p := range list {
   169  				previousParts[p.N] = p
   170  			}
   171  		}
   172  	}
   173  
   174  	// need to use ectx in everything in eg.Go() funcs since eg
   175  	// will cancel ectx in eg.Wait().
   176  	a.Debug(ctx, "putMultiPipeline: beginning parts uploader process")
   177  	eg, ectx := errgroup.WithContext(ctx)
   178  	blockCh := make(chan job)
   179  	retCh := make(chan s3.Part)
   180  	eg.Go(func() error {
   181  		defer close(blockCh)
   182  		return a.makeBlockJobs(ectx, r, blockCh, task.stashKey(), previous)
   183  	})
   184  	eg.Go(func() error {
   185  		for lb := range blockCh {
   186  			if err := s3UploadPipeline.QueueForTakeoff(ectx); err != nil {
   187  				return err
   188  			}
   189  			b := lb
   190  			eg.Go(func() error {
   191  				defer s3UploadPipeline.Complete()
   192  				if err := a.uploadPart(ectx, task, b, previous, previousParts, multi, retCh); err != nil {
   193  					return err
   194  				}
   195  				return nil
   196  			})
   197  		}
   198  		return nil
   199  	})
   200  	go func() {
   201  		err := eg.Wait()
   202  		if err != nil {
   203  			a.Debug(ctx, "putMultiPipeline: error waiting: %+v", err)
   204  		}
   205  		close(retCh)
   206  	}()
   207  
   208  	var parts []s3.Part
   209  	progWriter := progress.NewProgressWriter(task.Progress, size)
   210  	for p := range retCh {
   211  		parts = append(parts, p)
   212  		progWriter.Update(int(p.Size))
   213  	}
   214  	if err := eg.Wait(); err != nil {
   215  		return "", err
   216  	}
   217  
   218  	if a.blockLimit > 0 {
   219  		return "", errors.New("block limit hit, not completing multi upload")
   220  	}
   221  	a.Debug(ctx, "putMultiPipeline: all parts uploaded, completing request")
   222  	if err := multi.Complete(ctx, parts); err != nil {
   223  		a.Debug(ctx, "putMultiPipeline: Complete() failed: %s", err)
   224  		return "", err
   225  	}
   226  	a.Debug(ctx, "putMultiPipeline: success, %d parts", len(parts))
   227  	// Just to make sure the UI gets the 100% call
   228  	progWriter.Finish()
   229  	return task.S3Params.ObjectKey, nil
   230  }
   231  
   232  type job struct {
   233  	block []byte
   234  	index int
   235  	hash  string
   236  }
   237  
   238  func (j job) etag() string {
   239  	return `"` + j.hash + `"`
   240  }
   241  
   242  // makeBlockJobs reads ciphertext chunks from r and creates jobs that it puts onto blockCh.
   243  // If this is a resumed upload, it verifies the blocks against the local stash before
   244  // creating jobs.
   245  func (a *S3Store) makeBlockJobs(ctx context.Context, r io.Reader, blockCh chan job, stashKey StashKey, previous *AttachmentInfo) error {
   246  	var partNumber int
   247  	for {
   248  		partNumber++
   249  		block := make([]byte, blockSize)
   250  		// must call io.ReadFull to ensure full block read
   251  		n, err := io.ReadFull(r, block)
   252  		// io.ErrUnexpectedEOF will be returned for last partial block,
   253  		// which is ok.
   254  		if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
   255  			return err
   256  		}
   257  		if n < blockSize {
   258  			block = block[:n]
   259  		}
   260  		if n > 0 {
   261  			md5sum := md5.Sum(block)
   262  			md5hex := hex.EncodeToString(md5sum[:])
   263  
   264  			if previous != nil {
   265  				// resuming an upload, so check local stash record
   266  				// and abort on mismatch before adding a job for this block
   267  				// because if we don't it amounts to nonce reuse
   268  				lhash, found := previous.Parts[partNumber]
   269  				if found && lhash != md5hex {
   270  					a.Debug(ctx, "makeBlockJobs: part %d failed local part record verification", partNumber)
   271  					return ErrAbortOnPartMismatch
   272  				}
   273  			}
   274  
   275  			if err := a.addJob(ctx, blockCh, block, partNumber, md5hex); err != nil {
   276  				return err
   277  			}
   278  		}
   279  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   280  			break
   281  		}
   282  
   283  		if a.blockLimit > 0 && partNumber >= a.blockLimit {
   284  			a.Debug(ctx, "makeBlockJobs: hit blockLimit of %d", a.blockLimit)
   285  			break
   286  		}
   287  	}
   288  	return nil
   289  }
   290  
   291  // addJob creates a job and puts it on blockCh, unless the blockCh isn't ready and the context has been canceled.
   292  func (a *S3Store) addJob(ctx context.Context, blockCh chan job, block []byte, partNumber int, hash string) error {
   293  	// Create a job, unless the context has been canceled.
   294  	select {
   295  	case blockCh <- job{block: block, index: partNumber, hash: hash}:
   296  		return nil
   297  	case <-ctx.Done():
   298  		return ctx.Err()
   299  	}
   300  }
   301  
   302  // uploadPart handles uploading a job to S3.  The job `b` has already passed local stash verification.
   303  // If this is a resumed upload, it checks the previous parts reported by S3 and will skip uploading
   304  // any that already exist.
   305  func (a *S3Store) uploadPart(ctx context.Context, task *UploadTask, b job, previous *AttachmentInfo, previousParts map[int]s3.Part, multi s3.MultiInt, retCh chan s3.Part) (err error) {
   306  	defer a.Trace(ctx, &err, fmt.Sprintf("uploadPart(%d)", b.index))()
   307  
   308  	// check to see if this part has already been uploaded.
   309  	// for job `b` to be here, it has already passed local stash verification.
   310  	if previous != nil {
   311  		// check s3 previousParts for this block
   312  		p, ok := previousParts[b.index]
   313  		if ok && int(p.Size) == len(b.block) && p.ETag == b.etag() {
   314  			a.Debug(ctx, "uploadPart: part %d already uploaded to s3", b.index)
   315  
   316  			// part already uploaded, so put it in the retCh unless the context
   317  			// has been canceled
   318  			select {
   319  			case retCh <- p:
   320  			case <-ctx.Done():
   321  				return ctx.Err()
   322  			}
   323  
   324  			// nothing else to do
   325  			return nil
   326  		}
   327  
   328  		if p.Size > 0 {
   329  			// only abort if the part size from s3 is > 0.
   330  			a.Debug(ctx, "uploadPart: part %d s3 mismatch:  size %d != expected %d or etag %s != expected %s",
   331  				b.index, p.Size, len(b.block), p.ETag, b.etag())
   332  			return ErrAbortOnPartMismatch
   333  		}
   334  
   335  		// this part doesn't exist on s3, so it needs to be uploaded
   336  		a.Debug(ctx, "uploadPart: part %d not uploaded to s3 by previous upload attempt", b.index)
   337  	}
   338  
   339  	// stash part info locally before attempting S3 put
   340  	// doing this before attempting the S3 put is important
   341  	// for security concerns.
   342  	if err := a.stash.RecordPart(task.stashKey(), b.index, b.hash); err != nil {
   343  		a.Debug(ctx, "uploadPart: StashRecordPart error: %s", err)
   344  	}
   345  
   346  	part, putErr := multi.PutPart(ctx, b.index, bytes.NewReader(b.block))
   347  	if putErr != nil {
   348  		return NewErrorWrapper(fmt.Sprintf("failed to put part %d", b.index), putErr)
   349  	}
   350  
   351  	// put the successfully uploaded part information in the retCh
   352  	// unless the context has been canceled.
   353  	select {
   354  	case retCh <- part:
   355  	case <-ctx.Done():
   356  		a.Debug(ctx, "uploadPart: upload part %d, context canceled", b.index)
   357  		return ctx.Err()
   358  	}
   359  
   360  	return nil
   361  }
   362  
   363  type ErrorWrapper struct {
   364  	prefix string
   365  	err    error
   366  }
   367  
   368  func NewErrorWrapper(prefix string, err error) *ErrorWrapper {
   369  	return &ErrorWrapper{prefix: prefix, err: err}
   370  }
   371  
   372  func (e *ErrorWrapper) Error() string {
   373  	return fmt.Sprintf("%s: %s (%T)", e.prefix, e.err, e.err)
   374  }
   375  
   376  func (e *ErrorWrapper) Details() string {
   377  	switch err := e.err.(type) {
   378  	case *s3.Error:
   379  		return fmt.Sprintf("%s: error %q, status code: %d, code: %s, message: %s, bucket: %s", e.prefix, e.err, err.StatusCode, err.Code, err.Message, err.BucketName)
   380  	default:
   381  		return fmt.Sprintf("%s: error %q, no details for type %T", e.prefix, e.err, e.err)
   382  	}
   383  }
   384  
   385  type S3Signer struct {
   386  	ri func() chat1.RemoteInterface
   387  }
   388  
   389  func NewS3Signer(ri func() chat1.RemoteInterface) *S3Signer {
   390  	return &S3Signer{
   391  		ri: ri,
   392  	}
   393  }
   394  
   395  // Sign implements github.com/keybase/go/chat/s3.Signer interface.
   396  func (s *S3Signer) Sign(payload []byte) ([]byte, error) {
   397  	arg := chat1.S3SignArg{
   398  		Payload: payload,
   399  		Version: 1,
   400  	}
   401  	return s.ri().S3Sign(context.Background(), arg)
   402  }