github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/nbs/s3_fake_test.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // This file incorporates work covered by the following copyright and 16 // permission notice: 17 // 18 // Copyright 2016 Attic Labs, Inc. All rights reserved. 19 // Licensed under the Apache License, version 2.0: 20 // http://www.apache.org/licenses/LICENSE-2.0 21 22 package nbs 23 24 import ( 25 "bytes" 26 "context" 27 "io" 28 "net/url" 29 "strconv" 30 "strings" 31 "sync" 32 "testing" 33 "time" 34 35 "github.com/aws/aws-sdk-go/aws" 36 "github.com/aws/aws-sdk-go/aws/client/metadata" 37 "github.com/aws/aws-sdk-go/aws/request" 38 "github.com/aws/aws-sdk-go/service/s3" 39 "github.com/aws/aws-sdk-go/service/s3/s3iface" 40 "github.com/stretchr/testify/assert" 41 42 "github.com/dolthub/dolt/go/store/d" 43 "github.com/dolthub/dolt/go/store/hash" 44 ) 45 46 type mockAWSError string 47 48 func (m mockAWSError) Error() string { return string(m) } 49 func (m mockAWSError) Code() string { return string(m) } 50 func (m mockAWSError) Message() string { return string(m) } 51 func (m mockAWSError) OrigErr() error { return nil } 52 53 func makeFakeS3(t *testing.T) *fakeS3 { 54 return &fakeS3{ 55 assert: assert.New(t), 56 data: map[string][]byte{}, 57 inProgress: map[string]fakeS3Multipart{}, 58 parts: map[string][]byte{}, 59 } 60 } 61 62 type fakeS3 struct { 63 s3iface.S3API 64 65 assert *assert.Assertions 66 67 mu sync.Mutex 68 data map[string][]byte 69 inProgressCounter int 70 inProgress map[string]fakeS3Multipart // Key -> {UploadId, Etags...} 71 parts map[string][]byte // ETag -> data 72 getCount int 73 } 74 75 type fakeS3Multipart struct { 76 uploadID string 77 etags []string 78 } 79 80 func (m *fakeS3) readerForTable(ctx context.Context, name hash.Hash) (chunkReader, error) { 81 m.mu.Lock() 82 defer m.mu.Unlock() 83 if buff, present := m.data[name.String()]; present { 84 ti, err := parseTableIndexByCopy(ctx, buff, &UnlimitedQuotaProvider{}) 85 if err != nil { 86 return nil, err 87 } 88 tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), s3BlockSize) 89 if err != nil { 90 ti.Close() 91 return nil, err 92 } 93 return tr, nil 94 } 95 return nil, nil 96 } 97 98 func (m *fakeS3) readerForTableWithNamespace(ctx context.Context, ns string, name hash.Hash) (chunkReader, error) { 99 m.mu.Lock() 100 defer m.mu.Unlock() 101 key := name.String() 102 if ns != "" { 103 key = ns + "/" + key 104 } 105 if buff, present := m.data[key]; present { 106 ti, err := parseTableIndexByCopy(ctx, buff, &UnlimitedQuotaProvider{}) 107 108 if err != nil { 109 return nil, err 110 } 111 112 tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), s3BlockSize) 113 if err != nil { 114 return nil, err 115 } 116 return tr, nil 117 } 118 return nil, nil 119 } 120 121 func (m *fakeS3) AbortMultipartUploadWithContext(ctx aws.Context, input *s3.AbortMultipartUploadInput, opts ...request.Option) (*s3.AbortMultipartUploadOutput, error) { 122 m.assert.NotNil(input.Bucket, "Bucket is a required field") 123 m.assert.NotNil(input.Key, "Key is a required field") 124 m.assert.NotNil(input.UploadId, "UploadId is a required field") 125 126 m.mu.Lock() 127 defer m.mu.Unlock() 128 m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId) 129 for _, etag := range m.inProgress[*input.Key].etags { 130 delete(m.parts, etag) 131 } 132 delete(m.inProgress, *input.Key) 133 return &s3.AbortMultipartUploadOutput{}, nil 134 } 135 136 func (m *fakeS3) CreateMultipartUploadWithContext(ctx aws.Context, input *s3.CreateMultipartUploadInput, opts ...request.Option) (*s3.CreateMultipartUploadOutput, error) { 137 m.assert.NotNil(input.Bucket, "Bucket is a required field") 138 m.assert.NotNil(input.Key, "Key is a required field") 139 140 out := &s3.CreateMultipartUploadOutput{ 141 Bucket: input.Bucket, 142 Key: input.Key, 143 } 144 145 m.mu.Lock() 146 defer m.mu.Unlock() 147 uploadID := strconv.Itoa(m.inProgressCounter) 148 out.UploadId = aws.String(uploadID) 149 m.inProgress[*input.Key] = fakeS3Multipart{uploadID, nil} 150 m.inProgressCounter++ 151 return out, nil 152 } 153 154 func (m *fakeS3) UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) { 155 m.assert.NotNil(input.Bucket, "Bucket is a required field") 156 m.assert.NotNil(input.Key, "Key is a required field") 157 m.assert.NotNil(input.PartNumber, "PartNumber is a required field") 158 m.assert.NotNil(input.UploadId, "UploadId is a required field") 159 m.assert.NotNil(input.Body, "Body is a required field") 160 161 data, err := io.ReadAll(input.Body) 162 m.assert.NoError(err) 163 164 m.mu.Lock() 165 defer m.mu.Unlock() 166 etag := hash.Of(data).String() + time.Now().String() 167 m.parts[etag] = data 168 169 inProgress, present := m.inProgress[*input.Key] 170 m.assert.True(present) 171 m.assert.Equal(inProgress.uploadID, *input.UploadId) 172 inProgress.etags = append(inProgress.etags, etag) 173 m.inProgress[*input.Key] = inProgress 174 return &s3.UploadPartOutput{ETag: aws.String(etag)}, nil 175 } 176 177 func (m *fakeS3) UploadPartCopyWithContext(ctx aws.Context, input *s3.UploadPartCopyInput, opts ...request.Option) (*s3.UploadPartCopyOutput, error) { 178 m.assert.NotNil(input.Bucket, "Bucket is a required field") 179 m.assert.NotNil(input.Key, "Key is a required field") 180 m.assert.NotNil(input.PartNumber, "PartNumber is a required field") 181 m.assert.NotNil(input.UploadId, "UploadId is a required field") 182 m.assert.NotNil(input.CopySource, "CopySource is a required field") 183 184 unescaped, err := url.QueryUnescape(*input.CopySource) 185 m.assert.NoError(err) 186 slash := strings.LastIndex(unescaped, "/") 187 m.assert.NotEqual(-1, slash, "Malformed CopySource %s", unescaped) 188 src := unescaped[slash+1:] 189 190 m.mu.Lock() 191 defer m.mu.Unlock() 192 obj, present := m.data[src] 193 if !present { 194 return nil, mockAWSError("NoSuchKey") 195 } 196 if input.CopySourceRange != nil { 197 start, end := parseRange(*input.CopySourceRange, len(obj)) 198 obj = obj[start:end] 199 } 200 etag := hash.Of(obj).String() + time.Now().String() 201 m.parts[etag] = obj 202 203 inProgress, present := m.inProgress[*input.Key] 204 m.assert.True(present) 205 m.assert.Equal(inProgress.uploadID, *input.UploadId) 206 inProgress.etags = append(inProgress.etags, etag) 207 m.inProgress[*input.Key] = inProgress 208 return &s3.UploadPartCopyOutput{CopyPartResult: &s3.CopyPartResult{ETag: aws.String(etag)}}, nil 209 } 210 211 func (m *fakeS3) CompleteMultipartUploadWithContext(ctx aws.Context, input *s3.CompleteMultipartUploadInput, opts ...request.Option) (*s3.CompleteMultipartUploadOutput, error) { 212 m.assert.NotNil(input.Bucket, "Bucket is a required field") 213 m.assert.NotNil(input.Key, "Key is a required field") 214 m.assert.NotNil(input.UploadId, "UploadId is a required field") 215 m.assert.NotNil(input.MultipartUpload, "MultipartUpload is a required field") 216 m.assert.True(len(input.MultipartUpload.Parts) > 0) 217 218 m.mu.Lock() 219 defer m.mu.Unlock() 220 m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId) 221 for idx, part := range input.MultipartUpload.Parts { 222 m.assert.EqualValues(idx+1, *part.PartNumber) // Part numbers are 1-indexed 223 m.data[*input.Key] = append(m.data[*input.Key], m.parts[*part.ETag]...) 224 delete(m.parts, *part.ETag) 225 } 226 delete(m.inProgress, *input.Key) 227 228 return &s3.CompleteMultipartUploadOutput{Bucket: input.Bucket, Key: input.Key}, nil 229 } 230 231 func (m *fakeS3) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) { 232 m.assert.NotNil(input.Bucket, "Bucket is a required field") 233 m.assert.NotNil(input.Key, "Key is a required field") 234 235 m.mu.Lock() 236 defer m.mu.Unlock() 237 m.getCount++ 238 obj, present := m.data[*input.Key] 239 if !present { 240 return nil, mockAWSError("NoSuchKey") 241 } 242 if input.Range != nil { 243 start, end := parseRange(*input.Range, len(obj)) 244 obj = obj[start:end] 245 } 246 247 return &s3.GetObjectOutput{ 248 Body: io.NopCloser(bytes.NewReader(obj)), 249 ContentLength: aws.Int64(int64(len(obj))), 250 }, nil 251 } 252 253 func parseRange(hdr string, total int) (start, end int) { 254 d.PanicIfFalse(len(hdr) > len(s3RangePrefix)) 255 hdr = hdr[len(s3RangePrefix):] 256 d.PanicIfFalse(hdr[0] == '=') 257 hdr = hdr[1:] 258 if hdr[0] == '-' { 259 // negative range 260 fromEnd, err := strconv.Atoi(hdr[1:]) 261 d.PanicIfError(err) 262 return total - fromEnd, total 263 } 264 ends := strings.Split(hdr, "-") 265 d.PanicIfFalse(len(ends) == 2) 266 start, err := strconv.Atoi(ends[0]) 267 d.PanicIfError(err) 268 end, err = strconv.Atoi(ends[1]) 269 d.PanicIfError(err) 270 return start, end + 1 // insanely, the HTTP range header specifies ranges inclusively. 271 } 272 273 func (m *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) { 274 m.assert.NotNil(input.Bucket, "Bucket is a required field") 275 m.assert.NotNil(input.Key, "Key is a required field") 276 277 buff := &bytes.Buffer{} 278 _, err := io.Copy(buff, input.Body) 279 m.assert.NoError(err) 280 m.mu.Lock() 281 defer m.mu.Unlock() 282 m.data[*input.Key] = buff.Bytes() 283 284 return &s3.PutObjectOutput{}, nil 285 } 286 287 func (m *fakeS3) GetObjectRequest(input *s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput) { 288 out := &s3.GetObjectOutput{} 289 var handlers request.Handlers 290 handlers.Send.PushBack(func(r *request.Request) { 291 res, err := m.GetObjectWithContext(r.Context(), input) 292 r.Error = err 293 if res != nil { 294 *(r.Data.(*s3.GetObjectOutput)) = *res 295 } 296 }) 297 return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{ 298 Name: "GetObject", 299 HTTPMethod: "GET", 300 HTTPPath: "/{Bucket}/{Key+}", 301 }, input, out), out 302 } 303 304 func (m *fakeS3) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) { 305 out := &s3.PutObjectOutput{} 306 var handlers request.Handlers 307 handlers.Send.PushBack(func(r *request.Request) { 308 res, err := m.PutObjectWithContext(r.Context(), input) 309 r.Error = err 310 if res != nil { 311 *(r.Data.(*s3.PutObjectOutput)) = *res 312 } 313 }) 314 return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{ 315 Name: "PutObject", 316 HTTPMethod: "PUT", 317 HTTPPath: "/{Bucket}/{Key+}", 318 }, input, out), out 319 }