github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/chat/attachments/s3.go (about) 1 package attachments 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/md5" 7 "encoding/hex" 8 "errors" 9 "fmt" 10 "io" 11 "sync" 12 13 "github.com/keybase/client/go/chat/attachments/progress" 14 "github.com/keybase/client/go/chat/s3" 15 "github.com/keybase/client/go/chat/types" 16 "github.com/keybase/client/go/protocol/chat1" 17 "golang.org/x/sync/errgroup" 18 ) 19 20 const s3PipelineMaxWidth = 10 21 22 type s3UploadPipeliner struct { 23 sync.Mutex 24 width int 25 waiters []chan struct{} 26 } 27 28 func (s *s3UploadPipeliner) QueueForTakeoff(ctx context.Context) error { 29 s.Lock() 30 if s.width >= s3PipelineMaxWidth { 31 ch := make(chan struct{}) 32 s.waiters = append(s.waiters, ch) 33 s.Unlock() 34 select { 35 case <-ch: 36 case <-ctx.Done(): 37 return ctx.Err() 38 } 39 s.Lock() 40 s.width++ 41 s.Unlock() 42 return nil 43 } 44 s.width++ 45 s.Unlock() 46 return nil 47 } 48 49 func (s *s3UploadPipeliner) Complete() { 50 s.Lock() 51 defer s.Unlock() 52 if len(s.waiters) > 0 { 53 close(s.waiters[0]) 54 if len(s.waiters) > 1 { 55 s.waiters = s.waiters[1:] 56 } else { 57 s.waiters = nil 58 } 59 } 60 if s.width > 0 { 61 s.width-- 62 } 63 } 64 65 var s3UploadPipeline = &s3UploadPipeliner{} 66 67 const minMultiSize = 5 * 1024 * 1024 // can't use Multi API with parts less than 5MB 68 const blockSize = 5 * 1024 * 1024 // 5MB is the minimum Multi part size 69 70 // ErrAbortOnPartMismatch is returned when there is a mismatch between a current 71 // part and a previous attempt part. If ErrAbortOnPartMismatch is returned, 72 // the caller should abort the upload attempt and start from scratch. 73 var ErrAbortOnPartMismatch = errors.New("local part mismatch, aborting upload") 74 75 // PutS3Result is the success result of calling PutS3. 76 type PutS3Result struct { 77 Region string 78 Endpoint string 79 Bucket string 80 Path string 81 Size int64 82 } 83 84 // PutS3 uploads the data in Reader r to S3. It chooses whether to use 85 // putSingle or putMultiPipeline based on the size of the object. 86 func (a *S3Store) PutS3(ctx context.Context, r io.Reader, size int64, task *UploadTask, previous *AttachmentInfo) (res *PutS3Result, err error) { 87 defer a.Trace(ctx, &err, "PutS3")() 88 region := a.regionFromParams(task.S3Params) 89 b := a.s3Conn(task.S3Signer, region, task.S3Params.AccessKey).Bucket(task.S3Params.Bucket) 90 91 multiPartUpload := size > minMultiSize 92 if multiPartUpload && a.G().Env.GetAttachmentDisableMulti() { 93 a.Debug(ctx, "PutS3: multi part upload manually disabled, overriding for size: %v", size) 94 multiPartUpload = false 95 } 96 97 if !multiPartUpload { 98 if err := a.putSingle(ctx, r, size, task.S3Params, b, task.Progress); err != nil { 99 return nil, err 100 } 101 } else { 102 objectKey, err := a.putMultiPipeline(ctx, r, size, task, b, previous) 103 if err != nil { 104 return nil, err 105 } 106 task.S3Params.ObjectKey = objectKey 107 } 108 109 s3res := PutS3Result{ 110 Region: task.S3Params.RegionName, 111 Endpoint: task.S3Params.RegionEndpoint, 112 Bucket: task.S3Params.Bucket, 113 Path: task.S3Params.ObjectKey, 114 Size: size, 115 } 116 return &s3res, nil 117 } 118 119 // putSingle uploads data in r to S3 with the Put API. It has to be 120 // used for anything less than 5MB. It can be used for anything up 121 // to 5GB, but putMultiPipeline best for anything over 5MB. 122 func (a *S3Store) putSingle(ctx context.Context, r io.Reader, size int64, params chat1.S3Params, 123 b s3.BucketInt, progressReporter types.ProgressReporter) (err error) { 124 defer a.Trace(ctx, &err, fmt.Sprintf("putSingle(size=%d)", size))() 125 126 progWriter := progress.NewProgressWriter(progressReporter, size) 127 tee := io.TeeReader(r, progWriter) 128 129 if err := b.PutReader(ctx, params.ObjectKey, tee, size, "application/octet-stream", s3.ACL(params.Acl), 130 s3.Options{}); err != nil { 131 a.Debug(ctx, "putSingle: failed: %s", err) 132 return NewErrorWrapper("failed putSingle", err) 133 } 134 progWriter.Finish() 135 return nil 136 } 137 138 // putMultiPipeline uploads data in r to S3 using the Multi API. It uses a 139 // pipeline to upload 10 blocks of data concurrently. 140 // Each block is 5MB. It returns the object key if no errors. putMultiPipeline 141 // will return a different object key from params.ObjectKey if a previous Put is 142 // successfully resumed and completed. 143 func (a *S3Store) putMultiPipeline(ctx context.Context, r io.Reader, size int64, task *UploadTask, b s3.BucketInt, previous *AttachmentInfo) (res string, err error) { 144 defer a.Trace(ctx, &err, fmt.Sprintf("putMultiPipeline(size=%d)", size))() 145 146 var multi s3.MultiInt 147 if previous != nil { 148 a.Debug(ctx, "putMultiPipeline: previous exists. Changing object key from %q to %q", 149 task.S3Params.ObjectKey, previous.ObjectKey) 150 task.S3Params.ObjectKey = previous.ObjectKey 151 } 152 153 multi, err = b.Multi(ctx, task.S3Params.ObjectKey, "application/octet-stream", s3.ACL(task.S3Params.Acl)) 154 if err != nil { 155 a.Debug(ctx, "putMultiPipeline: b.Multi error: %s", err.Error()) 156 return "", NewErrorWrapper("s3 Multi error", err) 157 } 158 159 var previousParts map[int]s3.Part 160 if previous != nil { 161 previousParts = make(map[int]s3.Part) 162 list, err := multi.ListParts(ctx) 163 if err != nil { 164 a.Debug(ctx, "putMultiPipeline: ignoring multi.ListParts error: %s", err) 165 // dump previous since we can't check it anymore 166 previous = nil 167 } else { 168 for _, p := range list { 169 previousParts[p.N] = p 170 } 171 } 172 } 173 174 // need to use ectx in everything in eg.Go() funcs since eg 175 // will cancel ectx in eg.Wait(). 176 a.Debug(ctx, "putMultiPipeline: beginning parts uploader process") 177 eg, ectx := errgroup.WithContext(ctx) 178 blockCh := make(chan job) 179 retCh := make(chan s3.Part) 180 eg.Go(func() error { 181 defer close(blockCh) 182 return a.makeBlockJobs(ectx, r, blockCh, task.stashKey(), previous) 183 }) 184 eg.Go(func() error { 185 for lb := range blockCh { 186 if err := s3UploadPipeline.QueueForTakeoff(ectx); err != nil { 187 return err 188 } 189 b := lb 190 eg.Go(func() error { 191 defer s3UploadPipeline.Complete() 192 if err := a.uploadPart(ectx, task, b, previous, previousParts, multi, retCh); err != nil { 193 return err 194 } 195 return nil 196 }) 197 } 198 return nil 199 }) 200 go func() { 201 err := eg.Wait() 202 if err != nil { 203 a.Debug(ctx, "putMultiPipeline: error waiting: %+v", err) 204 } 205 close(retCh) 206 }() 207 208 var parts []s3.Part 209 progWriter := progress.NewProgressWriter(task.Progress, size) 210 for p := range retCh { 211 parts = append(parts, p) 212 progWriter.Update(int(p.Size)) 213 } 214 if err := eg.Wait(); err != nil { 215 return "", err 216 } 217 218 if a.blockLimit > 0 { 219 return "", errors.New("block limit hit, not completing multi upload") 220 } 221 a.Debug(ctx, "putMultiPipeline: all parts uploaded, completing request") 222 if err := multi.Complete(ctx, parts); err != nil { 223 a.Debug(ctx, "putMultiPipeline: Complete() failed: %s", err) 224 return "", err 225 } 226 a.Debug(ctx, "putMultiPipeline: success, %d parts", len(parts)) 227 // Just to make sure the UI gets the 100% call 228 progWriter.Finish() 229 return task.S3Params.ObjectKey, nil 230 } 231 232 type job struct { 233 block []byte 234 index int 235 hash string 236 } 237 238 func (j job) etag() string { 239 return `"` + j.hash + `"` 240 } 241 242 // makeBlockJobs reads ciphertext chunks from r and creates jobs that it puts onto blockCh. 243 // If this is a resumed upload, it verifies the blocks against the local stash before 244 // creating jobs. 245 func (a *S3Store) makeBlockJobs(ctx context.Context, r io.Reader, blockCh chan job, stashKey StashKey, previous *AttachmentInfo) error { 246 var partNumber int 247 for { 248 partNumber++ 249 block := make([]byte, blockSize) 250 // must call io.ReadFull to ensure full block read 251 n, err := io.ReadFull(r, block) 252 // io.ErrUnexpectedEOF will be returned for last partial block, 253 // which is ok. 254 if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { 255 return err 256 } 257 if n < blockSize { 258 block = block[:n] 259 } 260 if n > 0 { 261 md5sum := md5.Sum(block) 262 md5hex := hex.EncodeToString(md5sum[:]) 263 264 if previous != nil { 265 // resuming an upload, so check local stash record 266 // and abort on mismatch before adding a job for this block 267 // because if we don't it amounts to nonce reuse 268 lhash, found := previous.Parts[partNumber] 269 if found && lhash != md5hex { 270 a.Debug(ctx, "makeBlockJobs: part %d failed local part record verification", partNumber) 271 return ErrAbortOnPartMismatch 272 } 273 } 274 275 if err := a.addJob(ctx, blockCh, block, partNumber, md5hex); err != nil { 276 return err 277 } 278 } 279 if err == io.EOF || err == io.ErrUnexpectedEOF { 280 break 281 } 282 283 if a.blockLimit > 0 && partNumber >= a.blockLimit { 284 a.Debug(ctx, "makeBlockJobs: hit blockLimit of %d", a.blockLimit) 285 break 286 } 287 } 288 return nil 289 } 290 291 // addJob creates a job and puts it on blockCh, unless the blockCh isn't ready and the context has been canceled. 292 func (a *S3Store) addJob(ctx context.Context, blockCh chan job, block []byte, partNumber int, hash string) error { 293 // Create a job, unless the context has been canceled. 294 select { 295 case blockCh <- job{block: block, index: partNumber, hash: hash}: 296 return nil 297 case <-ctx.Done(): 298 return ctx.Err() 299 } 300 } 301 302 // uploadPart handles uploading a job to S3. The job `b` has already passed local stash verification. 303 // If this is a resumed upload, it checks the previous parts reported by S3 and will skip uploading 304 // any that already exist. 305 func (a *S3Store) uploadPart(ctx context.Context, task *UploadTask, b job, previous *AttachmentInfo, previousParts map[int]s3.Part, multi s3.MultiInt, retCh chan s3.Part) (err error) { 306 defer a.Trace(ctx, &err, fmt.Sprintf("uploadPart(%d)", b.index))() 307 308 // check to see if this part has already been uploaded. 309 // for job `b` to be here, it has already passed local stash verification. 310 if previous != nil { 311 // check s3 previousParts for this block 312 p, ok := previousParts[b.index] 313 if ok && int(p.Size) == len(b.block) && p.ETag == b.etag() { 314 a.Debug(ctx, "uploadPart: part %d already uploaded to s3", b.index) 315 316 // part already uploaded, so put it in the retCh unless the context 317 // has been canceled 318 select { 319 case retCh <- p: 320 case <-ctx.Done(): 321 return ctx.Err() 322 } 323 324 // nothing else to do 325 return nil 326 } 327 328 if p.Size > 0 { 329 // only abort if the part size from s3 is > 0. 330 a.Debug(ctx, "uploadPart: part %d s3 mismatch: size %d != expected %d or etag %s != expected %s", 331 b.index, p.Size, len(b.block), p.ETag, b.etag()) 332 return ErrAbortOnPartMismatch 333 } 334 335 // this part doesn't exist on s3, so it needs to be uploaded 336 a.Debug(ctx, "uploadPart: part %d not uploaded to s3 by previous upload attempt", b.index) 337 } 338 339 // stash part info locally before attempting S3 put 340 // doing this before attempting the S3 put is important 341 // for security concerns. 342 if err := a.stash.RecordPart(task.stashKey(), b.index, b.hash); err != nil { 343 a.Debug(ctx, "uploadPart: StashRecordPart error: %s", err) 344 } 345 346 part, putErr := multi.PutPart(ctx, b.index, bytes.NewReader(b.block)) 347 if putErr != nil { 348 return NewErrorWrapper(fmt.Sprintf("failed to put part %d", b.index), putErr) 349 } 350 351 // put the successfully uploaded part information in the retCh 352 // unless the context has been canceled. 353 select { 354 case retCh <- part: 355 case <-ctx.Done(): 356 a.Debug(ctx, "uploadPart: upload part %d, context canceled", b.index) 357 return ctx.Err() 358 } 359 360 return nil 361 } 362 363 type ErrorWrapper struct { 364 prefix string 365 err error 366 } 367 368 func NewErrorWrapper(prefix string, err error) *ErrorWrapper { 369 return &ErrorWrapper{prefix: prefix, err: err} 370 } 371 372 func (e *ErrorWrapper) Error() string { 373 return fmt.Sprintf("%s: %s (%T)", e.prefix, e.err, e.err) 374 } 375 376 func (e *ErrorWrapper) Details() string { 377 switch err := e.err.(type) { 378 case *s3.Error: 379 return fmt.Sprintf("%s: error %q, status code: %d, code: %s, message: %s, bucket: %s", e.prefix, e.err, err.StatusCode, err.Code, err.Message, err.BucketName) 380 default: 381 return fmt.Sprintf("%s: error %q, no details for type %T", e.prefix, e.err, e.err) 382 } 383 } 384 385 type S3Signer struct { 386 ri func() chat1.RemoteInterface 387 } 388 389 func NewS3Signer(ri func() chat1.RemoteInterface) *S3Signer { 390 return &S3Signer{ 391 ri: ri, 392 } 393 } 394 395 // Sign implements github.com/keybase/go/chat/s3.Signer interface. 396 func (s *S3Signer) Sign(payload []byte) ([]byte, error) { 397 arg := chat1.S3SignArg{ 398 Payload: payload, 399 Version: 1, 400 } 401 return s.ri().S3Sign(context.Background(), arg) 402 }