github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/file/s3file/file_write.go (about) 1 package s3file 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "sort" 8 "sync" 9 "time" 10 11 "github.com/aws/aws-sdk-go/aws" 12 "github.com/aws/aws-sdk-go/service/s3" 13 "github.com/aws/aws-sdk-go/service/s3/s3iface" 14 "github.com/Schaudge/grailbase/errors" 15 "github.com/Schaudge/grailbase/file" 16 "github.com/Schaudge/grailbase/log" 17 ) 18 19 // A helper class for driving s3manager.Uploader through an io.Writer-like 20 // interface. Its write() method will feed data incrementally to the uploader, 21 // and finish() will wait for all the uploads to finish. 22 type s3Uploader struct { 23 ctx context.Context 24 client s3iface.S3API 25 path, bucket, key string 26 opts file.Opts 27 s3opts Options 28 uploadID string 29 createTime time.Time // time of file.Create() call 30 // curBuf is only accessed by the handleRequests thread. 31 curBuf *[]byte 32 nextPartNum int64 33 34 bufPool sync.Pool 35 reqCh chan uploadChunk 36 err errors.Once 37 sg sync.WaitGroup 38 mu sync.Mutex 39 parts []*s3.CompletedPart 40 } 41 42 type uploadChunk struct { 43 client s3iface.S3API 44 uploadID string 45 partNum int64 46 buf *[]byte 47 } 48 49 const uploadParallelism = 16 50 51 // UploadPartSize is the size of a chunk during multi-part uploads. It is 52 // exposed only for unittests. 53 var UploadPartSize = 16 << 20 54 55 func newUploader(ctx context.Context, clientsForAction clientsForActionFunc, opts Options, path, bucket, key string, fileOpts file.Opts) (*s3Uploader, error) { 56 clients, err := clientsForAction(ctx, "PutObject", bucket, key) 57 if err != nil { 58 return nil, errors.E(err, "s3file.write", path) 59 } 60 params := &s3.CreateMultipartUploadInput{ 61 Bucket: aws.String(bucket), 62 Key: aws.String(key), 63 } 64 // Add any non-default options 65 if opts.ServerSideEncryption != "" { 66 params.SetServerSideEncryption(opts.ServerSideEncryption) 67 } 68 69 u := &s3Uploader{ 70 ctx: ctx, 71 path: path, 72 bucket: bucket, 73 key: key, 74 opts: fileOpts, 75 s3opts: opts, 76 createTime: time.Now(), 77 bufPool: sync.Pool{New: func() interface{} { slice := make([]byte, UploadPartSize); return &slice }}, 78 nextPartNum: 1, 79 } 80 policy := newBackoffPolicy(clients, file.Opts{}) 81 for { 82 var ids s3RequestIDs 83 resp, err := policy.client().CreateMultipartUploadWithContext(ctx, 84 params, ids.captureOption()) 85 if policy.shouldRetry(ctx, err, path) { 86 continue 87 } 88 if err != nil { 89 return nil, annotate(err, ids, &policy, "s3file.CreateMultipartUploadWithContext", path) 90 } 91 u.client = policy.client() 92 u.uploadID = *resp.UploadId 93 if u.uploadID == "" { 94 panic(fmt.Sprintf("empty uploadID: %+v, awsrequestID: %v", resp, ids)) 95 } 96 break 97 } 98 99 u.reqCh = make(chan uploadChunk, uploadParallelism) 100 for i := 0; i < uploadParallelism; i++ { 101 u.sg.Add(1) 102 go u.uploadThread() 103 } 104 return u, nil 105 } 106 107 func (u *s3Uploader) uploadThread() { 108 defer u.sg.Done() 109 for chunk := range u.reqCh { 110 policy := newBackoffPolicy([]s3iface.S3API{chunk.client}, file.Opts{}) 111 retry: 112 params := &s3.UploadPartInput{ 113 Bucket: aws.String(u.bucket), 114 Key: aws.String(u.key), 115 Body: bytes.NewReader(*chunk.buf), 116 UploadId: aws.String(chunk.uploadID), 117 PartNumber: &chunk.partNum, 118 } 119 var ids s3RequestIDs 120 resp, err := chunk.client.UploadPartWithContext(u.ctx, params, ids.captureOption()) 121 if policy.shouldRetry(u.ctx, err, u.path) { 122 goto retry 123 } 124 u.bufPool.Put(chunk.buf) 125 if err != nil { 126 u.err.Set(annotate(err, ids, &policy, fmt.Sprintf("s3file.UploadPartWithContext s3://%s/%s", u.bucket, u.key))) 127 continue 128 } 129 partNum := chunk.partNum 130 completed := &s3.CompletedPart{ETag: resp.ETag, PartNumber: &partNum} 131 u.mu.Lock() 132 u.parts = append(u.parts, completed) 133 u.mu.Unlock() 134 } 135 } 136 137 // write appends data to file. It can be called only by the request thread. 138 func (u *s3Uploader) write(buf []byte) { 139 if len(buf) == 0 { 140 panic("empty buf in write") 141 } 142 for len(buf) > 0 { 143 if u.curBuf == nil { 144 u.curBuf = u.bufPool.Get().(*[]byte) 145 *u.curBuf = (*u.curBuf)[:0] 146 } 147 if cap(*u.curBuf) != UploadPartSize { 148 panic("empty buf") 149 } 150 uploadBuf := *u.curBuf 151 space := uploadBuf[len(uploadBuf):cap(uploadBuf)] 152 n := len(buf) 153 if n < len(space) { 154 copy(space, buf) 155 *u.curBuf = uploadBuf[0 : len(uploadBuf)+n] 156 return 157 } 158 copy(space, buf) 159 buf = buf[len(space):] 160 *u.curBuf = uploadBuf[0:cap(uploadBuf)] 161 u.reqCh <- uploadChunk{client: u.client, uploadID: u.uploadID, partNum: u.nextPartNum, buf: u.curBuf} 162 u.nextPartNum++ 163 u.curBuf = nil 164 } 165 } 166 167 func (u *s3Uploader) abort() error { 168 policy := newBackoffPolicy([]s3iface.S3API{u.client}, file.Opts{}) 169 for { 170 var ids s3RequestIDs 171 _, err := u.client.AbortMultipartUploadWithContext(u.ctx, &s3.AbortMultipartUploadInput{ 172 Bucket: aws.String(u.bucket), 173 Key: aws.String(u.key), 174 UploadId: aws.String(u.uploadID), 175 }, ids.captureOption()) 176 if !policy.shouldRetry(u.ctx, err, u.path) { 177 if err != nil { 178 err = annotate(err, ids, &policy, fmt.Sprintf("s3file.AbortMultiPartUploadWithContext s3://%s/%s", u.bucket, u.key)) 179 } 180 return err 181 } 182 } 183 } 184 185 // finish finishes writing. It can be called only by the request thread. 186 func (u *s3Uploader) finish() error { 187 if u.curBuf != nil && len(*u.curBuf) > 0 { 188 u.reqCh <- uploadChunk{client: u.client, uploadID: u.uploadID, partNum: u.nextPartNum, buf: u.curBuf} 189 u.curBuf = nil 190 } 191 close(u.reqCh) 192 u.sg.Wait() 193 policy := newBackoffPolicy([]s3iface.S3API{u.client}, file.Opts{}) 194 if err := u.err.Err(); err != nil { 195 u.abort() // nolint: errcheck 196 return err 197 } 198 if len(u.parts) == 0 { 199 // Special case: an empty file. CompleteMultiPartUpload with empty parts causes an error, 200 // so work around the bug by issuing a separate PutObject request. 201 u.abort() // nolint: errcheck 202 for { 203 input := &s3.PutObjectInput{ 204 Bucket: aws.String(u.bucket), 205 Key: aws.String(u.key), 206 Body: bytes.NewReader(nil), 207 } 208 if u.s3opts.ServerSideEncryption != "" { 209 input.SetServerSideEncryption(u.s3opts.ServerSideEncryption) 210 } 211 212 var ids s3RequestIDs 213 _, err := u.client.PutObjectWithContext(u.ctx, input, ids.captureOption()) 214 if !policy.shouldRetry(u.ctx, err, u.path) { 215 if err != nil { 216 err = annotate(err, ids, &policy, fmt.Sprintf("s3file.PutObjectWithContext s3://%s/%s", u.bucket, u.key)) 217 } 218 u.err.Set(err) 219 break 220 } 221 } 222 return u.err.Err() 223 } 224 // Common case. Complete the multi-part upload. 225 closeStartTime := time.Now() 226 sort.Slice(u.parts, func(i, j int) bool { // Parts must be sorted in PartNumber order. 227 return *u.parts[i].PartNumber < *u.parts[j].PartNumber 228 }) 229 params := &s3.CompleteMultipartUploadInput{ 230 Bucket: aws.String(u.bucket), 231 Key: aws.String(u.key), 232 UploadId: aws.String(u.uploadID), 233 MultipartUpload: &s3.CompletedMultipartUpload{Parts: u.parts}, 234 } 235 for { 236 var ids s3RequestIDs 237 _, err := u.client.CompleteMultipartUploadWithContext(u.ctx, params, ids.captureOption()) 238 if aerr, ok := getAWSError(err); ok && aerr.Code() == "NoSuchUpload" { 239 if u.opts.IgnoreNoSuchUpload { 240 // Here we managed to upload >=1 part, so the uploadID must have been 241 // valid some point in the past. 242 // 243 // TODO(saito) we could check that upload isn't too old (say <= 7 days), 244 // or that the file actually exists. 245 log.Error.Printf("close %s: IgnoreNoSuchUpload is set; ignoring %v %+v", u.path, err, ids) 246 err = nil 247 } 248 } 249 if !policy.shouldRetry(u.ctx, err, u.path) { 250 if err != nil { 251 err = annotate(err, ids, &policy, 252 fmt.Sprintf("s3file.CompleteMultipartUploadWithContext s3://%s/%s, "+ 253 "created at %v, started closing at %v, failed at %v", 254 u.bucket, u.key, u.createTime, closeStartTime, time.Now())) 255 } 256 u.err.Set(err) 257 break 258 } 259 } 260 if u.err.Err() != nil { 261 u.abort() // nolint: errcheck 262 } 263 return u.err.Err() 264 }