github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/store/nbs/s3_fake_test.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // This file incorporates work covered by the following copyright and 16 // permission notice: 17 // 18 // Copyright 2016 Attic Labs, Inc. All rights reserved. 19 // Licensed under the Apache License, version 2.0: 20 // http://www.apache.org/licenses/LICENSE-2.0 21 22 package nbs 23 24 import ( 25 "bytes" 26 "io" 27 "io/ioutil" 28 "net/url" 29 "strconv" 30 "strings" 31 "sync" 32 "testing" 33 "time" 34 35 "github.com/aws/aws-sdk-go/aws" 36 "github.com/aws/aws-sdk-go/aws/request" 37 "github.com/aws/aws-sdk-go/service/s3" 38 "github.com/stretchr/testify/assert" 39 40 "github.com/dolthub/dolt/go/store/d" 41 "github.com/dolthub/dolt/go/store/hash" 42 ) 43 44 type mockAWSError string 45 46 func (m mockAWSError) Error() string { return string(m) } 47 func (m mockAWSError) Code() string { return string(m) } 48 func (m mockAWSError) Message() string { return string(m) } 49 func (m mockAWSError) OrigErr() error { return nil } 50 51 func makeFakeS3(t *testing.T) *fakeS3 { 52 return &fakeS3{ 53 assert: assert.New(t), 54 data: map[string][]byte{}, 55 inProgress: map[string]fakeS3Multipart{}, 56 parts: map[string][]byte{}, 57 } 58 } 59 60 type fakeS3 struct { 61 assert *assert.Assertions 62 63 mu sync.Mutex 64 data map[string][]byte 65 inProgressCounter int 66 inProgress map[string]fakeS3Multipart // Key -> {UploadId, Etags...} 67 parts map[string][]byte // ETag -> data 68 getCount int 69 } 70 71 type fakeS3Multipart struct { 72 uploadID string 73 etags []string 74 } 75 76 func (m *fakeS3) readerForTable(name addr) (chunkReader, error) { 77 m.mu.Lock() 78 defer m.mu.Unlock() 79 if buff, present := m.data[name.String()]; present { 80 ti, err := parseTableIndex(buff) 81 82 if err != nil { 83 return nil, err 84 } 85 return newTableReader(ti, tableReaderAtFromBytes(buff), s3BlockSize), nil 86 } 87 return nil, nil 88 } 89 90 func (m *fakeS3) readerForTableWithNamespace(ns string, name addr) (chunkReader, error) { 91 m.mu.Lock() 92 defer m.mu.Unlock() 93 key := name.String() 94 if ns != "" { 95 key = ns + "/" + key 96 } 97 if buff, present := m.data[key]; present { 98 ti, err := parseTableIndex(buff) 99 100 if err != nil { 101 return nil, err 102 } 103 104 return newTableReader(ti, tableReaderAtFromBytes(buff), s3BlockSize), nil 105 } 106 return nil, nil 107 } 108 109 func (m *fakeS3) AbortMultipartUploadWithContext(ctx aws.Context, input *s3.AbortMultipartUploadInput, opts ...request.Option) (*s3.AbortMultipartUploadOutput, error) { 110 m.assert.NotNil(input.Bucket, "Bucket is a required field") 111 m.assert.NotNil(input.Key, "Key is a required field") 112 m.assert.NotNil(input.UploadId, "UploadId is a required field") 113 114 m.mu.Lock() 115 defer m.mu.Unlock() 116 m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId) 117 for _, etag := range m.inProgress[*input.Key].etags { 118 delete(m.parts, etag) 119 } 120 delete(m.inProgress, *input.Key) 121 return &s3.AbortMultipartUploadOutput{}, nil 122 } 123 124 func (m *fakeS3) CreateMultipartUploadWithContext(ctx aws.Context, input *s3.CreateMultipartUploadInput, opts ...request.Option) (*s3.CreateMultipartUploadOutput, error) { 125 m.assert.NotNil(input.Bucket, "Bucket is a required field") 126 m.assert.NotNil(input.Key, "Key is a required field") 127 128 out := &s3.CreateMultipartUploadOutput{ 129 Bucket: input.Bucket, 130 Key: input.Key, 131 } 132 133 m.mu.Lock() 134 defer m.mu.Unlock() 135 uploadID := strconv.Itoa(m.inProgressCounter) 136 out.UploadId = aws.String(uploadID) 137 m.inProgress[*input.Key] = fakeS3Multipart{uploadID, nil} 138 m.inProgressCounter++ 139 return out, nil 140 } 141 142 func (m *fakeS3) UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) { 143 m.assert.NotNil(input.Bucket, "Bucket is a required field") 144 m.assert.NotNil(input.Key, "Key is a required field") 145 m.assert.NotNil(input.PartNumber, "PartNumber is a required field") 146 m.assert.NotNil(input.UploadId, "UploadId is a required field") 147 m.assert.NotNil(input.Body, "Body is a required field") 148 149 data, err := ioutil.ReadAll(input.Body) 150 m.assert.NoError(err) 151 152 m.mu.Lock() 153 defer m.mu.Unlock() 154 etag := hash.Of(data).String() + time.Now().String() 155 m.parts[etag] = data 156 157 inProgress, present := m.inProgress[*input.Key] 158 m.assert.True(present) 159 m.assert.Equal(inProgress.uploadID, *input.UploadId) 160 inProgress.etags = append(inProgress.etags, etag) 161 m.inProgress[*input.Key] = inProgress 162 return &s3.UploadPartOutput{ETag: aws.String(etag)}, nil 163 } 164 165 func (m *fakeS3) UploadPartCopyWithContext(ctx aws.Context, input *s3.UploadPartCopyInput, opts ...request.Option) (*s3.UploadPartCopyOutput, error) { 166 m.assert.NotNil(input.Bucket, "Bucket is a required field") 167 m.assert.NotNil(input.Key, "Key is a required field") 168 m.assert.NotNil(input.PartNumber, "PartNumber is a required field") 169 m.assert.NotNil(input.UploadId, "UploadId is a required field") 170 m.assert.NotNil(input.CopySource, "CopySource is a required field") 171 172 unescaped, err := url.QueryUnescape(*input.CopySource) 173 m.assert.NoError(err) 174 slash := strings.LastIndex(unescaped, "/") 175 m.assert.NotEqual(-1, slash, "Malformed CopySource %s", unescaped) 176 src := unescaped[slash+1:] 177 178 m.mu.Lock() 179 defer m.mu.Unlock() 180 obj, present := m.data[src] 181 if !present { 182 return nil, mockAWSError("NoSuchKey") 183 } 184 if input.CopySourceRange != nil { 185 start, end := parseRange(*input.CopySourceRange, len(obj)) 186 obj = obj[start:end] 187 } 188 etag := hash.Of(obj).String() + time.Now().String() 189 m.parts[etag] = obj 190 191 inProgress, present := m.inProgress[*input.Key] 192 m.assert.True(present) 193 m.assert.Equal(inProgress.uploadID, *input.UploadId) 194 inProgress.etags = append(inProgress.etags, etag) 195 m.inProgress[*input.Key] = inProgress 196 return &s3.UploadPartCopyOutput{CopyPartResult: &s3.CopyPartResult{ETag: aws.String(etag)}}, nil 197 } 198 199 func (m *fakeS3) CompleteMultipartUploadWithContext(ctx aws.Context, input *s3.CompleteMultipartUploadInput, opts ...request.Option) (*s3.CompleteMultipartUploadOutput, error) { 200 m.assert.NotNil(input.Bucket, "Bucket is a required field") 201 m.assert.NotNil(input.Key, "Key is a required field") 202 m.assert.NotNil(input.UploadId, "UploadId is a required field") 203 m.assert.NotNil(input.MultipartUpload, "MultipartUpload is a required field") 204 m.assert.True(len(input.MultipartUpload.Parts) > 0) 205 206 m.mu.Lock() 207 defer m.mu.Unlock() 208 m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId) 209 for idx, part := range input.MultipartUpload.Parts { 210 m.assert.EqualValues(idx+1, *part.PartNumber) // Part numbers are 1-indexed 211 m.data[*input.Key] = append(m.data[*input.Key], m.parts[*part.ETag]...) 212 delete(m.parts, *part.ETag) 213 } 214 delete(m.inProgress, *input.Key) 215 216 return &s3.CompleteMultipartUploadOutput{Bucket: input.Bucket, Key: input.Key}, nil 217 } 218 219 func (m *fakeS3) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) { 220 m.getCount++ 221 m.assert.NotNil(input.Bucket, "Bucket is a required field") 222 m.assert.NotNil(input.Key, "Key is a required field") 223 224 m.mu.Lock() 225 defer m.mu.Unlock() 226 obj, present := m.data[*input.Key] 227 if !present { 228 return nil, mockAWSError("NoSuchKey") 229 } 230 if input.Range != nil { 231 start, end := parseRange(*input.Range, len(obj)) 232 obj = obj[start:end] 233 } 234 235 return &s3.GetObjectOutput{ 236 Body: ioutil.NopCloser(bytes.NewReader(obj)), 237 ContentLength: aws.Int64(int64(len(obj))), 238 }, nil 239 } 240 241 func parseRange(hdr string, total int) (start, end int) { 242 d.PanicIfFalse(len(hdr) > len(s3RangePrefix)) 243 hdr = hdr[len(s3RangePrefix):] 244 d.PanicIfFalse(hdr[0] == '=') 245 hdr = hdr[1:] 246 if hdr[0] == '-' { 247 // negative range 248 fromEnd, err := strconv.Atoi(hdr[1:]) 249 d.PanicIfError(err) 250 return total - fromEnd, total 251 } 252 ends := strings.Split(hdr, "-") 253 d.PanicIfFalse(len(ends) == 2) 254 start, err := strconv.Atoi(ends[0]) 255 d.PanicIfError(err) 256 end, err = strconv.Atoi(ends[1]) 257 d.PanicIfError(err) 258 return start, end + 1 // insanely, the HTTP range header specifies ranges inclusively. 259 } 260 261 func (m *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) { 262 m.assert.NotNil(input.Bucket, "Bucket is a required field") 263 m.assert.NotNil(input.Key, "Key is a required field") 264 265 buff := &bytes.Buffer{} 266 _, err := io.Copy(buff, input.Body) 267 m.assert.NoError(err) 268 m.mu.Lock() 269 defer m.mu.Unlock() 270 m.data[*input.Key] = buff.Bytes() 271 272 return &s3.PutObjectOutput{}, nil 273 }