github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/chat/attachments/s3.go (about) 1 package attachments 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/md5" 7 "encoding/hex" 8 "errors" 9 "fmt" 10 "io" 11 "sync" 12 13 "github.com/keybase/client/go/chat/attachments/progress" 14 "github.com/keybase/client/go/chat/s3" 15 "github.com/keybase/client/go/chat/types" 16 "github.com/keybase/client/go/protocol/chat1" 17 "golang.org/x/sync/errgroup" 18 ) 19 20 const s3PipelineMaxWidth = 10 21 22 type s3UploadPipeliner struct { 23 sync.Mutex 24 width int 25 waiters []chan struct{} 26 } 27 28 func (s *s3UploadPipeliner) QueueForTakeoff(ctx context.Context) error { 29 s.Lock() 30 if s.width >= s3PipelineMaxWidth { 31 ch := make(chan struct{}) 32 s.waiters = append(s.waiters, ch) 33 s.Unlock() 34 select { 35 case <-ch: 36 case <-ctx.Done(): 37 return ctx.Err() 38 } 39 s.Lock() 40 s.width++ 41 s.Unlock() 42 return nil 43 } 44 s.width++ 45 s.Unlock() 46 return nil 47 } 48 49 func (s *s3UploadPipeliner) Complete() { 50 s.Lock() 51 defer s.Unlock() 52 if len(s.waiters) > 0 { 53 close(s.waiters[0]) 54 if len(s.waiters) > 1 { 55 s.waiters = s.waiters[1:] 56 } else { 57 s.waiters = nil 58 } 59 } 60 if s.width > 0 { 61 s.width-- 62 } 63 } 64 65 var s3UploadPipeline = &s3UploadPipeliner{} 66 67 const minMultiSize = 5 * 1024 * 1024 // can't use Multi API with parts less than 5MB 68 const blockSize = 5 * 1024 * 1024 // 5MB is the minimum Multi part size 69 70 // ErrAbortOnPartMismatch is returned when there is a mismatch between a current 71 // part and a previous attempt part. If ErrAbortOnPartMismatch is returned, 72 // the caller should abort the upload attempt and start from scratch. 73 var ErrAbortOnPartMismatch = errors.New("local part mismatch, aborting upload") 74 75 // PutS3Result is the success result of calling PutS3. 76 type PutS3Result struct { 77 Region string 78 Endpoint string 79 Bucket string 80 Path string 81 Size int64 82 } 83 84 // PutS3 uploads the data in Reader r to S3. It chooses whether to use 85 // putSingle or putMultiPipeline based on the size of the object. 86 func (a *S3Store) PutS3(ctx context.Context, r io.Reader, size int64, task *UploadTask, previous *AttachmentInfo) (res *PutS3Result, err error) { 87 defer a.Trace(ctx, &err, "PutS3")() 88 region := a.regionFromParams(task.S3Params) 89 b := a.s3Conn(task.S3Signer, region, task.S3Params.AccessKey, task.S3Params.Token).Bucket( 90 task.S3Params.Bucket) 91 92 multiPartUpload := size > minMultiSize 93 if multiPartUpload && a.G().Env.GetAttachmentDisableMulti() { 94 a.Debug(ctx, "PutS3: multi part upload manually disabled, overriding for size: %v", size) 95 multiPartUpload = false 96 } 97 98 if !multiPartUpload { 99 if err := a.putSingle(ctx, r, size, task.S3Params, b, task.Progress); err != nil { 100 return nil, err 101 } 102 } else { 103 objectKey, err := a.putMultiPipeline(ctx, r, size, task, b, previous) 104 if err != nil { 105 return nil, err 106 } 107 task.S3Params.ObjectKey = objectKey 108 } 109 110 s3res := PutS3Result{ 111 Region: task.S3Params.RegionName, 112 Endpoint: task.S3Params.RegionEndpoint, 113 Bucket: task.S3Params.Bucket, 114 Path: task.S3Params.ObjectKey, 115 Size: size, 116 } 117 return &s3res, nil 118 } 119 120 // putSingle uploads data in r to S3 with the Put API. It has to be 121 // used for anything less than 5MB. It can be used for anything up 122 // to 5GB, but putMultiPipeline best for anything over 5MB. 123 func (a *S3Store) putSingle(ctx context.Context, r io.Reader, size int64, params chat1.S3Params, 124 b s3.BucketInt, progressReporter types.ProgressReporter) (err error) { 125 defer a.Trace(ctx, &err, fmt.Sprintf("putSingle(size=%d)", size))() 126 127 progWriter := progress.NewProgressWriter(progressReporter, size) 128 tee := io.TeeReader(r, progWriter) 129 130 if err := b.PutReader(ctx, params.ObjectKey, tee, size, "application/octet-stream", s3.ACL(params.Acl), 131 s3.Options{}); err != nil { 132 a.Debug(ctx, "putSingle: failed: %s", err) 133 return NewErrorWrapper("failed putSingle", err) 134 } 135 progWriter.Finish() 136 return nil 137 } 138 139 // putMultiPipeline uploads data in r to S3 using the Multi API. It uses a 140 // pipeline to upload 10 blocks of data concurrently. 141 // Each block is 5MB. It returns the object key if no errors. putMultiPipeline 142 // will return a different object key from params.ObjectKey if a previous Put is 143 // successfully resumed and completed. 144 func (a *S3Store) putMultiPipeline(ctx context.Context, r io.Reader, size int64, task *UploadTask, b s3.BucketInt, previous *AttachmentInfo) (res string, err error) { 145 defer a.Trace(ctx, &err, fmt.Sprintf("putMultiPipeline(size=%d)", size))() 146 147 var multi s3.MultiInt 148 if previous != nil { 149 a.Debug(ctx, "putMultiPipeline: previous exists. Changing object key from %q to %q", 150 task.S3Params.ObjectKey, previous.ObjectKey) 151 task.S3Params.ObjectKey = previous.ObjectKey 152 } 153 154 multi, err = b.Multi(ctx, task.S3Params.ObjectKey, "application/octet-stream", s3.ACL(task.S3Params.Acl)) 155 if err != nil { 156 a.Debug(ctx, "putMultiPipeline: b.Multi error: %s", err.Error()) 157 return "", NewErrorWrapper("s3 Multi error", err) 158 } 159 160 var previousParts map[int]s3.Part 161 if previous != nil { 162 previousParts = make(map[int]s3.Part) 163 list, err := multi.ListParts(ctx) 164 if err != nil { 165 a.Debug(ctx, "putMultiPipeline: ignoring multi.ListParts error: %s", err) 166 // dump previous since we can't check it anymore 167 previous = nil 168 } else { 169 for _, p := range list { 170 previousParts[p.N] = p 171 } 172 } 173 } 174 175 // need to use ectx in everything in eg.Go() funcs since eg 176 // will cancel ectx in eg.Wait(). 177 a.Debug(ctx, "putMultiPipeline: beginning parts uploader process") 178 eg, ectx := errgroup.WithContext(ctx) 179 blockCh := make(chan job) 180 retCh := make(chan s3.Part) 181 eg.Go(func() error { 182 defer close(blockCh) 183 return a.makeBlockJobs(ectx, r, blockCh, task.stashKey(), previous) 184 }) 185 eg.Go(func() error { 186 for lb := range blockCh { 187 if err := s3UploadPipeline.QueueForTakeoff(ectx); err != nil { 188 return err 189 } 190 b := lb 191 eg.Go(func() error { 192 defer s3UploadPipeline.Complete() 193 if err := a.uploadPart(ectx, task, b, previous, previousParts, multi, retCh); err != nil { 194 return err 195 } 196 return nil 197 }) 198 } 199 return nil 200 }) 201 go func() { 202 err := eg.Wait() 203 if err != nil { 204 a.Debug(ctx, "putMultiPipeline: error waiting: %+v", err) 205 } 206 close(retCh) 207 }() 208 209 var parts []s3.Part 210 progWriter := progress.NewProgressWriter(task.Progress, size) 211 for p := range retCh { 212 parts = append(parts, p) 213 progWriter.Update(int(p.Size)) 214 } 215 if err := eg.Wait(); err != nil { 216 return "", err 217 } 218 219 if a.blockLimit > 0 { 220 return "", errors.New("block limit hit, not completing multi upload") 221 } 222 a.Debug(ctx, "putMultiPipeline: all parts uploaded, completing request") 223 if err := multi.Complete(ctx, parts); err != nil { 224 a.Debug(ctx, "putMultiPipeline: Complete() failed: %s", err) 225 return "", err 226 } 227 a.Debug(ctx, "putMultiPipeline: success, %d parts", len(parts)) 228 // Just to make sure the UI gets the 100% call 229 progWriter.Finish() 230 return task.S3Params.ObjectKey, nil 231 } 232 233 type job struct { 234 block []byte 235 index int 236 hash string 237 } 238 239 func (j job) etag() string { 240 return `"` + j.hash + `"` 241 } 242 243 // makeBlockJobs reads ciphertext chunks from r and creates jobs that it puts onto blockCh. 244 // If this is a resumed upload, it verifies the blocks against the local stash before 245 // creating jobs. 246 func (a *S3Store) makeBlockJobs(ctx context.Context, r io.Reader, blockCh chan job, stashKey StashKey, previous *AttachmentInfo) error { 247 var partNumber int 248 for { 249 partNumber++ 250 block := make([]byte, blockSize) 251 // must call io.ReadFull to ensure full block read 252 n, err := io.ReadFull(r, block) 253 // io.ErrUnexpectedEOF will be returned for last partial block, 254 // which is ok. 255 if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { 256 return err 257 } 258 if n < blockSize { 259 block = block[:n] 260 } 261 if n > 0 { 262 md5sum := md5.Sum(block) 263 md5hex := hex.EncodeToString(md5sum[:]) 264 265 if previous != nil { 266 // resuming an upload, so check local stash record 267 // and abort on mismatch before adding a job for this block 268 // because if we don't it amounts to nonce reuse 269 lhash, found := previous.Parts[partNumber] 270 if found && lhash != md5hex { 271 a.Debug(ctx, "makeBlockJobs: part %d failed local part record verification", partNumber) 272 return ErrAbortOnPartMismatch 273 } 274 } 275 276 if err := a.addJob(ctx, blockCh, block, partNumber, md5hex); err != nil { 277 return err 278 } 279 } 280 if err == io.EOF || err == io.ErrUnexpectedEOF { 281 break 282 } 283 284 if a.blockLimit > 0 && partNumber >= a.blockLimit { 285 a.Debug(ctx, "makeBlockJobs: hit blockLimit of %d", a.blockLimit) 286 break 287 } 288 } 289 return nil 290 } 291 292 // addJob creates a job and puts it on blockCh, unless the blockCh isn't ready and the context has been canceled. 293 func (a *S3Store) addJob(ctx context.Context, blockCh chan job, block []byte, partNumber int, hash string) error { 294 // Create a job, unless the context has been canceled. 295 select { 296 case blockCh <- job{block: block, index: partNumber, hash: hash}: 297 return nil 298 case <-ctx.Done(): 299 return ctx.Err() 300 } 301 } 302 303 // uploadPart handles uploading a job to S3. The job `b` has already passed local stash verification. 304 // If this is a resumed upload, it checks the previous parts reported by S3 and will skip uploading 305 // any that already exist. 306 func (a *S3Store) uploadPart(ctx context.Context, task *UploadTask, b job, previous *AttachmentInfo, previousParts map[int]s3.Part, multi s3.MultiInt, retCh chan s3.Part) (err error) { 307 defer a.Trace(ctx, &err, fmt.Sprintf("uploadPart(%d)", b.index))() 308 309 // check to see if this part has already been uploaded. 310 // for job `b` to be here, it has already passed local stash verification. 311 if previous != nil { 312 // check s3 previousParts for this block 313 p, ok := previousParts[b.index] 314 if ok && int(p.Size) == len(b.block) && p.ETag == b.etag() { 315 a.Debug(ctx, "uploadPart: part %d already uploaded to s3", b.index) 316 317 // part already uploaded, so put it in the retCh unless the context 318 // has been canceled 319 select { 320 case retCh <- p: 321 case <-ctx.Done(): 322 return ctx.Err() 323 } 324 325 // nothing else to do 326 return nil 327 } 328 329 if p.Size > 0 { 330 // only abort if the part size from s3 is > 0. 331 a.Debug(ctx, "uploadPart: part %d s3 mismatch: size %d != expected %d or etag %s != expected %s", 332 b.index, p.Size, len(b.block), p.ETag, b.etag()) 333 return ErrAbortOnPartMismatch 334 } 335 336 // this part doesn't exist on s3, so it needs to be uploaded 337 a.Debug(ctx, "uploadPart: part %d not uploaded to s3 by previous upload attempt", b.index) 338 } 339 340 // stash part info locally before attempting S3 put 341 // doing this before attempting the S3 put is important 342 // for security concerns. 343 if err := a.stash.RecordPart(task.stashKey(), b.index, b.hash); err != nil { 344 a.Debug(ctx, "uploadPart: StashRecordPart error: %s", err) 345 } 346 347 part, putErr := multi.PutPart(ctx, b.index, bytes.NewReader(b.block)) 348 if putErr != nil { 349 return NewErrorWrapper(fmt.Sprintf("failed to put part %d", b.index), putErr) 350 } 351 352 // put the successfully uploaded part information in the retCh 353 // unless the context has been canceled. 354 select { 355 case retCh <- part: 356 case <-ctx.Done(): 357 a.Debug(ctx, "uploadPart: upload part %d, context canceled", b.index) 358 return ctx.Err() 359 } 360 361 return nil 362 } 363 364 type ErrorWrapper struct { 365 prefix string 366 err error 367 } 368 369 func NewErrorWrapper(prefix string, err error) *ErrorWrapper { 370 return &ErrorWrapper{prefix: prefix, err: err} 371 } 372 373 func (e *ErrorWrapper) Error() string { 374 return fmt.Sprintf("%s: %s (%T)", e.prefix, e.err, e.err) 375 } 376 377 func (e *ErrorWrapper) Details() string { 378 switch err := e.err.(type) { 379 case *s3.Error: 380 return fmt.Sprintf("%s: error %q, status code: %d, code: %s, message: %s, bucket: %s", e.prefix, e.err, err.StatusCode, err.Code, err.Message, err.BucketName) 381 default: 382 return fmt.Sprintf("%s: error %q, no details for type %T", e.prefix, e.err, e.err) 383 } 384 } 385 386 type S3Signer struct { 387 ri func() chat1.RemoteInterface 388 } 389 390 func NewS3Signer(ri func() chat1.RemoteInterface) *S3Signer { 391 return &S3Signer{ 392 ri: ri, 393 } 394 } 395 396 // Sign implements github.com/keybase/go/chat/s3.Signer interface. 397 func (s *S3Signer) Sign(payload []byte) ([]byte, error) { 398 arg := chat1.S3SignArg{ 399 Payload: payload, 400 Version: 1, 401 TempCreds: true, 402 } 403 return s.ri().S3Sign(context.Background(), arg) 404 }