github.com/yankunsam/loki/v2@v2.6.3-0.20220817130409-389df5235c27/pkg/storage/chunk/client/aws/mock.go (about)

     1  package aws
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"sort"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/aws/aws-sdk-go/aws"
    13  	"github.com/aws/aws-sdk-go/aws/awserr"
    14  	"github.com/aws/aws-sdk-go/aws/request"
    15  	"github.com/aws/aws-sdk-go/service/dynamodb"
    16  	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
    17  	"github.com/aws/aws-sdk-go/service/s3"
    18  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    19  	"github.com/go-kit/log/level"
    20  
    21  	util_log "github.com/grafana/loki/pkg/util/log"
    22  )
    23  
    24  const arnPrefix = "arn:"
    25  
    26  type mockDynamoDBClient struct {
    27  	dynamodbiface.DynamoDBAPI
    28  
    29  	mtx            sync.RWMutex
    30  	unprocessed    int
    31  	provisionedErr int
    32  	errAfter       int
    33  	tables         map[string]*mockDynamoDBTable
    34  }
    35  
    36  type mockDynamoDBTable struct {
    37  	items       map[string][]mockDynamoDBItem
    38  	read, write int64
    39  	tags        []*dynamodb.Tag
    40  }
    41  
    42  type mockDynamoDBItem map[string]*dynamodb.AttributeValue
    43  
    44  // nolint
    45  func newMockDynamoDB(unprocessed int, provisionedErr int) *mockDynamoDBClient {
    46  	return &mockDynamoDBClient{
    47  		tables:         map[string]*mockDynamoDBTable{},
    48  		unprocessed:    unprocessed,
    49  		provisionedErr: provisionedErr,
    50  	}
    51  }
    52  
    53  func (a dynamoDBStorageClient) setErrorParameters(provisionedErr, errAfter int) {
    54  	if m, ok := a.DynamoDB.(*mockDynamoDBClient); ok {
    55  		m.provisionedErr = provisionedErr
    56  		m.errAfter = errAfter
    57  	}
    58  }
    59  
    60  //nolint:unused //Leaving this around in the case we need to create a table via mock this is useful.
    61  func (m *mockDynamoDBClient) createTable(name string) {
    62  	m.mtx.Lock()
    63  	defer m.mtx.Unlock()
    64  	m.tables[name] = &mockDynamoDBTable{
    65  		items: map[string][]mockDynamoDBItem{},
    66  	}
    67  }
    68  
    69  func (m *mockDynamoDBClient) batchWriteItemRequest(_ context.Context, input *dynamodb.BatchWriteItemInput) dynamoDBRequest {
    70  	m.mtx.Lock()
    71  	defer m.mtx.Unlock()
    72  
    73  	resp := &dynamodb.BatchWriteItemOutput{
    74  		UnprocessedItems: map[string][]*dynamodb.WriteRequest{},
    75  	}
    76  
    77  	if m.errAfter > 0 {
    78  		m.errAfter--
    79  	} else if m.provisionedErr > 0 {
    80  		m.provisionedErr--
    81  		return &dynamoDBMockRequest{
    82  			result: resp,
    83  			err:    awserr.New(dynamodb.ErrCodeProvisionedThroughputExceededException, "", nil),
    84  		}
    85  	}
    86  
    87  	for tableName, writeRequests := range input.RequestItems {
    88  		table, ok := m.tables[tableName]
    89  		if !ok {
    90  			return &dynamoDBMockRequest{
    91  				result: &dynamodb.BatchWriteItemOutput{},
    92  				err:    fmt.Errorf("table not found: %s", tableName),
    93  			}
    94  		}
    95  
    96  		for _, writeRequest := range writeRequests {
    97  			if m.unprocessed > 0 {
    98  				m.unprocessed--
    99  				resp.UnprocessedItems[tableName] = append(resp.UnprocessedItems[tableName], writeRequest)
   100  				continue
   101  			}
   102  
   103  			hashValue := *writeRequest.PutRequest.Item[hashKey].S
   104  			rangeValue := writeRequest.PutRequest.Item[rangeKey].B
   105  
   106  			items := table.items[hashValue]
   107  
   108  			// insert in order
   109  			i := sort.Search(len(items), func(i int) bool {
   110  				return bytes.Compare(items[i][rangeKey].B, rangeValue) >= 0
   111  			})
   112  			if i >= len(items) || !bytes.Equal(items[i][rangeKey].B, rangeValue) {
   113  				items = append(items, nil)
   114  				copy(items[i+1:], items[i:])
   115  			} else {
   116  				return &dynamoDBMockRequest{
   117  					result: &dynamodb.BatchWriteItemOutput{},
   118  					err:    fmt.Errorf("Duplicate entry"),
   119  				}
   120  			}
   121  			items[i] = writeRequest.PutRequest.Item
   122  
   123  			table.items[hashValue] = items
   124  		}
   125  	}
   126  	return &dynamoDBMockRequest{result: resp}
   127  }
   128  
   129  func (m *mockDynamoDBClient) batchGetItemRequest(_ context.Context, input *dynamodb.BatchGetItemInput) dynamoDBRequest {
   130  	m.mtx.Lock()
   131  	defer m.mtx.Unlock()
   132  
   133  	resp := &dynamodb.BatchGetItemOutput{
   134  		Responses:       map[string][]map[string]*dynamodb.AttributeValue{},
   135  		UnprocessedKeys: map[string]*dynamodb.KeysAndAttributes{},
   136  	}
   137  
   138  	if m.errAfter > 0 {
   139  		m.errAfter--
   140  	} else if m.provisionedErr > 0 {
   141  		m.provisionedErr--
   142  		return &dynamoDBMockRequest{
   143  			result: resp,
   144  			err:    awserr.New(dynamodb.ErrCodeProvisionedThroughputExceededException, "", nil),
   145  		}
   146  	}
   147  
   148  	for tableName, readRequests := range input.RequestItems {
   149  		table, ok := m.tables[tableName]
   150  		if !ok {
   151  			return &dynamoDBMockRequest{
   152  				result: &dynamodb.BatchGetItemOutput{},
   153  				err:    fmt.Errorf("table not found"),
   154  			}
   155  		}
   156  
   157  		unprocessed := &dynamodb.KeysAndAttributes{
   158  			AttributesToGet:          readRequests.AttributesToGet,
   159  			ConsistentRead:           readRequests.ConsistentRead,
   160  			ExpressionAttributeNames: readRequests.ExpressionAttributeNames,
   161  		}
   162  		for _, readRequest := range readRequests.Keys {
   163  			if m.unprocessed > 0 {
   164  				m.unprocessed--
   165  				unprocessed.Keys = append(unprocessed.Keys, readRequest)
   166  				resp.UnprocessedKeys[tableName] = unprocessed
   167  				continue
   168  			}
   169  
   170  			hashValue := *readRequest[hashKey].S
   171  			rangeValue := readRequest[rangeKey].B
   172  			items := table.items[hashValue]
   173  
   174  			// insert in order
   175  			i := sort.Search(len(items), func(i int) bool {
   176  				return bytes.Compare(items[i][rangeKey].B, rangeValue) >= 0
   177  			})
   178  			if i >= len(items) || !bytes.Equal(items[i][rangeKey].B, rangeValue) {
   179  				return &dynamoDBMockRequest{
   180  					result: &dynamodb.BatchGetItemOutput{},
   181  					err:    fmt.Errorf("Couldn't find item"),
   182  				}
   183  			}
   184  
   185  			// Only return AttributesToGet!
   186  			item := map[string]*dynamodb.AttributeValue{}
   187  			for _, key := range readRequests.AttributesToGet {
   188  				item[*key] = items[i][*key]
   189  			}
   190  			resp.Responses[tableName] = append(resp.Responses[tableName], item)
   191  		}
   192  	}
   193  	return &dynamoDBMockRequest{
   194  		result: resp,
   195  	}
   196  }
   197  
   198  func (m *mockDynamoDBClient) QueryPagesWithContext(ctx aws.Context, input *dynamodb.QueryInput, fn func(*dynamodb.QueryOutput, bool) bool, opts ...request.Option) error {
   199  	result := &dynamodb.QueryOutput{
   200  		Items: []map[string]*dynamodb.AttributeValue{},
   201  	}
   202  
   203  	// Required filters
   204  	hashValue := *input.KeyConditions[hashKey].AttributeValueList[0].S
   205  
   206  	// Optional filters
   207  	var (
   208  		rangeValueFilter     []byte
   209  		rangeValueFilterType string
   210  	)
   211  	if c, ok := input.KeyConditions[rangeKey]; ok {
   212  		rangeValueFilter = c.AttributeValueList[0].B
   213  		rangeValueFilterType = *c.ComparisonOperator
   214  	}
   215  
   216  	// Filter by HashValue, RangeValue and Value if it exists
   217  	items := m.tables[*input.TableName].items[hashValue]
   218  	for _, item := range items {
   219  		rangeValue := item[rangeKey].B
   220  		if rangeValueFilterType == dynamodb.ComparisonOperatorGe && bytes.Compare(rangeValue, rangeValueFilter) < 0 {
   221  			continue
   222  		}
   223  		if rangeValueFilterType == dynamodb.ComparisonOperatorBeginsWith && !bytes.HasPrefix(rangeValue, rangeValueFilter) {
   224  			continue
   225  		}
   226  
   227  		if item[valueKey] != nil {
   228  			value := item[valueKey].B
   229  
   230  			// Apply filterExpression if it exists (supporting only v = :v)
   231  			if input.FilterExpression != nil {
   232  				if *input.FilterExpression == fmt.Sprintf("%s = :v", valueKey) {
   233  					filterValue := input.ExpressionAttributeValues[":v"].B
   234  					if !bytes.Equal(value, filterValue) {
   235  						continue
   236  					}
   237  				} else {
   238  					level.Warn(util_log.Logger).Log("msg", "unsupported FilterExpression", "expression", *input.FilterExpression)
   239  				}
   240  			}
   241  		}
   242  
   243  		result.Items = append(result.Items, item)
   244  	}
   245  	fn(result, true)
   246  	return nil
   247  }
   248  
   249  type dynamoDBMockRequest struct {
   250  	result interface{}
   251  	err    error
   252  }
   253  
   254  func (m *dynamoDBMockRequest) Send() error {
   255  	return m.err
   256  }
   257  
   258  func (m *dynamoDBMockRequest) Data() interface{} {
   259  	return m.result
   260  }
   261  
   262  func (m *dynamoDBMockRequest) Error() error {
   263  	return m.err
   264  }
   265  
   266  func (m *dynamoDBMockRequest) Retryable() bool {
   267  	return false
   268  }
   269  
   270  func (m *mockDynamoDBClient) ListTablesPagesWithContext(_ aws.Context, input *dynamodb.ListTablesInput, fn func(*dynamodb.ListTablesOutput, bool) bool, _ ...request.Option) error {
   271  	m.mtx.RLock()
   272  	defer m.mtx.RUnlock()
   273  
   274  	var tableNames []*string
   275  	for tableName := range m.tables {
   276  		func(tableName string) {
   277  			tableNames = append(tableNames, &tableName)
   278  		}(tableName)
   279  	}
   280  	fn(&dynamodb.ListTablesOutput{
   281  		TableNames: tableNames,
   282  	}, true)
   283  
   284  	return nil
   285  }
   286  
   287  // CreateTable implements StorageClient.
   288  func (m *mockDynamoDBClient) CreateTableWithContext(_ aws.Context, input *dynamodb.CreateTableInput, _ ...request.Option) (*dynamodb.CreateTableOutput, error) {
   289  	m.mtx.Lock()
   290  	defer m.mtx.Unlock()
   291  
   292  	if _, ok := m.tables[*input.TableName]; ok {
   293  		return nil, fmt.Errorf("table already exists")
   294  	}
   295  
   296  	m.tables[*input.TableName] = &mockDynamoDBTable{
   297  		items: map[string][]mockDynamoDBItem{},
   298  		write: *input.ProvisionedThroughput.WriteCapacityUnits,
   299  		read:  *input.ProvisionedThroughput.ReadCapacityUnits,
   300  	}
   301  
   302  	return &dynamodb.CreateTableOutput{
   303  		TableDescription: &dynamodb.TableDescription{
   304  			TableArn: aws.String(arnPrefix + *input.TableName),
   305  		},
   306  	}, nil
   307  }
   308  
   309  // DescribeTable implements StorageClient.
   310  func (m *mockDynamoDBClient) DescribeTableWithContext(_ aws.Context, input *dynamodb.DescribeTableInput, _ ...request.Option) (*dynamodb.DescribeTableOutput, error) {
   311  	m.mtx.RLock()
   312  	defer m.mtx.RUnlock()
   313  
   314  	table, ok := m.tables[*input.TableName]
   315  	if !ok {
   316  		return nil, fmt.Errorf("not found")
   317  	}
   318  
   319  	return &dynamodb.DescribeTableOutput{
   320  		Table: &dynamodb.TableDescription{
   321  			TableName:   input.TableName,
   322  			TableStatus: aws.String(dynamodb.TableStatusActive),
   323  			ProvisionedThroughput: &dynamodb.ProvisionedThroughputDescription{
   324  				ReadCapacityUnits:  aws.Int64(table.read),
   325  				WriteCapacityUnits: aws.Int64(table.write),
   326  			},
   327  			TableArn: aws.String(arnPrefix + *input.TableName),
   328  		},
   329  	}, nil
   330  }
   331  
   332  // UpdateTable implements StorageClient.
   333  func (m *mockDynamoDBClient) UpdateTableWithContext(_ aws.Context, input *dynamodb.UpdateTableInput, _ ...request.Option) (*dynamodb.UpdateTableOutput, error) {
   334  	m.mtx.Lock()
   335  	defer m.mtx.Unlock()
   336  
   337  	table, ok := m.tables[*input.TableName]
   338  	if !ok {
   339  		return nil, fmt.Errorf("not found")
   340  	}
   341  
   342  	table.read = *input.ProvisionedThroughput.ReadCapacityUnits
   343  	table.write = *input.ProvisionedThroughput.WriteCapacityUnits
   344  
   345  	return &dynamodb.UpdateTableOutput{
   346  		TableDescription: &dynamodb.TableDescription{
   347  			TableArn: aws.String(arnPrefix + *input.TableName),
   348  		},
   349  	}, nil
   350  }
   351  
   352  func (m *mockDynamoDBClient) TagResourceWithContext(_ aws.Context, input *dynamodb.TagResourceInput, _ ...request.Option) (*dynamodb.TagResourceOutput, error) {
   353  	m.mtx.Lock()
   354  	defer m.mtx.Unlock()
   355  
   356  	if len(input.Tags) == 0 {
   357  		return nil, fmt.Errorf("tags are required")
   358  	}
   359  
   360  	if !strings.HasPrefix(*input.ResourceArn, arnPrefix) {
   361  		return nil, fmt.Errorf("not an arn: %v", *input.ResourceArn)
   362  	}
   363  
   364  	table, ok := m.tables[strings.TrimPrefix(*input.ResourceArn, arnPrefix)]
   365  	if !ok {
   366  		return nil, fmt.Errorf("not found")
   367  	}
   368  
   369  	table.tags = input.Tags
   370  	return &dynamodb.TagResourceOutput{}, nil
   371  }
   372  
   373  func (m *mockDynamoDBClient) ListTagsOfResourceWithContext(_ aws.Context, input *dynamodb.ListTagsOfResourceInput, _ ...request.Option) (*dynamodb.ListTagsOfResourceOutput, error) {
   374  	m.mtx.RLock()
   375  	defer m.mtx.RUnlock()
   376  
   377  	if !strings.HasPrefix(*input.ResourceArn, arnPrefix) {
   378  		return nil, fmt.Errorf("not an arn: %v", *input.ResourceArn)
   379  	}
   380  
   381  	table, ok := m.tables[strings.TrimPrefix(*input.ResourceArn, arnPrefix)]
   382  	if !ok {
   383  		return nil, fmt.Errorf("not found")
   384  	}
   385  
   386  	return &dynamodb.ListTagsOfResourceOutput{
   387  		Tags: table.tags,
   388  	}, nil
   389  }
   390  
   391  type mockS3 struct {
   392  	s3iface.S3API
   393  	sync.RWMutex
   394  	objects map[string][]byte
   395  }
   396  
   397  func newMockS3() *mockS3 {
   398  	return &mockS3{
   399  		objects: map[string][]byte{},
   400  	}
   401  }
   402  
   403  func (m *mockS3) PutObjectWithContext(_ aws.Context, req *s3.PutObjectInput, _ ...request.Option) (*s3.PutObjectOutput, error) {
   404  	m.Lock()
   405  	defer m.Unlock()
   406  
   407  	buf, err := ioutil.ReadAll(req.Body)
   408  	if err != nil {
   409  		return nil, err
   410  	}
   411  
   412  	m.objects[*req.Key] = buf
   413  	return &s3.PutObjectOutput{}, nil
   414  }
   415  
   416  func (m *mockS3) GetObjectWithContext(_ aws.Context, req *s3.GetObjectInput, _ ...request.Option) (*s3.GetObjectOutput, error) {
   417  	m.RLock()
   418  	defer m.RUnlock()
   419  
   420  	buf, ok := m.objects[*req.Key]
   421  	if !ok {
   422  		return nil, fmt.Errorf("Not found")
   423  	}
   424  
   425  	return &s3.GetObjectOutput{
   426  		Body: ioutil.NopCloser(bytes.NewReader(buf)),
   427  	}, nil
   428  }