github.com/sequix/cortex@v1.1.6/pkg/chunk/aws/mock.go (about)

     1  package aws
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"sort"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/aws/aws-sdk-go/aws"
    13  	"github.com/aws/aws-sdk-go/aws/awserr"
    14  	"github.com/aws/aws-sdk-go/aws/request"
    15  	"github.com/aws/aws-sdk-go/service/dynamodb"
    16  	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
    17  	"github.com/aws/aws-sdk-go/service/s3"
    18  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    19  	"github.com/sequix/cortex/pkg/util"
    20  	"github.com/go-kit/kit/log/level"
    21  )
    22  
    23  const arnPrefix = "arn:"
    24  
    25  type mockDynamoDBClient struct {
    26  	dynamodbiface.DynamoDBAPI
    27  
    28  	mtx            sync.RWMutex
    29  	unprocessed    int
    30  	provisionedErr int
    31  	errAfter       int
    32  	tables         map[string]*mockDynamoDBTable
    33  }
    34  
    35  type mockDynamoDBTable struct {
    36  	items       map[string][]mockDynamoDBItem
    37  	read, write int64
    38  	tags        []*dynamodb.Tag
    39  }
    40  
    41  type mockDynamoDBItem map[string]*dynamodb.AttributeValue
    42  
    43  func newMockDynamoDB(unprocessed int, provisionedErr int) *mockDynamoDBClient {
    44  	return &mockDynamoDBClient{
    45  		tables:         map[string]*mockDynamoDBTable{},
    46  		unprocessed:    unprocessed,
    47  		provisionedErr: provisionedErr,
    48  	}
    49  }
    50  
    51  func (a dynamoDBStorageClient) setErrorParameters(provisionedErr, errAfter int) {
    52  	if m, ok := a.DynamoDB.(*mockDynamoDBClient); ok {
    53  		m.provisionedErr = provisionedErr
    54  		m.errAfter = errAfter
    55  	}
    56  }
    57  
    58  func (m *mockDynamoDBClient) createTable(name string) {
    59  	m.mtx.Lock()
    60  	defer m.mtx.Unlock()
    61  	m.tables[name] = &mockDynamoDBTable{
    62  		items: map[string][]mockDynamoDBItem{},
    63  	}
    64  }
    65  
    66  func (m *mockDynamoDBClient) batchWriteItemRequest(_ context.Context, input *dynamodb.BatchWriteItemInput) dynamoDBRequest {
    67  	m.mtx.Lock()
    68  	defer m.mtx.Unlock()
    69  
    70  	resp := &dynamodb.BatchWriteItemOutput{
    71  		UnprocessedItems: map[string][]*dynamodb.WriteRequest{},
    72  	}
    73  
    74  	if m.errAfter > 0 {
    75  		m.errAfter--
    76  	} else if m.provisionedErr > 0 {
    77  		m.provisionedErr--
    78  		return &dynamoDBMockRequest{
    79  			result: resp,
    80  			err:    awserr.New(dynamodb.ErrCodeProvisionedThroughputExceededException, "", nil),
    81  		}
    82  	}
    83  
    84  	for tableName, writeRequests := range input.RequestItems {
    85  		table, ok := m.tables[tableName]
    86  		if !ok {
    87  			return &dynamoDBMockRequest{
    88  				result: &dynamodb.BatchWriteItemOutput{},
    89  				err:    fmt.Errorf("table not found: %s", tableName),
    90  			}
    91  		}
    92  
    93  		for _, writeRequest := range writeRequests {
    94  			if m.unprocessed > 0 {
    95  				m.unprocessed--
    96  				resp.UnprocessedItems[tableName] = append(resp.UnprocessedItems[tableName], writeRequest)
    97  				continue
    98  			}
    99  
   100  			hashValue := *writeRequest.PutRequest.Item[hashKey].S
   101  			rangeValue := writeRequest.PutRequest.Item[rangeKey].B
   102  
   103  			items := table.items[hashValue]
   104  
   105  			// insert in order
   106  			i := sort.Search(len(items), func(i int) bool {
   107  				return bytes.Compare(items[i][rangeKey].B, rangeValue) >= 0
   108  			})
   109  			if i >= len(items) || !bytes.Equal(items[i][rangeKey].B, rangeValue) {
   110  				items = append(items, nil)
   111  				copy(items[i+1:], items[i:])
   112  			} else {
   113  				return &dynamoDBMockRequest{
   114  					result: &dynamodb.BatchWriteItemOutput{},
   115  					err:    fmt.Errorf("Duplicate entry"),
   116  				}
   117  			}
   118  			items[i] = writeRequest.PutRequest.Item
   119  
   120  			table.items[hashValue] = items
   121  		}
   122  	}
   123  	return &dynamoDBMockRequest{result: resp}
   124  }
   125  
   126  func (m *mockDynamoDBClient) batchGetItemRequest(_ context.Context, input *dynamodb.BatchGetItemInput) dynamoDBRequest {
   127  	m.mtx.Lock()
   128  	defer m.mtx.Unlock()
   129  
   130  	resp := &dynamodb.BatchGetItemOutput{
   131  		Responses:       map[string][]map[string]*dynamodb.AttributeValue{},
   132  		UnprocessedKeys: map[string]*dynamodb.KeysAndAttributes{},
   133  	}
   134  
   135  	if m.errAfter > 0 {
   136  		m.errAfter--
   137  	} else if m.provisionedErr > 0 {
   138  		m.provisionedErr--
   139  		return &dynamoDBMockRequest{
   140  			result: resp,
   141  			err:    awserr.New(dynamodb.ErrCodeProvisionedThroughputExceededException, "", nil),
   142  		}
   143  	}
   144  
   145  	for tableName, readRequests := range input.RequestItems {
   146  		table, ok := m.tables[tableName]
   147  		if !ok {
   148  			return &dynamoDBMockRequest{
   149  				result: &dynamodb.BatchGetItemOutput{},
   150  				err:    fmt.Errorf("table not found"),
   151  			}
   152  		}
   153  
   154  		unprocessed := &dynamodb.KeysAndAttributes{
   155  			AttributesToGet:          readRequests.AttributesToGet,
   156  			ConsistentRead:           readRequests.ConsistentRead,
   157  			ExpressionAttributeNames: readRequests.ExpressionAttributeNames,
   158  		}
   159  		for _, readRequest := range readRequests.Keys {
   160  			if m.unprocessed > 0 {
   161  				m.unprocessed--
   162  				unprocessed.Keys = append(unprocessed.Keys, readRequest)
   163  				resp.UnprocessedKeys[tableName] = unprocessed
   164  				continue
   165  			}
   166  
   167  			hashValue := *readRequest[hashKey].S
   168  			rangeValue := readRequest[rangeKey].B
   169  			items := table.items[hashValue]
   170  
   171  			// insert in order
   172  			i := sort.Search(len(items), func(i int) bool {
   173  				return bytes.Compare(items[i][rangeKey].B, rangeValue) >= 0
   174  			})
   175  			if i >= len(items) || !bytes.Equal(items[i][rangeKey].B, rangeValue) {
   176  				return &dynamoDBMockRequest{
   177  					result: &dynamodb.BatchGetItemOutput{},
   178  					err:    fmt.Errorf("Couldn't find item"),
   179  				}
   180  			}
   181  
   182  			// Only return AttributesToGet!
   183  			item := map[string]*dynamodb.AttributeValue{}
   184  			for _, key := range readRequests.AttributesToGet {
   185  				item[*key] = items[i][*key]
   186  			}
   187  			resp.Responses[tableName] = append(resp.Responses[tableName], item)
   188  		}
   189  	}
   190  	return &dynamoDBMockRequest{
   191  		result: resp,
   192  	}
   193  }
   194  
   195  func (m *mockDynamoDBClient) queryRequest(_ context.Context, input *dynamodb.QueryInput) dynamoDBRequest {
   196  	result := &dynamodb.QueryOutput{
   197  		Items: []map[string]*dynamodb.AttributeValue{},
   198  	}
   199  
   200  	// Required filters
   201  	hashValue := *input.KeyConditions[hashKey].AttributeValueList[0].S
   202  
   203  	// Optional filters
   204  	var (
   205  		rangeValueFilter     []byte
   206  		rangeValueFilterType string
   207  	)
   208  	if c, ok := input.KeyConditions[rangeKey]; ok {
   209  		rangeValueFilter = c.AttributeValueList[0].B
   210  		rangeValueFilterType = *c.ComparisonOperator
   211  	}
   212  
   213  	// Filter by HashValue, RangeValue and Value if it exists
   214  	items := m.tables[*input.TableName].items[hashValue]
   215  	for _, item := range items {
   216  		rangeValue := item[rangeKey].B
   217  		if rangeValueFilterType == dynamodb.ComparisonOperatorGe && bytes.Compare(rangeValue, rangeValueFilter) < 0 {
   218  			continue
   219  		}
   220  		if rangeValueFilterType == dynamodb.ComparisonOperatorBeginsWith && !bytes.HasPrefix(rangeValue, rangeValueFilter) {
   221  			continue
   222  		}
   223  
   224  		if item[valueKey] != nil {
   225  			value := item[valueKey].B
   226  
   227  			// Apply filterExpression if it exists (supporting only v = :v)
   228  			if input.FilterExpression != nil {
   229  				if *input.FilterExpression == fmt.Sprintf("%s = :v", valueKey) {
   230  					filterValue := input.ExpressionAttributeValues[":v"].B
   231  					if !bytes.Equal(value, filterValue) {
   232  						continue
   233  					}
   234  				} else {
   235  					level.Warn(util.Logger).Log("msg", "unsupported FilterExpression", "expression", *input.FilterExpression)
   236  				}
   237  			}
   238  		}
   239  
   240  		result.Items = append(result.Items, item)
   241  	}
   242  
   243  	return &dynamoDBMockRequest{
   244  		result: result,
   245  	}
   246  }
   247  
   248  type dynamoDBMockRequest struct {
   249  	result interface{}
   250  	err    error
   251  }
   252  
   253  func (m *dynamoDBMockRequest) NextPage() dynamoDBRequest {
   254  	return m
   255  }
   256  func (m *dynamoDBMockRequest) Send() error {
   257  	return m.err
   258  }
   259  func (m *dynamoDBMockRequest) Data() interface{} {
   260  	return m.result
   261  }
   262  func (m *dynamoDBMockRequest) Error() error {
   263  	return m.err
   264  }
   265  func (m *dynamoDBMockRequest) HasNextPage() bool {
   266  	return false
   267  }
   268  func (m *dynamoDBMockRequest) Retryable() bool {
   269  	return false
   270  }
   271  
   272  func (m *mockDynamoDBClient) ListTablesPagesWithContext(_ aws.Context, input *dynamodb.ListTablesInput, fn func(*dynamodb.ListTablesOutput, bool) bool, _ ...request.Option) error {
   273  	m.mtx.RLock()
   274  	defer m.mtx.RUnlock()
   275  
   276  	var tableNames []*string
   277  	for tableName := range m.tables {
   278  		func(tableName string) {
   279  			tableNames = append(tableNames, &tableName)
   280  		}(tableName)
   281  	}
   282  	fn(&dynamodb.ListTablesOutput{
   283  		TableNames: tableNames,
   284  	}, true)
   285  
   286  	return nil
   287  }
   288  
   289  // CreateTable implements StorageClient.
   290  func (m *mockDynamoDBClient) CreateTableWithContext(_ aws.Context, input *dynamodb.CreateTableInput, _ ...request.Option) (*dynamodb.CreateTableOutput, error) {
   291  	m.mtx.Lock()
   292  	defer m.mtx.Unlock()
   293  
   294  	if _, ok := m.tables[*input.TableName]; ok {
   295  		return nil, fmt.Errorf("table already exists")
   296  	}
   297  
   298  	m.tables[*input.TableName] = &mockDynamoDBTable{
   299  		items: map[string][]mockDynamoDBItem{},
   300  		write: *input.ProvisionedThroughput.WriteCapacityUnits,
   301  		read:  *input.ProvisionedThroughput.ReadCapacityUnits,
   302  	}
   303  
   304  	return &dynamodb.CreateTableOutput{
   305  		TableDescription: &dynamodb.TableDescription{
   306  			TableArn: aws.String(arnPrefix + *input.TableName),
   307  		},
   308  	}, nil
   309  }
   310  
   311  // DescribeTable implements StorageClient.
   312  func (m *mockDynamoDBClient) DescribeTableWithContext(_ aws.Context, input *dynamodb.DescribeTableInput, _ ...request.Option) (*dynamodb.DescribeTableOutput, error) {
   313  	m.mtx.RLock()
   314  	defer m.mtx.RUnlock()
   315  
   316  	table, ok := m.tables[*input.TableName]
   317  	if !ok {
   318  		return nil, fmt.Errorf("not found")
   319  	}
   320  
   321  	return &dynamodb.DescribeTableOutput{
   322  		Table: &dynamodb.TableDescription{
   323  			TableName:   input.TableName,
   324  			TableStatus: aws.String(dynamodb.TableStatusActive),
   325  			ProvisionedThroughput: &dynamodb.ProvisionedThroughputDescription{
   326  				ReadCapacityUnits:  aws.Int64(table.read),
   327  				WriteCapacityUnits: aws.Int64(table.write),
   328  			},
   329  			TableArn: aws.String(arnPrefix + *input.TableName),
   330  		},
   331  	}, nil
   332  }
   333  
   334  // UpdateTable implements StorageClient.
   335  func (m *mockDynamoDBClient) UpdateTableWithContext(_ aws.Context, input *dynamodb.UpdateTableInput, _ ...request.Option) (*dynamodb.UpdateTableOutput, error) {
   336  	m.mtx.Lock()
   337  	defer m.mtx.Unlock()
   338  
   339  	table, ok := m.tables[*input.TableName]
   340  	if !ok {
   341  		return nil, fmt.Errorf("not found")
   342  	}
   343  
   344  	table.read = *input.ProvisionedThroughput.ReadCapacityUnits
   345  	table.write = *input.ProvisionedThroughput.WriteCapacityUnits
   346  
   347  	return &dynamodb.UpdateTableOutput{
   348  		TableDescription: &dynamodb.TableDescription{
   349  			TableArn: aws.String(arnPrefix + *input.TableName),
   350  		},
   351  	}, nil
   352  }
   353  
   354  func (m *mockDynamoDBClient) TagResourceWithContext(_ aws.Context, input *dynamodb.TagResourceInput, _ ...request.Option) (*dynamodb.TagResourceOutput, error) {
   355  	m.mtx.Lock()
   356  	defer m.mtx.Unlock()
   357  
   358  	if len(input.Tags) == 0 {
   359  		return nil, fmt.Errorf("tags are required")
   360  	}
   361  
   362  	if !strings.HasPrefix(*input.ResourceArn, arnPrefix) {
   363  		return nil, fmt.Errorf("not an arn: %v", *input.ResourceArn)
   364  	}
   365  
   366  	table, ok := m.tables[strings.TrimPrefix(*input.ResourceArn, arnPrefix)]
   367  	if !ok {
   368  		return nil, fmt.Errorf("not found")
   369  	}
   370  
   371  	table.tags = input.Tags
   372  	return &dynamodb.TagResourceOutput{}, nil
   373  }
   374  
   375  func (m *mockDynamoDBClient) ListTagsOfResourceWithContext(_ aws.Context, input *dynamodb.ListTagsOfResourceInput, _ ...request.Option) (*dynamodb.ListTagsOfResourceOutput, error) {
   376  	m.mtx.RLock()
   377  	defer m.mtx.RUnlock()
   378  
   379  	if !strings.HasPrefix(*input.ResourceArn, arnPrefix) {
   380  		return nil, fmt.Errorf("not an arn: %v", *input.ResourceArn)
   381  	}
   382  
   383  	table, ok := m.tables[strings.TrimPrefix(*input.ResourceArn, arnPrefix)]
   384  	if !ok {
   385  		return nil, fmt.Errorf("not found")
   386  	}
   387  
   388  	return &dynamodb.ListTagsOfResourceOutput{
   389  		Tags: table.tags,
   390  	}, nil
   391  }
   392  
   393  type mockS3 struct {
   394  	s3iface.S3API
   395  	sync.RWMutex
   396  	objects map[string][]byte
   397  }
   398  
   399  func newMockS3() *mockS3 {
   400  	return &mockS3{
   401  		objects: map[string][]byte{},
   402  	}
   403  }
   404  
   405  func (m *mockS3) PutObjectWithContext(_ aws.Context, req *s3.PutObjectInput, _ ...request.Option) (*s3.PutObjectOutput, error) {
   406  	m.Lock()
   407  	defer m.Unlock()
   408  
   409  	buf, err := ioutil.ReadAll(req.Body)
   410  	if err != nil {
   411  		return nil, err
   412  	}
   413  
   414  	m.objects[*req.Key] = buf
   415  	return &s3.PutObjectOutput{}, nil
   416  }
   417  
   418  func (m *mockS3) GetObjectWithContext(_ aws.Context, req *s3.GetObjectInput, _ ...request.Option) (*s3.GetObjectOutput, error) {
   419  	m.RLock()
   420  	defer m.RUnlock()
   421  
   422  	buf, ok := m.objects[*req.Key]
   423  	if !ok {
   424  		return nil, fmt.Errorf("Not found")
   425  	}
   426  
   427  	return &s3.GetObjectOutput{
   428  		Body: ioutil.NopCloser(bytes.NewReader(buf)),
   429  	}, nil
   430  }