github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/nbs/s3_fake_test.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // This file incorporates work covered by the following copyright and
    16  // permission notice:
    17  //
    18  // Copyright 2016 Attic Labs, Inc. All rights reserved.
    19  // Licensed under the Apache License, version 2.0:
    20  // http://www.apache.org/licenses/LICENSE-2.0
    21  
    22  package nbs
    23  
    24  import (
    25  	"bytes"
    26  	"context"
    27  	"io"
    28  	"net/url"
    29  	"strconv"
    30  	"strings"
    31  	"sync"
    32  	"testing"
    33  	"time"
    34  
    35  	"github.com/aws/aws-sdk-go/aws"
    36  	"github.com/aws/aws-sdk-go/aws/client/metadata"
    37  	"github.com/aws/aws-sdk-go/aws/request"
    38  	"github.com/aws/aws-sdk-go/service/s3"
    39  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    40  	"github.com/stretchr/testify/assert"
    41  
    42  	"github.com/dolthub/dolt/go/store/d"
    43  	"github.com/dolthub/dolt/go/store/hash"
    44  )
    45  
    46  type mockAWSError string
    47  
    48  func (m mockAWSError) Error() string   { return string(m) }
    49  func (m mockAWSError) Code() string    { return string(m) }
    50  func (m mockAWSError) Message() string { return string(m) }
    51  func (m mockAWSError) OrigErr() error  { return nil }
    52  
    53  func makeFakeS3(t *testing.T) *fakeS3 {
    54  	return &fakeS3{
    55  		assert:     assert.New(t),
    56  		data:       map[string][]byte{},
    57  		inProgress: map[string]fakeS3Multipart{},
    58  		parts:      map[string][]byte{},
    59  	}
    60  }
    61  
    62  type fakeS3 struct {
    63  	s3iface.S3API
    64  
    65  	assert *assert.Assertions
    66  
    67  	mu                sync.Mutex
    68  	data              map[string][]byte
    69  	inProgressCounter int
    70  	inProgress        map[string]fakeS3Multipart // Key -> {UploadId, Etags...}
    71  	parts             map[string][]byte          // ETag -> data
    72  	getCount          int
    73  }
    74  
    75  type fakeS3Multipart struct {
    76  	uploadID string
    77  	etags    []string
    78  }
    79  
    80  func (m *fakeS3) readerForTable(ctx context.Context, name hash.Hash) (chunkReader, error) {
    81  	m.mu.Lock()
    82  	defer m.mu.Unlock()
    83  	if buff, present := m.data[name.String()]; present {
    84  		ti, err := parseTableIndexByCopy(ctx, buff, &UnlimitedQuotaProvider{})
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  		tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), s3BlockSize)
    89  		if err != nil {
    90  			ti.Close()
    91  			return nil, err
    92  		}
    93  		return tr, nil
    94  	}
    95  	return nil, nil
    96  }
    97  
    98  func (m *fakeS3) readerForTableWithNamespace(ctx context.Context, ns string, name hash.Hash) (chunkReader, error) {
    99  	m.mu.Lock()
   100  	defer m.mu.Unlock()
   101  	key := name.String()
   102  	if ns != "" {
   103  		key = ns + "/" + key
   104  	}
   105  	if buff, present := m.data[key]; present {
   106  		ti, err := parseTableIndexByCopy(ctx, buff, &UnlimitedQuotaProvider{})
   107  
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  
   112  		tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), s3BlockSize)
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  		return tr, nil
   117  	}
   118  	return nil, nil
   119  }
   120  
   121  func (m *fakeS3) AbortMultipartUploadWithContext(ctx aws.Context, input *s3.AbortMultipartUploadInput, opts ...request.Option) (*s3.AbortMultipartUploadOutput, error) {
   122  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   123  	m.assert.NotNil(input.Key, "Key is a required field")
   124  	m.assert.NotNil(input.UploadId, "UploadId is a required field")
   125  
   126  	m.mu.Lock()
   127  	defer m.mu.Unlock()
   128  	m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId)
   129  	for _, etag := range m.inProgress[*input.Key].etags {
   130  		delete(m.parts, etag)
   131  	}
   132  	delete(m.inProgress, *input.Key)
   133  	return &s3.AbortMultipartUploadOutput{}, nil
   134  }
   135  
   136  func (m *fakeS3) CreateMultipartUploadWithContext(ctx aws.Context, input *s3.CreateMultipartUploadInput, opts ...request.Option) (*s3.CreateMultipartUploadOutput, error) {
   137  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   138  	m.assert.NotNil(input.Key, "Key is a required field")
   139  
   140  	out := &s3.CreateMultipartUploadOutput{
   141  		Bucket: input.Bucket,
   142  		Key:    input.Key,
   143  	}
   144  
   145  	m.mu.Lock()
   146  	defer m.mu.Unlock()
   147  	uploadID := strconv.Itoa(m.inProgressCounter)
   148  	out.UploadId = aws.String(uploadID)
   149  	m.inProgress[*input.Key] = fakeS3Multipart{uploadID, nil}
   150  	m.inProgressCounter++
   151  	return out, nil
   152  }
   153  
   154  func (m *fakeS3) UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) {
   155  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   156  	m.assert.NotNil(input.Key, "Key is a required field")
   157  	m.assert.NotNil(input.PartNumber, "PartNumber is a required field")
   158  	m.assert.NotNil(input.UploadId, "UploadId is a required field")
   159  	m.assert.NotNil(input.Body, "Body is a required field")
   160  
   161  	data, err := io.ReadAll(input.Body)
   162  	m.assert.NoError(err)
   163  
   164  	m.mu.Lock()
   165  	defer m.mu.Unlock()
   166  	etag := hash.Of(data).String() + time.Now().String()
   167  	m.parts[etag] = data
   168  
   169  	inProgress, present := m.inProgress[*input.Key]
   170  	m.assert.True(present)
   171  	m.assert.Equal(inProgress.uploadID, *input.UploadId)
   172  	inProgress.etags = append(inProgress.etags, etag)
   173  	m.inProgress[*input.Key] = inProgress
   174  	return &s3.UploadPartOutput{ETag: aws.String(etag)}, nil
   175  }
   176  
   177  func (m *fakeS3) UploadPartCopyWithContext(ctx aws.Context, input *s3.UploadPartCopyInput, opts ...request.Option) (*s3.UploadPartCopyOutput, error) {
   178  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   179  	m.assert.NotNil(input.Key, "Key is a required field")
   180  	m.assert.NotNil(input.PartNumber, "PartNumber is a required field")
   181  	m.assert.NotNil(input.UploadId, "UploadId is a required field")
   182  	m.assert.NotNil(input.CopySource, "CopySource is a required field")
   183  
   184  	unescaped, err := url.QueryUnescape(*input.CopySource)
   185  	m.assert.NoError(err)
   186  	slash := strings.LastIndex(unescaped, "/")
   187  	m.assert.NotEqual(-1, slash, "Malformed CopySource %s", unescaped)
   188  	src := unescaped[slash+1:]
   189  
   190  	m.mu.Lock()
   191  	defer m.mu.Unlock()
   192  	obj, present := m.data[src]
   193  	if !present {
   194  		return nil, mockAWSError("NoSuchKey")
   195  	}
   196  	if input.CopySourceRange != nil {
   197  		start, end := parseRange(*input.CopySourceRange, len(obj))
   198  		obj = obj[start:end]
   199  	}
   200  	etag := hash.Of(obj).String() + time.Now().String()
   201  	m.parts[etag] = obj
   202  
   203  	inProgress, present := m.inProgress[*input.Key]
   204  	m.assert.True(present)
   205  	m.assert.Equal(inProgress.uploadID, *input.UploadId)
   206  	inProgress.etags = append(inProgress.etags, etag)
   207  	m.inProgress[*input.Key] = inProgress
   208  	return &s3.UploadPartCopyOutput{CopyPartResult: &s3.CopyPartResult{ETag: aws.String(etag)}}, nil
   209  }
   210  
   211  func (m *fakeS3) CompleteMultipartUploadWithContext(ctx aws.Context, input *s3.CompleteMultipartUploadInput, opts ...request.Option) (*s3.CompleteMultipartUploadOutput, error) {
   212  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   213  	m.assert.NotNil(input.Key, "Key is a required field")
   214  	m.assert.NotNil(input.UploadId, "UploadId is a required field")
   215  	m.assert.NotNil(input.MultipartUpload, "MultipartUpload is a required field")
   216  	m.assert.True(len(input.MultipartUpload.Parts) > 0)
   217  
   218  	m.mu.Lock()
   219  	defer m.mu.Unlock()
   220  	m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId)
   221  	for idx, part := range input.MultipartUpload.Parts {
   222  		m.assert.EqualValues(idx+1, *part.PartNumber) // Part numbers are 1-indexed
   223  		m.data[*input.Key] = append(m.data[*input.Key], m.parts[*part.ETag]...)
   224  		delete(m.parts, *part.ETag)
   225  	}
   226  	delete(m.inProgress, *input.Key)
   227  
   228  	return &s3.CompleteMultipartUploadOutput{Bucket: input.Bucket, Key: input.Key}, nil
   229  }
   230  
   231  func (m *fakeS3) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) {
   232  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   233  	m.assert.NotNil(input.Key, "Key is a required field")
   234  
   235  	m.mu.Lock()
   236  	defer m.mu.Unlock()
   237  	m.getCount++
   238  	obj, present := m.data[*input.Key]
   239  	if !present {
   240  		return nil, mockAWSError("NoSuchKey")
   241  	}
   242  	if input.Range != nil {
   243  		start, end := parseRange(*input.Range, len(obj))
   244  		obj = obj[start:end]
   245  	}
   246  
   247  	return &s3.GetObjectOutput{
   248  		Body:          io.NopCloser(bytes.NewReader(obj)),
   249  		ContentLength: aws.Int64(int64(len(obj))),
   250  	}, nil
   251  }
   252  
   253  func parseRange(hdr string, total int) (start, end int) {
   254  	d.PanicIfFalse(len(hdr) > len(s3RangePrefix))
   255  	hdr = hdr[len(s3RangePrefix):]
   256  	d.PanicIfFalse(hdr[0] == '=')
   257  	hdr = hdr[1:]
   258  	if hdr[0] == '-' {
   259  		// negative range
   260  		fromEnd, err := strconv.Atoi(hdr[1:])
   261  		d.PanicIfError(err)
   262  		return total - fromEnd, total
   263  	}
   264  	ends := strings.Split(hdr, "-")
   265  	d.PanicIfFalse(len(ends) == 2)
   266  	start, err := strconv.Atoi(ends[0])
   267  	d.PanicIfError(err)
   268  	end, err = strconv.Atoi(ends[1])
   269  	d.PanicIfError(err)
   270  	return start, end + 1 // insanely, the HTTP range header specifies ranges inclusively.
   271  }
   272  
   273  func (m *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) {
   274  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   275  	m.assert.NotNil(input.Key, "Key is a required field")
   276  
   277  	buff := &bytes.Buffer{}
   278  	_, err := io.Copy(buff, input.Body)
   279  	m.assert.NoError(err)
   280  	m.mu.Lock()
   281  	defer m.mu.Unlock()
   282  	m.data[*input.Key] = buff.Bytes()
   283  
   284  	return &s3.PutObjectOutput{}, nil
   285  }
   286  
   287  func (m *fakeS3) GetObjectRequest(input *s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput) {
   288  	out := &s3.GetObjectOutput{}
   289  	var handlers request.Handlers
   290  	handlers.Send.PushBack(func(r *request.Request) {
   291  		res, err := m.GetObjectWithContext(r.Context(), input)
   292  		r.Error = err
   293  		if res != nil {
   294  			*(r.Data.(*s3.GetObjectOutput)) = *res
   295  		}
   296  	})
   297  	return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{
   298  		Name:       "GetObject",
   299  		HTTPMethod: "GET",
   300  		HTTPPath:   "/{Bucket}/{Key+}",
   301  	}, input, out), out
   302  }
   303  
   304  func (m *fakeS3) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) {
   305  	out := &s3.PutObjectOutput{}
   306  	var handlers request.Handlers
   307  	handlers.Send.PushBack(func(r *request.Request) {
   308  		res, err := m.PutObjectWithContext(r.Context(), input)
   309  		r.Error = err
   310  		if res != nil {
   311  			*(r.Data.(*s3.PutObjectOutput)) = *res
   312  		}
   313  	})
   314  	return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{
   315  		Name:       "PutObject",
   316  		HTTPMethod: "PUT",
   317  		HTTPPath:   "/{Bucket}/{Key+}",
   318  	}, input, out), out
   319  }