github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/kv/cosmosdb/store.go (about)

     1  package cosmosdb
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/base32"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"net/http"
    11  	"sort"
    12  
    13  	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
    14  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
    15  	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
    16  	"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
    17  	"github.com/treeverse/lakefs/pkg/ident"
    18  	"github.com/treeverse/lakefs/pkg/kv"
    19  	"github.com/treeverse/lakefs/pkg/kv/kvparams"
    20  	"github.com/treeverse/lakefs/pkg/logging"
    21  )
    22  
    23  type Driver struct{}
    24  
    25  type Store struct {
    26  	containerClient  *azcosmos.ContainerClient
    27  	consistencyLevel azcosmos.ConsistencyLevel
    28  	logger           logging.Logger
    29  }
    30  
    31  const (
    32  	DriverName = "cosmosdb"
    33  )
    34  
    35  // encoding is the encoding used to encode the partition keys, ids and values.
    36  // Must be an encoding that keeps the strings in-order.
    37  var encoding = base32.HexEncoding // Encoding that keeps the strings in-order.
    38  
    39  //nolint:gochecknoinits
    40  func init() {
    41  	kv.Register(DriverName, &Driver{})
    42  }
    43  
    44  // Open - opens and returns a KV store over CosmosDB. This function creates the DB session
    45  // and sets up the KV table.
    46  func (d *Driver) Open(ctx context.Context, kvParams kvparams.Config) (kv.Store, error) {
    47  	params := kvParams.CosmosDB
    48  	if params == nil {
    49  		return nil, fmt.Errorf("missing %s settings: %w", DriverName, kv.ErrDriverConfiguration)
    50  	}
    51  	if params.Endpoint == "" {
    52  		return nil, fmt.Errorf("missing endpoint: %w", kv.ErrDriverConfiguration)
    53  	}
    54  	if params.Database == "" {
    55  		return nil, fmt.Errorf("missing database: %w", kv.ErrDriverConfiguration)
    56  	}
    57  	if params.Container == "" {
    58  		return nil, fmt.Errorf("missing container: %w", kv.ErrDriverConfiguration)
    59  	}
    60  
    61  	logger := logging.FromContext(ctx).WithField("store", DriverName)
    62  	logger.Infof("CosmosDB: connecting to %s", params.Endpoint)
    63  
    64  	var client *azcosmos.Client
    65  	if params.Key != "" {
    66  		cred, err := azcosmos.NewKeyCredential(params.Key)
    67  		if err != nil {
    68  			return nil, fmt.Errorf("creating key: %w", err)
    69  		}
    70  
    71  		// hook for using emulator for testing
    72  		if params.Client == nil {
    73  			params.Client = http.DefaultClient
    74  		}
    75  		// Create a CosmosDB client
    76  		client, err = azcosmos.NewClientWithKey(params.Endpoint, cred, &azcosmos.ClientOptions{
    77  			ClientOptions: azcore.ClientOptions{
    78  				Transport: params.Client,
    79  			},
    80  		})
    81  		if err != nil {
    82  			return nil, fmt.Errorf("creating client using access key: %w", err)
    83  		}
    84  	} else {
    85  		cred, err := azidentity.NewDefaultAzureCredential(nil)
    86  		if err != nil {
    87  			return nil, fmt.Errorf("default creds: %w", err)
    88  		}
    89  		client, err = azcosmos.NewClient(params.Endpoint, cred, nil)
    90  		if err != nil {
    91  			return nil, fmt.Errorf("creating client with default creds: %w", err)
    92  		}
    93  	}
    94  
    95  	dbClient, err := getOrCreateDatabase(ctx, client, params)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	// Create container client
   101  	containerClient, err := getOrCreateContainer(ctx, dbClient, params)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	cLevel := azcosmos.ConsistencyLevelBoundedStaleness
   107  	if !params.StrongConsistency {
   108  		cLevel = azcosmos.ConsistencyLevelSession
   109  	}
   110  	return &Store{
   111  		containerClient:  containerClient,
   112  		consistencyLevel: cLevel,
   113  		logger:           logger,
   114  	}, nil
   115  }
   116  
   117  func getOrCreateDatabase(ctx context.Context, client *azcosmos.Client, params *kvparams.CosmosDB) (*azcosmos.DatabaseClient, error) {
   118  	_, err := client.CreateDatabase(ctx, azcosmos.DatabaseProperties{ID: params.Database}, nil)
   119  	if err != nil {
   120  		if errStatusCode(err) != http.StatusConflict {
   121  			return nil, fmt.Errorf("creating database: %w", err)
   122  		}
   123  	}
   124  	dbClient, err := client.NewDatabase(params.Database)
   125  	if err != nil {
   126  		return nil, fmt.Errorf("init database client: %w", err)
   127  	}
   128  	return dbClient, nil
   129  }
   130  
   131  func getOrCreateContainer(ctx context.Context, dbClient *azcosmos.DatabaseClient, params *kvparams.CosmosDB) (*azcosmos.ContainerClient, error) {
   132  	var opts *azcosmos.CreateContainerOptions
   133  	if params.Throughput > 0 {
   134  		var throughputProperties azcosmos.ThroughputProperties
   135  		if params.Autoscale {
   136  			throughputProperties = azcosmos.NewAutoscaleThroughputProperties(params.Throughput)
   137  		} else {
   138  			throughputProperties = azcosmos.NewManualThroughputProperties(params.Throughput)
   139  		}
   140  		opts = &azcosmos.CreateContainerOptions{ThroughputProperties: &throughputProperties}
   141  	}
   142  
   143  	_, err := dbClient.CreateContainer(ctx,
   144  		azcosmos.ContainerProperties{
   145  			ID: params.Container,
   146  			PartitionKeyDefinition: azcosmos.PartitionKeyDefinition{
   147  				Paths: []string{"/partitionKey"},
   148  			},
   149  			// Excluding the value field from indexing since it is not used in queries and saves RUs for writes.
   150  			// partitionKey is automatically not indexed. The rest of the fields are indexed by default, including id
   151  			// which is unnecessary, but cannot be excluded.
   152  			IndexingPolicy: &azcosmos.IndexingPolicy{
   153  				Automatic:     false,
   154  				IndexingMode:  azcosmos.IndexingModeConsistent,
   155  				IncludedPaths: []azcosmos.IncludedPath{{Path: "/*"}},
   156  				ExcludedPaths: []azcosmos.ExcludedPath{{Path: "/value/?"}},
   157  			},
   158  		}, opts)
   159  	if err != nil {
   160  		if errStatusCode(err) != http.StatusConflict {
   161  			return nil, fmt.Errorf("creating container: %w", err)
   162  		}
   163  	}
   164  	containerClient, err := dbClient.NewContainer(params.Container)
   165  	if err != nil {
   166  		return nil, fmt.Errorf("init container client: %w", err)
   167  	}
   168  	return containerClient, nil
   169  }
   170  
   171  // hashID returns a hash of the key that is used as the document id.
   172  func (s *Store) hashID(key []byte) string {
   173  	return encoding.EncodeToString(ident.NewAddressWriter().MarshalBytes(key).Identity())
   174  }
   175  
   176  type Document struct {
   177  	PartitionKey string `json:"partitionKey"`
   178  	// ID is the hash of the key. It is used as the document id for lookup of a single item.
   179  	// CosmosDB has a 1023 byte limit on the id, so we hash the key to ensure it fits.
   180  	ID string `json:"id"`
   181  	// Key is the original key. It is not used listing of items by order.
   182  	Key   string `json:"key"`
   183  	Value string `json:"value"`
   184  }
   185  
   186  func (s *Store) Get(ctx context.Context, partitionKey, key []byte) (*kv.ValueWithPredicate, error) {
   187  	if len(partitionKey) == 0 {
   188  		return nil, kv.ErrMissingPartitionKey
   189  	}
   190  	if len(key) == 0 {
   191  		return nil, kv.ErrMissingKey
   192  	}
   193  	item := Document{
   194  		PartitionKey: encoding.EncodeToString(partitionKey),
   195  		ID:           s.hashID(key),
   196  	}
   197  	pk := azcosmos.NewPartitionKeyString(item.PartitionKey)
   198  
   199  	// Read an item
   200  	itemResponse, err := s.containerClient.ReadItem(ctx, pk, item.ID, nil)
   201  	if err != nil {
   202  		return nil, convertError(err)
   203  	}
   204  
   205  	var itemResponseBody Document
   206  	err = json.Unmarshal(itemResponse.Value, &itemResponseBody)
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	val, err := encoding.DecodeString(itemResponseBody.Value)
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  	return &kv.ValueWithPredicate{
   216  		Value:     val,
   217  		Predicate: kv.Predicate([]byte(itemResponse.ETag)),
   218  	}, nil
   219  }
   220  
   221  func errStatusCode(err error) int {
   222  	var respErr *azcore.ResponseError
   223  	if !errors.As(err, &respErr) {
   224  		return -1
   225  	}
   226  	return respErr.StatusCode
   227  }
   228  
   229  func isErrStatusCode(err error, code int) bool {
   230  	return errStatusCode(err) == code
   231  }
   232  
   233  func convertError(err error) error {
   234  	statusCode := errStatusCode(err)
   235  	switch statusCode {
   236  	case http.StatusTooManyRequests:
   237  		return kv.ErrSlowDown
   238  	case http.StatusPreconditionFailed:
   239  		return kv.ErrPredicateFailed
   240  	case http.StatusNotFound:
   241  		return kv.ErrNotFound
   242  	case http.StatusConflict:
   243  		return kv.ErrPredicateFailed
   244  	}
   245  	return err
   246  }
   247  
   248  func (s *Store) Set(ctx context.Context, partitionKey, key, value []byte) error {
   249  	if len(partitionKey) == 0 {
   250  		return kv.ErrMissingPartitionKey
   251  	}
   252  	if len(key) == 0 {
   253  		return kv.ErrMissingKey
   254  	}
   255  	if value == nil {
   256  		return kv.ErrMissingValue
   257  	}
   258  
   259  	// Specifies the value of the partiton key
   260  	item := Document{
   261  		PartitionKey: encoding.EncodeToString(partitionKey),
   262  		ID:           s.hashID(key),
   263  		Key:          encoding.EncodeToString(key),
   264  		Value:        encoding.EncodeToString(value),
   265  	}
   266  
   267  	b, err := json.Marshal(item)
   268  	if err != nil {
   269  		return err
   270  	}
   271  	itemOptions := azcosmos.ItemOptions{
   272  		ConsistencyLevel: s.consistencyLevel.ToPtr(),
   273  	}
   274  	pk := azcosmos.NewPartitionKeyString(item.PartitionKey)
   275  
   276  	_, err = s.containerClient.UpsertItem(ctx, pk, b, &itemOptions)
   277  	return convertError(err)
   278  }
   279  
   280  func (s *Store) SetIf(ctx context.Context, partitionKey, key, value []byte, valuePredicate kv.Predicate) error {
   281  	if len(partitionKey) == 0 {
   282  		return kv.ErrMissingPartitionKey
   283  	}
   284  	if len(key) == 0 {
   285  		return kv.ErrMissingKey
   286  	}
   287  	if value == nil {
   288  		return kv.ErrMissingValue
   289  	}
   290  
   291  	// Specifies the value of the partiton key
   292  	item := Document{
   293  		PartitionKey: encoding.EncodeToString(partitionKey),
   294  		ID:           s.hashID(key),
   295  		Key:          encoding.EncodeToString(key),
   296  		Value:        encoding.EncodeToString(value),
   297  	}
   298  
   299  	b, err := json.Marshal(item)
   300  	if err != nil {
   301  		return err
   302  	}
   303  	itemOptions := azcosmos.ItemOptions{
   304  		ConsistencyLevel: s.consistencyLevel.ToPtr(),
   305  	}
   306  	pk := azcosmos.NewPartitionKeyString(item.PartitionKey)
   307  
   308  	switch valuePredicate {
   309  	case nil:
   310  		_, err = s.containerClient.CreateItem(ctx, pk, b, &itemOptions)
   311  	case kv.PrecondConditionalExists:
   312  		patch := azcosmos.PatchOperations{}
   313  		patch.AppendReplace("/value", item.Value)
   314  		_, err = s.containerClient.PatchItem(
   315  			ctx,
   316  			pk,
   317  			item.ID,
   318  			patch,
   319  			&itemOptions,
   320  		)
   321  		if isErrStatusCode(err, http.StatusNotFound) {
   322  			return kv.ErrPredicateFailed
   323  		}
   324  	default:
   325  		etag := azcore.ETag(valuePredicate.([]byte))
   326  		itemOptions.IfMatchEtag = &etag
   327  		_, err = s.containerClient.UpsertItem(ctx, pk, b, &itemOptions)
   328  	}
   329  	return convertError(err)
   330  }
   331  
   332  func (s *Store) Delete(ctx context.Context, partitionKey, key []byte) error {
   333  	if len(partitionKey) == 0 {
   334  		return kv.ErrMissingPartitionKey
   335  	}
   336  	if len(key) == 0 {
   337  		return kv.ErrMissingKey
   338  	}
   339  	pk := azcosmos.NewPartitionKeyString(encoding.EncodeToString(partitionKey))
   340  
   341  	_, err := s.containerClient.DeleteItem(ctx, pk, s.hashID(key), nil)
   342  	if err != nil {
   343  		err = convertError(err)
   344  	}
   345  	if !errors.Is(err, kv.ErrNotFound) {
   346  		return err
   347  	}
   348  	return nil
   349  }
   350  
   351  func (s *Store) Scan(ctx context.Context, partitionKey []byte, options kv.ScanOptions) (kv.EntriesIterator, error) {
   352  	if len(partitionKey) == 0 {
   353  		return nil, kv.ErrMissingPartitionKey
   354  	}
   355  	it := &EntriesIterator{
   356  		store:        s,
   357  		partitionKey: partitionKey,
   358  		startKey:     options.KeyStart,
   359  		limit:        options.BatchSize,
   360  		queryCtx:     ctx,
   361  		encoding:     encoding,
   362  	}
   363  	if err := it.runQuery(); err != nil {
   364  		return nil, convertError(err)
   365  	}
   366  	return it, nil
   367  }
   368  
   369  func (s *Store) Close() {
   370  }
   371  
   372  type EntriesIterator struct {
   373  	store        *Store
   374  	partitionKey []byte
   375  	startKey     []byte
   376  	limit        int
   377  
   378  	entry        *kv.Entry
   379  	err          error
   380  	currEntryIdx int
   381  	queryPager   *runtime.Pager[azcosmos.QueryItemsResponse]
   382  	queryCtx     context.Context
   383  	currPage     azcosmos.QueryItemsResponse
   384  	// currPageSeekedKey is the key we seeked to get this page, will be nil if this page wasn't returned by the query
   385  	currPageSeekedKey []byte
   386  	// currPageHasMore is true if the current page has more after it
   387  	encoding *base32.Encoding
   388  }
   389  
   390  func (e *EntriesIterator) getKeyValue(i int) ([]byte, []byte) {
   391  	var itemResponseBody Document
   392  	err := json.Unmarshal(e.currPage.Items[i], &itemResponseBody)
   393  	if err != nil {
   394  		e.err = fmt.Errorf("failed to unmarshal: %w", err)
   395  		return nil, nil
   396  	}
   397  	key, err := e.encoding.DecodeString(itemResponseBody.Key)
   398  	if err != nil {
   399  		e.err = fmt.Errorf("failed to decode id: %w", err)
   400  		return nil, nil
   401  	}
   402  	value, err := e.encoding.DecodeString(itemResponseBody.Value)
   403  	if err != nil {
   404  		e.err = fmt.Errorf("failed to decode value: %w", err)
   405  		return nil, nil
   406  	}
   407  	return key, value
   408  }
   409  
   410  func (e *EntriesIterator) Next() bool {
   411  	if e.err != nil {
   412  		return false
   413  	}
   414  
   415  	if e.currEntryIdx+1 >= len(e.currPage.Items) {
   416  		if !e.queryPager.More() {
   417  			return false
   418  		}
   419  		var err error
   420  		e.currPage, err = e.queryPager.NextPage(e.queryCtx)
   421  		if err != nil {
   422  			e.err = fmt.Errorf("getting next page: %w", convertError(err))
   423  			return false
   424  		}
   425  		if len(e.currPage.Items) == 0 {
   426  			// returned page is empty, no more items
   427  			return false
   428  		}
   429  		e.currPageSeekedKey = nil
   430  		e.currEntryIdx = -1
   431  	}
   432  	e.currEntryIdx++
   433  	key, value := e.getKeyValue(e.currEntryIdx)
   434  	if e.err != nil {
   435  		return false
   436  	}
   437  	e.entry = &kv.Entry{
   438  		Key:   key,
   439  		Value: value,
   440  	}
   441  
   442  	return true
   443  }
   444  
   445  func (e *EntriesIterator) SeekGE(key []byte) {
   446  	e.startKey = key
   447  	if !e.isInRange() {
   448  		if err := e.runQuery(); err != nil {
   449  			e.err = convertError(err)
   450  		}
   451  		return
   452  	}
   453  	idx := sort.Search(len(e.currPage.Items), func(i int) bool {
   454  		currentKey, _ := e.getKeyValue(i)
   455  		if e.err != nil {
   456  			return false
   457  		}
   458  		return bytes.Compare(key, currentKey) <= 0
   459  	})
   460  	if idx == -1 {
   461  		// not found, set to the end
   462  		e.currEntryIdx = len(e.currPage.Items)
   463  	}
   464  	e.currEntryIdx = idx - 1
   465  }
   466  
   467  func (e *EntriesIterator) Entry() *kv.Entry {
   468  	return e.entry
   469  }
   470  
   471  func (e *EntriesIterator) Err() error {
   472  	return e.err
   473  }
   474  
   475  func (e *EntriesIterator) Close() {
   476  	e.err = kv.ErrClosedEntries
   477  }
   478  
   479  func (e *EntriesIterator) runQuery() error {
   480  	pk := azcosmos.NewPartitionKeyString(encoding.EncodeToString(e.partitionKey))
   481  	e.queryPager = e.store.containerClient.NewQueryItemsPager("select * from c where c.key >= @start order by c.key", pk, &azcosmos.QueryOptions{
   482  		ConsistencyLevel: e.store.consistencyLevel.ToPtr(),
   483  		PageSizeHint:     int32(e.limit),
   484  		QueryParameters: []azcosmos.QueryParameter{{
   485  			Name:  "@start",
   486  			Value: encoding.EncodeToString(e.startKey),
   487  		}},
   488  	})
   489  	currPage, err := e.queryPager.NextPage(e.queryCtx)
   490  	if err != nil {
   491  		return err
   492  	}
   493  	e.currEntryIdx = -1
   494  	e.entry = nil
   495  	e.currPage = currPage
   496  	e.currPageSeekedKey = e.startKey
   497  	return nil
   498  }
   499  
   500  // isInRange checks if e.startKey falls within the range of keys on the current page.
   501  // To optimize range checking:
   502  // - If the current page is a result of a seek operation, the seeked key is used as the minimum key.
   503  // - If the current page is the last page, all keys greater than the minimum key are considered in range.
   504  // This function returns true if e.startKey is within these defined range criteria.
   505  func (e *EntriesIterator) isInRange() bool {
   506  	if e.err != nil {
   507  		return false
   508  	}
   509  	var minKey []byte
   510  	if e.currPageSeekedKey != nil {
   511  		minKey = e.currPageSeekedKey
   512  	} else {
   513  		if len(e.currPage.Items) == 0 {
   514  			return false
   515  		}
   516  		minKey, _ = e.getKeyValue(0)
   517  		if minKey == nil {
   518  			return false
   519  		}
   520  	}
   521  	if bytes.Compare(e.startKey, minKey) < 0 {
   522  		return false
   523  	}
   524  	if !e.queryPager.More() {
   525  		// last page, all keys greater than minKey are considered in range (in order to avoid unnecessary queries)
   526  		return true
   527  	}
   528  	if len(e.currPage.Items) == 0 {
   529  		// cosmosdb returned empty page but has more results, should not happen
   530  		return false
   531  	}
   532  	maxKey, _ := e.getKeyValue(len(e.currPage.Items) - 1)
   533  	return maxKey != nil && bytes.Compare(e.startKey, maxKey) <= 0
   534  }