github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/kv/cosmosdb/store.go (about) 1 package cosmosdb 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/base32" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "net/http" 11 "sort" 12 13 "github.com/Azure/azure-sdk-for-go/sdk/azcore" 14 "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" 15 "github.com/Azure/azure-sdk-for-go/sdk/azidentity" 16 "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 17 "github.com/treeverse/lakefs/pkg/ident" 18 "github.com/treeverse/lakefs/pkg/kv" 19 "github.com/treeverse/lakefs/pkg/kv/kvparams" 20 "github.com/treeverse/lakefs/pkg/logging" 21 ) 22 23 type Driver struct{} 24 25 type Store struct { 26 containerClient *azcosmos.ContainerClient 27 consistencyLevel azcosmos.ConsistencyLevel 28 logger logging.Logger 29 } 30 31 const ( 32 DriverName = "cosmosdb" 33 ) 34 35 // encoding is the encoding used to encode the partition keys, ids and values. 36 // Must be an encoding that keeps the strings in-order. 37 var encoding = base32.HexEncoding // Encoding that keeps the strings in-order. 38 39 //nolint:gochecknoinits 40 func init() { 41 kv.Register(DriverName, &Driver{}) 42 } 43 44 // Open - opens and returns a KV store over CosmosDB. This function creates the DB session 45 // and sets up the KV table. 46 func (d *Driver) Open(ctx context.Context, kvParams kvparams.Config) (kv.Store, error) { 47 params := kvParams.CosmosDB 48 if params == nil { 49 return nil, fmt.Errorf("missing %s settings: %w", DriverName, kv.ErrDriverConfiguration) 50 } 51 if params.Endpoint == "" { 52 return nil, fmt.Errorf("missing endpoint: %w", kv.ErrDriverConfiguration) 53 } 54 if params.Database == "" { 55 return nil, fmt.Errorf("missing database: %w", kv.ErrDriverConfiguration) 56 } 57 if params.Container == "" { 58 return nil, fmt.Errorf("missing container: %w", kv.ErrDriverConfiguration) 59 } 60 61 logger := logging.FromContext(ctx).WithField("store", DriverName) 62 logger.Infof("CosmosDB: connecting to %s", params.Endpoint) 63 64 var client *azcosmos.Client 65 if params.Key != "" { 66 cred, err := azcosmos.NewKeyCredential(params.Key) 67 if err != nil { 68 return nil, fmt.Errorf("creating key: %w", err) 69 } 70 71 // hook for using emulator for testing 72 if params.Client == nil { 73 params.Client = http.DefaultClient 74 } 75 // Create a CosmosDB client 76 client, err = azcosmos.NewClientWithKey(params.Endpoint, cred, &azcosmos.ClientOptions{ 77 ClientOptions: azcore.ClientOptions{ 78 Transport: params.Client, 79 }, 80 }) 81 if err != nil { 82 return nil, fmt.Errorf("creating client using access key: %w", err) 83 } 84 } else { 85 cred, err := azidentity.NewDefaultAzureCredential(nil) 86 if err != nil { 87 return nil, fmt.Errorf("default creds: %w", err) 88 } 89 client, err = azcosmos.NewClient(params.Endpoint, cred, nil) 90 if err != nil { 91 return nil, fmt.Errorf("creating client with default creds: %w", err) 92 } 93 } 94 95 dbClient, err := getOrCreateDatabase(ctx, client, params) 96 if err != nil { 97 return nil, err 98 } 99 100 // Create container client 101 containerClient, err := getOrCreateContainer(ctx, dbClient, params) 102 if err != nil { 103 return nil, err 104 } 105 106 cLevel := azcosmos.ConsistencyLevelBoundedStaleness 107 if !params.StrongConsistency { 108 cLevel = azcosmos.ConsistencyLevelSession 109 } 110 return &Store{ 111 containerClient: containerClient, 112 consistencyLevel: cLevel, 113 logger: logger, 114 }, nil 115 } 116 117 func getOrCreateDatabase(ctx context.Context, client *azcosmos.Client, params *kvparams.CosmosDB) (*azcosmos.DatabaseClient, error) { 118 _, err := client.CreateDatabase(ctx, azcosmos.DatabaseProperties{ID: params.Database}, nil) 119 if err != nil { 120 if errStatusCode(err) != http.StatusConflict { 121 return nil, fmt.Errorf("creating database: %w", err) 122 } 123 } 124 dbClient, err := client.NewDatabase(params.Database) 125 if err != nil { 126 return nil, fmt.Errorf("init database client: %w", err) 127 } 128 return dbClient, nil 129 } 130 131 func getOrCreateContainer(ctx context.Context, dbClient *azcosmos.DatabaseClient, params *kvparams.CosmosDB) (*azcosmos.ContainerClient, error) { 132 var opts *azcosmos.CreateContainerOptions 133 if params.Throughput > 0 { 134 var throughputProperties azcosmos.ThroughputProperties 135 if params.Autoscale { 136 throughputProperties = azcosmos.NewAutoscaleThroughputProperties(params.Throughput) 137 } else { 138 throughputProperties = azcosmos.NewManualThroughputProperties(params.Throughput) 139 } 140 opts = &azcosmos.CreateContainerOptions{ThroughputProperties: &throughputProperties} 141 } 142 143 _, err := dbClient.CreateContainer(ctx, 144 azcosmos.ContainerProperties{ 145 ID: params.Container, 146 PartitionKeyDefinition: azcosmos.PartitionKeyDefinition{ 147 Paths: []string{"/partitionKey"}, 148 }, 149 // Excluding the value field from indexing since it is not used in queries and saves RUs for writes. 150 // partitionKey is automatically not indexed. The rest of the fields are indexed by default, including id 151 // which is unnecessary, but cannot be excluded. 152 IndexingPolicy: &azcosmos.IndexingPolicy{ 153 Automatic: false, 154 IndexingMode: azcosmos.IndexingModeConsistent, 155 IncludedPaths: []azcosmos.IncludedPath{{Path: "/*"}}, 156 ExcludedPaths: []azcosmos.ExcludedPath{{Path: "/value/?"}}, 157 }, 158 }, opts) 159 if err != nil { 160 if errStatusCode(err) != http.StatusConflict { 161 return nil, fmt.Errorf("creating container: %w", err) 162 } 163 } 164 containerClient, err := dbClient.NewContainer(params.Container) 165 if err != nil { 166 return nil, fmt.Errorf("init container client: %w", err) 167 } 168 return containerClient, nil 169 } 170 171 // hashID returns a hash of the key that is used as the document id. 172 func (s *Store) hashID(key []byte) string { 173 return encoding.EncodeToString(ident.NewAddressWriter().MarshalBytes(key).Identity()) 174 } 175 176 type Document struct { 177 PartitionKey string `json:"partitionKey"` 178 // ID is the hash of the key. It is used as the document id for lookup of a single item. 179 // CosmosDB has a 1023 byte limit on the id, so we hash the key to ensure it fits. 180 ID string `json:"id"` 181 // Key is the original key. It is not used listing of items by order. 182 Key string `json:"key"` 183 Value string `json:"value"` 184 } 185 186 func (s *Store) Get(ctx context.Context, partitionKey, key []byte) (*kv.ValueWithPredicate, error) { 187 if len(partitionKey) == 0 { 188 return nil, kv.ErrMissingPartitionKey 189 } 190 if len(key) == 0 { 191 return nil, kv.ErrMissingKey 192 } 193 item := Document{ 194 PartitionKey: encoding.EncodeToString(partitionKey), 195 ID: s.hashID(key), 196 } 197 pk := azcosmos.NewPartitionKeyString(item.PartitionKey) 198 199 // Read an item 200 itemResponse, err := s.containerClient.ReadItem(ctx, pk, item.ID, nil) 201 if err != nil { 202 return nil, convertError(err) 203 } 204 205 var itemResponseBody Document 206 err = json.Unmarshal(itemResponse.Value, &itemResponseBody) 207 if err != nil { 208 return nil, err 209 } 210 211 val, err := encoding.DecodeString(itemResponseBody.Value) 212 if err != nil { 213 return nil, err 214 } 215 return &kv.ValueWithPredicate{ 216 Value: val, 217 Predicate: kv.Predicate([]byte(itemResponse.ETag)), 218 }, nil 219 } 220 221 func errStatusCode(err error) int { 222 var respErr *azcore.ResponseError 223 if !errors.As(err, &respErr) { 224 return -1 225 } 226 return respErr.StatusCode 227 } 228 229 func isErrStatusCode(err error, code int) bool { 230 return errStatusCode(err) == code 231 } 232 233 func convertError(err error) error { 234 statusCode := errStatusCode(err) 235 switch statusCode { 236 case http.StatusTooManyRequests: 237 return kv.ErrSlowDown 238 case http.StatusPreconditionFailed: 239 return kv.ErrPredicateFailed 240 case http.StatusNotFound: 241 return kv.ErrNotFound 242 case http.StatusConflict: 243 return kv.ErrPredicateFailed 244 } 245 return err 246 } 247 248 func (s *Store) Set(ctx context.Context, partitionKey, key, value []byte) error { 249 if len(partitionKey) == 0 { 250 return kv.ErrMissingPartitionKey 251 } 252 if len(key) == 0 { 253 return kv.ErrMissingKey 254 } 255 if value == nil { 256 return kv.ErrMissingValue 257 } 258 259 // Specifies the value of the partiton key 260 item := Document{ 261 PartitionKey: encoding.EncodeToString(partitionKey), 262 ID: s.hashID(key), 263 Key: encoding.EncodeToString(key), 264 Value: encoding.EncodeToString(value), 265 } 266 267 b, err := json.Marshal(item) 268 if err != nil { 269 return err 270 } 271 itemOptions := azcosmos.ItemOptions{ 272 ConsistencyLevel: s.consistencyLevel.ToPtr(), 273 } 274 pk := azcosmos.NewPartitionKeyString(item.PartitionKey) 275 276 _, err = s.containerClient.UpsertItem(ctx, pk, b, &itemOptions) 277 return convertError(err) 278 } 279 280 func (s *Store) SetIf(ctx context.Context, partitionKey, key, value []byte, valuePredicate kv.Predicate) error { 281 if len(partitionKey) == 0 { 282 return kv.ErrMissingPartitionKey 283 } 284 if len(key) == 0 { 285 return kv.ErrMissingKey 286 } 287 if value == nil { 288 return kv.ErrMissingValue 289 } 290 291 // Specifies the value of the partiton key 292 item := Document{ 293 PartitionKey: encoding.EncodeToString(partitionKey), 294 ID: s.hashID(key), 295 Key: encoding.EncodeToString(key), 296 Value: encoding.EncodeToString(value), 297 } 298 299 b, err := json.Marshal(item) 300 if err != nil { 301 return err 302 } 303 itemOptions := azcosmos.ItemOptions{ 304 ConsistencyLevel: s.consistencyLevel.ToPtr(), 305 } 306 pk := azcosmos.NewPartitionKeyString(item.PartitionKey) 307 308 switch valuePredicate { 309 case nil: 310 _, err = s.containerClient.CreateItem(ctx, pk, b, &itemOptions) 311 case kv.PrecondConditionalExists: 312 patch := azcosmos.PatchOperations{} 313 patch.AppendReplace("/value", item.Value) 314 _, err = s.containerClient.PatchItem( 315 ctx, 316 pk, 317 item.ID, 318 patch, 319 &itemOptions, 320 ) 321 if isErrStatusCode(err, http.StatusNotFound) { 322 return kv.ErrPredicateFailed 323 } 324 default: 325 etag := azcore.ETag(valuePredicate.([]byte)) 326 itemOptions.IfMatchEtag = &etag 327 _, err = s.containerClient.UpsertItem(ctx, pk, b, &itemOptions) 328 } 329 return convertError(err) 330 } 331 332 func (s *Store) Delete(ctx context.Context, partitionKey, key []byte) error { 333 if len(partitionKey) == 0 { 334 return kv.ErrMissingPartitionKey 335 } 336 if len(key) == 0 { 337 return kv.ErrMissingKey 338 } 339 pk := azcosmos.NewPartitionKeyString(encoding.EncodeToString(partitionKey)) 340 341 _, err := s.containerClient.DeleteItem(ctx, pk, s.hashID(key), nil) 342 if err != nil { 343 err = convertError(err) 344 } 345 if !errors.Is(err, kv.ErrNotFound) { 346 return err 347 } 348 return nil 349 } 350 351 func (s *Store) Scan(ctx context.Context, partitionKey []byte, options kv.ScanOptions) (kv.EntriesIterator, error) { 352 if len(partitionKey) == 0 { 353 return nil, kv.ErrMissingPartitionKey 354 } 355 it := &EntriesIterator{ 356 store: s, 357 partitionKey: partitionKey, 358 startKey: options.KeyStart, 359 limit: options.BatchSize, 360 queryCtx: ctx, 361 encoding: encoding, 362 } 363 if err := it.runQuery(); err != nil { 364 return nil, convertError(err) 365 } 366 return it, nil 367 } 368 369 func (s *Store) Close() { 370 } 371 372 type EntriesIterator struct { 373 store *Store 374 partitionKey []byte 375 startKey []byte 376 limit int 377 378 entry *kv.Entry 379 err error 380 currEntryIdx int 381 queryPager *runtime.Pager[azcosmos.QueryItemsResponse] 382 queryCtx context.Context 383 currPage azcosmos.QueryItemsResponse 384 // currPageSeekedKey is the key we seeked to get this page, will be nil if this page wasn't returned by the query 385 currPageSeekedKey []byte 386 // currPageHasMore is true if the current page has more after it 387 encoding *base32.Encoding 388 } 389 390 func (e *EntriesIterator) getKeyValue(i int) ([]byte, []byte) { 391 var itemResponseBody Document 392 err := json.Unmarshal(e.currPage.Items[i], &itemResponseBody) 393 if err != nil { 394 e.err = fmt.Errorf("failed to unmarshal: %w", err) 395 return nil, nil 396 } 397 key, err := e.encoding.DecodeString(itemResponseBody.Key) 398 if err != nil { 399 e.err = fmt.Errorf("failed to decode id: %w", err) 400 return nil, nil 401 } 402 value, err := e.encoding.DecodeString(itemResponseBody.Value) 403 if err != nil { 404 e.err = fmt.Errorf("failed to decode value: %w", err) 405 return nil, nil 406 } 407 return key, value 408 } 409 410 func (e *EntriesIterator) Next() bool { 411 if e.err != nil { 412 return false 413 } 414 415 if e.currEntryIdx+1 >= len(e.currPage.Items) { 416 if !e.queryPager.More() { 417 return false 418 } 419 var err error 420 e.currPage, err = e.queryPager.NextPage(e.queryCtx) 421 if err != nil { 422 e.err = fmt.Errorf("getting next page: %w", convertError(err)) 423 return false 424 } 425 if len(e.currPage.Items) == 0 { 426 // returned page is empty, no more items 427 return false 428 } 429 e.currPageSeekedKey = nil 430 e.currEntryIdx = -1 431 } 432 e.currEntryIdx++ 433 key, value := e.getKeyValue(e.currEntryIdx) 434 if e.err != nil { 435 return false 436 } 437 e.entry = &kv.Entry{ 438 Key: key, 439 Value: value, 440 } 441 442 return true 443 } 444 445 func (e *EntriesIterator) SeekGE(key []byte) { 446 e.startKey = key 447 if !e.isInRange() { 448 if err := e.runQuery(); err != nil { 449 e.err = convertError(err) 450 } 451 return 452 } 453 idx := sort.Search(len(e.currPage.Items), func(i int) bool { 454 currentKey, _ := e.getKeyValue(i) 455 if e.err != nil { 456 return false 457 } 458 return bytes.Compare(key, currentKey) <= 0 459 }) 460 if idx == -1 { 461 // not found, set to the end 462 e.currEntryIdx = len(e.currPage.Items) 463 } 464 e.currEntryIdx = idx - 1 465 } 466 467 func (e *EntriesIterator) Entry() *kv.Entry { 468 return e.entry 469 } 470 471 func (e *EntriesIterator) Err() error { 472 return e.err 473 } 474 475 func (e *EntriesIterator) Close() { 476 e.err = kv.ErrClosedEntries 477 } 478 479 func (e *EntriesIterator) runQuery() error { 480 pk := azcosmos.NewPartitionKeyString(encoding.EncodeToString(e.partitionKey)) 481 e.queryPager = e.store.containerClient.NewQueryItemsPager("select * from c where c.key >= @start order by c.key", pk, &azcosmos.QueryOptions{ 482 ConsistencyLevel: e.store.consistencyLevel.ToPtr(), 483 PageSizeHint: int32(e.limit), 484 QueryParameters: []azcosmos.QueryParameter{{ 485 Name: "@start", 486 Value: encoding.EncodeToString(e.startKey), 487 }}, 488 }) 489 currPage, err := e.queryPager.NextPage(e.queryCtx) 490 if err != nil { 491 return err 492 } 493 e.currEntryIdx = -1 494 e.entry = nil 495 e.currPage = currPage 496 e.currPageSeekedKey = e.startKey 497 return nil 498 } 499 500 // isInRange checks if e.startKey falls within the range of keys on the current page. 501 // To optimize range checking: 502 // - If the current page is a result of a seek operation, the seeked key is used as the minimum key. 503 // - If the current page is the last page, all keys greater than the minimum key are considered in range. 504 // This function returns true if e.startKey is within these defined range criteria. 505 func (e *EntriesIterator) isInRange() bool { 506 if e.err != nil { 507 return false 508 } 509 var minKey []byte 510 if e.currPageSeekedKey != nil { 511 minKey = e.currPageSeekedKey 512 } else { 513 if len(e.currPage.Items) == 0 { 514 return false 515 } 516 minKey, _ = e.getKeyValue(0) 517 if minKey == nil { 518 return false 519 } 520 } 521 if bytes.Compare(e.startKey, minKey) < 0 { 522 return false 523 } 524 if !e.queryPager.More() { 525 // last page, all keys greater than minKey are considered in range (in order to avoid unnecessary queries) 526 return true 527 } 528 if len(e.currPage.Items) == 0 { 529 // cosmosdb returned empty page but has more results, should not happen 530 return false 531 } 532 maxKey, _ := e.getKeyValue(len(e.currPage.Items) - 1) 533 return maxKey != nil && bytes.Compare(e.startKey, maxKey) <= 0 534 }