github.com/grailbio/base@v0.0.11/file/s3file/file_write.go (about)

     1  package s3file
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"sort"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/aws/aws-sdk-go/aws"
    12  	"github.com/aws/aws-sdk-go/service/s3"
    13  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    14  	"github.com/grailbio/base/errors"
    15  	"github.com/grailbio/base/file"
    16  	"github.com/grailbio/base/log"
    17  )
    18  
    19  // A helper class for driving s3manager.Uploader through an io.Writer-like
    20  // interface. Its write() method will feed data incrementally to the uploader,
    21  // and finish() will wait for all the uploads to finish.
    22  type s3Uploader struct {
    23  	ctx               context.Context
    24  	client            s3iface.S3API
    25  	path, bucket, key string
    26  	opts              file.Opts
    27  	s3opts            Options
    28  	uploadID          string
    29  	createTime        time.Time // time of file.Create() call
    30  	// curBuf is only accessed by the handleRequests thread.
    31  	curBuf      *[]byte
    32  	nextPartNum int64
    33  
    34  	bufPool sync.Pool
    35  	reqCh   chan uploadChunk
    36  	err     errors.Once
    37  	sg      sync.WaitGroup
    38  	mu      sync.Mutex
    39  	parts   []*s3.CompletedPart
    40  }
    41  
    42  type uploadChunk struct {
    43  	client   s3iface.S3API
    44  	uploadID string
    45  	partNum  int64
    46  	buf      *[]byte
    47  }
    48  
    49  const uploadParallelism = 16
    50  
    51  // UploadPartSize is the size of a chunk during multi-part uploads.  It is
    52  // exposed only for unittests.
    53  var UploadPartSize = 16 << 20
    54  
    55  func newUploader(ctx context.Context, clientsForAction clientsForActionFunc, opts Options, path, bucket, key string, fileOpts file.Opts) (*s3Uploader, error) {
    56  	clients, err := clientsForAction(ctx, "PutObject", bucket, key)
    57  	if err != nil {
    58  		return nil, errors.E(err, "s3file.write", path)
    59  	}
    60  	params := &s3.CreateMultipartUploadInput{
    61  		Bucket: aws.String(bucket),
    62  		Key:    aws.String(key),
    63  	}
    64  	// Add any non-default options
    65  	if opts.ServerSideEncryption != "" {
    66  		params.SetServerSideEncryption(opts.ServerSideEncryption)
    67  	}
    68  
    69  	u := &s3Uploader{
    70  		ctx:         ctx,
    71  		path:        path,
    72  		bucket:      bucket,
    73  		key:         key,
    74  		opts:        fileOpts,
    75  		s3opts:      opts,
    76  		createTime:  time.Now(),
    77  		bufPool:     sync.Pool{New: func() interface{} { slice := make([]byte, UploadPartSize); return &slice }},
    78  		nextPartNum: 1,
    79  	}
    80  	policy := newBackoffPolicy(clients, file.Opts{})
    81  	for {
    82  		var ids s3RequestIDs
    83  		resp, err := policy.client().CreateMultipartUploadWithContext(ctx,
    84  			params, ids.captureOption())
    85  		if policy.shouldRetry(ctx, err, path) {
    86  			continue
    87  		}
    88  		if err != nil {
    89  			return nil, annotate(err, ids, &policy, "s3file.CreateMultipartUploadWithContext", path)
    90  		}
    91  		u.client = policy.client()
    92  		u.uploadID = *resp.UploadId
    93  		if u.uploadID == "" {
    94  			panic(fmt.Sprintf("empty uploadID: %+v, awsrequestID: %v", resp, ids))
    95  		}
    96  		break
    97  	}
    98  
    99  	u.reqCh = make(chan uploadChunk, uploadParallelism)
   100  	for i := 0; i < uploadParallelism; i++ {
   101  		u.sg.Add(1)
   102  		go u.uploadThread()
   103  	}
   104  	return u, nil
   105  }
   106  
   107  func (u *s3Uploader) uploadThread() {
   108  	defer u.sg.Done()
   109  	for chunk := range u.reqCh {
   110  		policy := newBackoffPolicy([]s3iface.S3API{chunk.client}, file.Opts{})
   111  	retry:
   112  		params := &s3.UploadPartInput{
   113  			Bucket:     aws.String(u.bucket),
   114  			Key:        aws.String(u.key),
   115  			Body:       bytes.NewReader(*chunk.buf),
   116  			UploadId:   aws.String(chunk.uploadID),
   117  			PartNumber: &chunk.partNum,
   118  		}
   119  		var ids s3RequestIDs
   120  		resp, err := chunk.client.UploadPartWithContext(u.ctx, params, ids.captureOption())
   121  		if policy.shouldRetry(u.ctx, err, u.path) {
   122  			goto retry
   123  		}
   124  		u.bufPool.Put(chunk.buf)
   125  		if err != nil {
   126  			u.err.Set(annotate(err, ids, &policy, fmt.Sprintf("s3file.UploadPartWithContext s3://%s/%s", u.bucket, u.key)))
   127  			continue
   128  		}
   129  		partNum := chunk.partNum
   130  		completed := &s3.CompletedPart{ETag: resp.ETag, PartNumber: &partNum}
   131  		u.mu.Lock()
   132  		u.parts = append(u.parts, completed)
   133  		u.mu.Unlock()
   134  	}
   135  }
   136  
   137  // write appends data to file. It can be called only by the request thread.
   138  func (u *s3Uploader) write(buf []byte) {
   139  	if len(buf) == 0 {
   140  		panic("empty buf in write")
   141  	}
   142  	for len(buf) > 0 {
   143  		if u.curBuf == nil {
   144  			u.curBuf = u.bufPool.Get().(*[]byte)
   145  			*u.curBuf = (*u.curBuf)[:0]
   146  		}
   147  		if cap(*u.curBuf) != UploadPartSize {
   148  			panic("empty buf")
   149  		}
   150  		uploadBuf := *u.curBuf
   151  		space := uploadBuf[len(uploadBuf):cap(uploadBuf)]
   152  		n := len(buf)
   153  		if n < len(space) {
   154  			copy(space, buf)
   155  			*u.curBuf = uploadBuf[0 : len(uploadBuf)+n]
   156  			return
   157  		}
   158  		copy(space, buf)
   159  		buf = buf[len(space):]
   160  		*u.curBuf = uploadBuf[0:cap(uploadBuf)]
   161  		u.reqCh <- uploadChunk{client: u.client, uploadID: u.uploadID, partNum: u.nextPartNum, buf: u.curBuf}
   162  		u.nextPartNum++
   163  		u.curBuf = nil
   164  	}
   165  }
   166  
   167  func (u *s3Uploader) abort() error {
   168  	policy := newBackoffPolicy([]s3iface.S3API{u.client}, file.Opts{})
   169  	for {
   170  		var ids s3RequestIDs
   171  		_, err := u.client.AbortMultipartUploadWithContext(u.ctx, &s3.AbortMultipartUploadInput{
   172  			Bucket:   aws.String(u.bucket),
   173  			Key:      aws.String(u.key),
   174  			UploadId: aws.String(u.uploadID),
   175  		}, ids.captureOption())
   176  		if !policy.shouldRetry(u.ctx, err, u.path) {
   177  			if err != nil {
   178  				err = annotate(err, ids, &policy, fmt.Sprintf("s3file.AbortMultiPartUploadWithContext s3://%s/%s", u.bucket, u.key))
   179  			}
   180  			return err
   181  		}
   182  	}
   183  }
   184  
   185  // finish finishes writing. It can be called only by the request thread.
   186  func (u *s3Uploader) finish() error {
   187  	if u.curBuf != nil && len(*u.curBuf) > 0 {
   188  		u.reqCh <- uploadChunk{client: u.client, uploadID: u.uploadID, partNum: u.nextPartNum, buf: u.curBuf}
   189  		u.curBuf = nil
   190  	}
   191  	close(u.reqCh)
   192  	u.sg.Wait()
   193  	policy := newBackoffPolicy([]s3iface.S3API{u.client}, file.Opts{})
   194  	if err := u.err.Err(); err != nil {
   195  		u.abort() // nolint: errcheck
   196  		return err
   197  	}
   198  	if len(u.parts) == 0 {
   199  		// Special case: an empty file. CompleteMultiPartUpload with empty parts causes an error,
   200  		// so work around the bug by issuing a separate PutObject request.
   201  		u.abort() // nolint: errcheck
   202  		for {
   203  			input := &s3.PutObjectInput{
   204  				Bucket: aws.String(u.bucket),
   205  				Key:    aws.String(u.key),
   206  				Body:   bytes.NewReader(nil),
   207  			}
   208  			if u.s3opts.ServerSideEncryption != "" {
   209  				input.SetServerSideEncryption(u.s3opts.ServerSideEncryption)
   210  			}
   211  
   212  			var ids s3RequestIDs
   213  			_, err := u.client.PutObjectWithContext(u.ctx, input, ids.captureOption())
   214  			if !policy.shouldRetry(u.ctx, err, u.path) {
   215  				if err != nil {
   216  					err = annotate(err, ids, &policy, fmt.Sprintf("s3file.PutObjectWithContext s3://%s/%s", u.bucket, u.key))
   217  				}
   218  				u.err.Set(err)
   219  				break
   220  			}
   221  		}
   222  		return u.err.Err()
   223  	}
   224  	// Common case. Complete the multi-part upload.
   225  	closeStartTime := time.Now()
   226  	sort.Slice(u.parts, func(i, j int) bool { // Parts must be sorted in PartNumber order.
   227  		return *u.parts[i].PartNumber < *u.parts[j].PartNumber
   228  	})
   229  	params := &s3.CompleteMultipartUploadInput{
   230  		Bucket:          aws.String(u.bucket),
   231  		Key:             aws.String(u.key),
   232  		UploadId:        aws.String(u.uploadID),
   233  		MultipartUpload: &s3.CompletedMultipartUpload{Parts: u.parts},
   234  	}
   235  	for {
   236  		var ids s3RequestIDs
   237  		_, err := u.client.CompleteMultipartUploadWithContext(u.ctx, params, ids.captureOption())
   238  		if aerr, ok := getAWSError(err); ok && aerr.Code() == "NoSuchUpload" {
   239  			if u.opts.IgnoreNoSuchUpload {
   240  				// Here we managed to upload >=1 part, so the uploadID must have been
   241  				// valid some point in the past.
   242  				//
   243  				// TODO(saito) we could check that upload isn't too old (say <= 7 days),
   244  				// or that the file actually exists.
   245  				log.Error.Printf("close %s: IgnoreNoSuchUpload is set; ignoring %v %+v", u.path, err, ids)
   246  				err = nil
   247  			}
   248  		}
   249  		if !policy.shouldRetry(u.ctx, err, u.path) {
   250  			if err != nil {
   251  				err = annotate(err, ids, &policy,
   252  					fmt.Sprintf("s3file.CompleteMultipartUploadWithContext s3://%s/%s, "+
   253  						"created at %v, started closing at %v, failed at %v",
   254  						u.bucket, u.key, u.createTime, closeStartTime, time.Now()))
   255  			}
   256  			u.err.Set(err)
   257  			break
   258  		}
   259  	}
   260  	if u.err.Err() != nil {
   261  		u.abort() // nolint: errcheck
   262  	}
   263  	return u.err.Err()
   264  }