github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/store/nbs/s3_fake_test.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // This file incorporates work covered by the following copyright and
    16  // permission notice:
    17  //
    18  // Copyright 2016 Attic Labs, Inc. All rights reserved.
    19  // Licensed under the Apache License, version 2.0:
    20  // http://www.apache.org/licenses/LICENSE-2.0
    21  
    22  package nbs
    23  
    24  import (
    25  	"bytes"
    26  	"io"
    27  	"io/ioutil"
    28  	"net/url"
    29  	"strconv"
    30  	"strings"
    31  	"sync"
    32  	"testing"
    33  	"time"
    34  
    35  	"github.com/aws/aws-sdk-go/aws"
    36  	"github.com/aws/aws-sdk-go/aws/request"
    37  	"github.com/aws/aws-sdk-go/service/s3"
    38  	"github.com/stretchr/testify/assert"
    39  
    40  	"github.com/dolthub/dolt/go/store/d"
    41  	"github.com/dolthub/dolt/go/store/hash"
    42  )
    43  
    44  type mockAWSError string
    45  
    46  func (m mockAWSError) Error() string   { return string(m) }
    47  func (m mockAWSError) Code() string    { return string(m) }
    48  func (m mockAWSError) Message() string { return string(m) }
    49  func (m mockAWSError) OrigErr() error  { return nil }
    50  
    51  func makeFakeS3(t *testing.T) *fakeS3 {
    52  	return &fakeS3{
    53  		assert:     assert.New(t),
    54  		data:       map[string][]byte{},
    55  		inProgress: map[string]fakeS3Multipart{},
    56  		parts:      map[string][]byte{},
    57  	}
    58  }
    59  
    60  type fakeS3 struct {
    61  	assert *assert.Assertions
    62  
    63  	mu                sync.Mutex
    64  	data              map[string][]byte
    65  	inProgressCounter int
    66  	inProgress        map[string]fakeS3Multipart // Key -> {UploadId, Etags...}
    67  	parts             map[string][]byte          // ETag -> data
    68  	getCount          int
    69  }
    70  
    71  type fakeS3Multipart struct {
    72  	uploadID string
    73  	etags    []string
    74  }
    75  
    76  func (m *fakeS3) readerForTable(name addr) (chunkReader, error) {
    77  	m.mu.Lock()
    78  	defer m.mu.Unlock()
    79  	if buff, present := m.data[name.String()]; present {
    80  		ti, err := parseTableIndex(buff)
    81  
    82  		if err != nil {
    83  			return nil, err
    84  		}
    85  		return newTableReader(ti, tableReaderAtFromBytes(buff), s3BlockSize), nil
    86  	}
    87  	return nil, nil
    88  }
    89  
    90  func (m *fakeS3) readerForTableWithNamespace(ns string, name addr) (chunkReader, error) {
    91  	m.mu.Lock()
    92  	defer m.mu.Unlock()
    93  	key := name.String()
    94  	if ns != "" {
    95  		key = ns + "/" + key
    96  	}
    97  	if buff, present := m.data[key]; present {
    98  		ti, err := parseTableIndex(buff)
    99  
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  
   104  		return newTableReader(ti, tableReaderAtFromBytes(buff), s3BlockSize), nil
   105  	}
   106  	return nil, nil
   107  }
   108  
   109  func (m *fakeS3) AbortMultipartUploadWithContext(ctx aws.Context, input *s3.AbortMultipartUploadInput, opts ...request.Option) (*s3.AbortMultipartUploadOutput, error) {
   110  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   111  	m.assert.NotNil(input.Key, "Key is a required field")
   112  	m.assert.NotNil(input.UploadId, "UploadId is a required field")
   113  
   114  	m.mu.Lock()
   115  	defer m.mu.Unlock()
   116  	m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId)
   117  	for _, etag := range m.inProgress[*input.Key].etags {
   118  		delete(m.parts, etag)
   119  	}
   120  	delete(m.inProgress, *input.Key)
   121  	return &s3.AbortMultipartUploadOutput{}, nil
   122  }
   123  
   124  func (m *fakeS3) CreateMultipartUploadWithContext(ctx aws.Context, input *s3.CreateMultipartUploadInput, opts ...request.Option) (*s3.CreateMultipartUploadOutput, error) {
   125  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   126  	m.assert.NotNil(input.Key, "Key is a required field")
   127  
   128  	out := &s3.CreateMultipartUploadOutput{
   129  		Bucket: input.Bucket,
   130  		Key:    input.Key,
   131  	}
   132  
   133  	m.mu.Lock()
   134  	defer m.mu.Unlock()
   135  	uploadID := strconv.Itoa(m.inProgressCounter)
   136  	out.UploadId = aws.String(uploadID)
   137  	m.inProgress[*input.Key] = fakeS3Multipart{uploadID, nil}
   138  	m.inProgressCounter++
   139  	return out, nil
   140  }
   141  
   142  func (m *fakeS3) UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) {
   143  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   144  	m.assert.NotNil(input.Key, "Key is a required field")
   145  	m.assert.NotNil(input.PartNumber, "PartNumber is a required field")
   146  	m.assert.NotNil(input.UploadId, "UploadId is a required field")
   147  	m.assert.NotNil(input.Body, "Body is a required field")
   148  
   149  	data, err := ioutil.ReadAll(input.Body)
   150  	m.assert.NoError(err)
   151  
   152  	m.mu.Lock()
   153  	defer m.mu.Unlock()
   154  	etag := hash.Of(data).String() + time.Now().String()
   155  	m.parts[etag] = data
   156  
   157  	inProgress, present := m.inProgress[*input.Key]
   158  	m.assert.True(present)
   159  	m.assert.Equal(inProgress.uploadID, *input.UploadId)
   160  	inProgress.etags = append(inProgress.etags, etag)
   161  	m.inProgress[*input.Key] = inProgress
   162  	return &s3.UploadPartOutput{ETag: aws.String(etag)}, nil
   163  }
   164  
   165  func (m *fakeS3) UploadPartCopyWithContext(ctx aws.Context, input *s3.UploadPartCopyInput, opts ...request.Option) (*s3.UploadPartCopyOutput, error) {
   166  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   167  	m.assert.NotNil(input.Key, "Key is a required field")
   168  	m.assert.NotNil(input.PartNumber, "PartNumber is a required field")
   169  	m.assert.NotNil(input.UploadId, "UploadId is a required field")
   170  	m.assert.NotNil(input.CopySource, "CopySource is a required field")
   171  
   172  	unescaped, err := url.QueryUnescape(*input.CopySource)
   173  	m.assert.NoError(err)
   174  	slash := strings.LastIndex(unescaped, "/")
   175  	m.assert.NotEqual(-1, slash, "Malformed CopySource %s", unescaped)
   176  	src := unescaped[slash+1:]
   177  
   178  	m.mu.Lock()
   179  	defer m.mu.Unlock()
   180  	obj, present := m.data[src]
   181  	if !present {
   182  		return nil, mockAWSError("NoSuchKey")
   183  	}
   184  	if input.CopySourceRange != nil {
   185  		start, end := parseRange(*input.CopySourceRange, len(obj))
   186  		obj = obj[start:end]
   187  	}
   188  	etag := hash.Of(obj).String() + time.Now().String()
   189  	m.parts[etag] = obj
   190  
   191  	inProgress, present := m.inProgress[*input.Key]
   192  	m.assert.True(present)
   193  	m.assert.Equal(inProgress.uploadID, *input.UploadId)
   194  	inProgress.etags = append(inProgress.etags, etag)
   195  	m.inProgress[*input.Key] = inProgress
   196  	return &s3.UploadPartCopyOutput{CopyPartResult: &s3.CopyPartResult{ETag: aws.String(etag)}}, nil
   197  }
   198  
   199  func (m *fakeS3) CompleteMultipartUploadWithContext(ctx aws.Context, input *s3.CompleteMultipartUploadInput, opts ...request.Option) (*s3.CompleteMultipartUploadOutput, error) {
   200  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   201  	m.assert.NotNil(input.Key, "Key is a required field")
   202  	m.assert.NotNil(input.UploadId, "UploadId is a required field")
   203  	m.assert.NotNil(input.MultipartUpload, "MultipartUpload is a required field")
   204  	m.assert.True(len(input.MultipartUpload.Parts) > 0)
   205  
   206  	m.mu.Lock()
   207  	defer m.mu.Unlock()
   208  	m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId)
   209  	for idx, part := range input.MultipartUpload.Parts {
   210  		m.assert.EqualValues(idx+1, *part.PartNumber) // Part numbers are 1-indexed
   211  		m.data[*input.Key] = append(m.data[*input.Key], m.parts[*part.ETag]...)
   212  		delete(m.parts, *part.ETag)
   213  	}
   214  	delete(m.inProgress, *input.Key)
   215  
   216  	return &s3.CompleteMultipartUploadOutput{Bucket: input.Bucket, Key: input.Key}, nil
   217  }
   218  
   219  func (m *fakeS3) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) {
   220  	m.getCount++
   221  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   222  	m.assert.NotNil(input.Key, "Key is a required field")
   223  
   224  	m.mu.Lock()
   225  	defer m.mu.Unlock()
   226  	obj, present := m.data[*input.Key]
   227  	if !present {
   228  		return nil, mockAWSError("NoSuchKey")
   229  	}
   230  	if input.Range != nil {
   231  		start, end := parseRange(*input.Range, len(obj))
   232  		obj = obj[start:end]
   233  	}
   234  
   235  	return &s3.GetObjectOutput{
   236  		Body:          ioutil.NopCloser(bytes.NewReader(obj)),
   237  		ContentLength: aws.Int64(int64(len(obj))),
   238  	}, nil
   239  }
   240  
   241  func parseRange(hdr string, total int) (start, end int) {
   242  	d.PanicIfFalse(len(hdr) > len(s3RangePrefix))
   243  	hdr = hdr[len(s3RangePrefix):]
   244  	d.PanicIfFalse(hdr[0] == '=')
   245  	hdr = hdr[1:]
   246  	if hdr[0] == '-' {
   247  		// negative range
   248  		fromEnd, err := strconv.Atoi(hdr[1:])
   249  		d.PanicIfError(err)
   250  		return total - fromEnd, total
   251  	}
   252  	ends := strings.Split(hdr, "-")
   253  	d.PanicIfFalse(len(ends) == 2)
   254  	start, err := strconv.Atoi(ends[0])
   255  	d.PanicIfError(err)
   256  	end, err = strconv.Atoi(ends[1])
   257  	d.PanicIfError(err)
   258  	return start, end + 1 // insanely, the HTTP range header specifies ranges inclusively.
   259  }
   260  
   261  func (m *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) {
   262  	m.assert.NotNil(input.Bucket, "Bucket is a required field")
   263  	m.assert.NotNil(input.Key, "Key is a required field")
   264  
   265  	buff := &bytes.Buffer{}
   266  	_, err := io.Copy(buff, input.Body)
   267  	m.assert.NoError(err)
   268  	m.mu.Lock()
   269  	defer m.mu.Unlock()
   270  	m.data[*input.Key] = buff.Bytes()
   271  
   272  	return &s3.PutObjectOutput{}, nil
   273  }