github.com/thiagoyeds/go-cloud@v0.26.0/docstore/awsdynamodb/query.go (about) 1 // Copyright 2019 The Go Cloud Development Kit Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package awsdynamodb 16 17 import ( 18 "bytes" 19 "context" 20 "errors" 21 "fmt" 22 "io" 23 "sort" 24 "strings" 25 "time" 26 27 "github.com/aws/aws-sdk-go/aws" 28 dyn "github.com/aws/aws-sdk-go/service/dynamodb" 29 "github.com/aws/aws-sdk-go/service/dynamodb/expression" 30 "gocloud.dev/docstore/driver" 31 "gocloud.dev/gcerrors" 32 "gocloud.dev/internal/gcerr" 33 ) 34 35 // TODO: support parallel scans (http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Scan.html#Scan.ParallelScan) 36 37 // TODO(jba): support an empty item slice returned from an RPC: "A Query operation can 38 // return an empty result set and a LastEvaluatedKey if all the items read for the 39 // page of results are filtered out." 40 41 type avmap = map[string]*dyn.AttributeValue 42 43 func (c *collection) RunGetQuery(ctx context.Context, q *driver.Query) (driver.DocumentIterator, error) { 44 qr, err := c.planQuery(q) 45 if err != nil { 46 if gcerrors.Code(err) == gcerrors.Unimplemented && c.opts.RunQueryFallback != nil { 47 return c.opts.RunQueryFallback(ctx, q, c.RunGetQuery) 48 } 49 return nil, err 50 } 51 if err := c.checkPlan(qr); err != nil { 52 return nil, err 53 } 54 it := &documentIterator{ 55 qr: qr, 56 limit: q.Limit, 57 count: 0, // manually count limit since dynamodb uses "limit" as scan limit before filtering 58 } 59 it.items, it.last, it.asFunc, err = it.qr.run(ctx, nil) 60 if err != nil { 61 return nil, err 62 } 63 return it, nil 64 } 65 66 func (c *collection) checkPlan(qr *queryRunner) error { 67 if qr.scanIn != nil && qr.scanIn.FilterExpression != nil && !c.opts.AllowScans { 68 return gcerr.Newf(gcerr.InvalidArgument, nil, "query requires a table scan; set Options.AllowScans to true to enable") 69 } 70 return nil 71 } 72 73 func (c *collection) planQuery(q *driver.Query) (*queryRunner, error) { 74 var cb expression.Builder 75 cbUsed := false // It's an error to build an empty Builder. 76 // Set up the projection expression. 77 if len(q.FieldPaths) > 0 { 78 var pb expression.ProjectionBuilder 79 hasFields := map[string]bool{} 80 for _, fp := range q.FieldPaths { 81 if len(fp) == 1 { 82 hasFields[fp[0]] = true 83 } 84 pb = pb.AddNames(expression.Name(strings.Join(fp, "."))) 85 } 86 // Always include the keys. 87 for _, f := range []string{c.partitionKey, c.sortKey} { 88 if f != "" && !hasFields[f] { 89 pb = pb.AddNames(expression.Name(f)) 90 q.FieldPaths = append(q.FieldPaths, []string{f}) 91 } 92 } 93 cb = cb.WithProjection(pb) 94 cbUsed = true 95 } 96 97 // Find the best thing to query (table or index). 98 indexName, pkey, skey := c.bestQueryable(q) 99 if indexName == nil && pkey == "" { 100 // No query can be done: fall back to scanning. 101 if q.OrderByField != "" { 102 // Scans are unordered, so we can't run this query. 103 // TODO(jba): If the user specifies all the partition keys, and there is a global 104 // secondary index whose sort key is the order-by field, then we can query that index 105 // for every value of the partition key and merge the results. 106 // TODO(jba): If the query has a reasonable limit N, then we can run a scan and keep 107 // the top N documents in memory. 108 return nil, gcerr.Newf(gcerr.Unimplemented, nil, "query requires a table scan, but has an ordering requirement; add an index or provide Options.RunQueryFallback") 109 } 110 if len(q.Filters) > 0 { 111 cb = cb.WithFilter(filtersToConditionBuilder(q.Filters)) 112 cbUsed = true 113 } 114 in := &dyn.ScanInput{ 115 TableName: &c.table, 116 ConsistentRead: aws.Bool(c.opts.ConsistentRead), 117 } 118 if cbUsed { 119 ce, err := cb.Build() 120 if err != nil { 121 return nil, err 122 } 123 in.ExpressionAttributeNames = ce.Names() 124 in.ExpressionAttributeValues = ce.Values() 125 in.FilterExpression = ce.Filter() 126 in.ProjectionExpression = ce.Projection() 127 } 128 return &queryRunner{c: c, scanIn: in, beforeRun: q.BeforeQuery}, nil 129 } 130 131 // Do a query. 132 cb = processFilters(cb, q.Filters, pkey, skey) 133 ce, err := cb.Build() 134 if err != nil { 135 return nil, err 136 } 137 qIn := &dyn.QueryInput{ 138 TableName: &c.table, 139 IndexName: indexName, 140 ExpressionAttributeNames: ce.Names(), 141 ExpressionAttributeValues: ce.Values(), 142 KeyConditionExpression: ce.KeyCondition(), 143 FilterExpression: ce.Filter(), 144 ProjectionExpression: ce.Projection(), 145 ConsistentRead: aws.Bool(c.opts.ConsistentRead), 146 } 147 if q.OrderByField != "" && !q.OrderAscending { 148 qIn.ScanIndexForward = &q.OrderAscending 149 } 150 return &queryRunner{ 151 c: c, 152 queryIn: qIn, 153 beforeRun: q.BeforeQuery, 154 }, nil 155 } 156 157 // Return the best choice of queryable (table or index) for this query. 158 // How to interpret the return values: 159 // - If indexName is nil but pkey is not empty, then use the table. 160 // - If all return values are zero, no query will work: do a scan. 161 func (c *collection) bestQueryable(q *driver.Query) (indexName *string, pkey, skey string) { 162 // If the query has an "=" filter on the table's partition key, look at the table 163 // and local indexes. 164 if hasEqualityFilter(q, c.partitionKey) { 165 // If the table has a sort key that's in the query, and the ordering 166 // constraint works with the sort key, use the table. 167 // (Query results are always ordered by sort key.) 168 if hasFilter(q, c.sortKey) && orderingConsistent(q, c.sortKey) { 169 return nil, c.partitionKey, c.sortKey 170 } 171 // Look at local indexes. They all have the same partition key as the base table. 172 // If one has a sort key in the query, use it. 173 for _, li := range c.description.LocalSecondaryIndexes { 174 pkey, skey := keyAttributes(li.KeySchema) 175 if hasFilter(q, skey) && localFieldsIncluded(q, li) && orderingConsistent(q, skey) { 176 return li.IndexName, pkey, skey 177 } 178 } 179 } 180 // Consider the global indexes: if one has a matching partition and sort key, and 181 // the projected fields of the index include those of the query, use it. 182 for _, gi := range c.description.GlobalSecondaryIndexes { 183 pkey, skey := keyAttributes(gi.KeySchema) 184 if skey == "" { 185 continue // We'll visit global indexes without a sort key later. 186 } 187 if hasEqualityFilter(q, pkey) && hasFilter(q, skey) && c.globalFieldsIncluded(q, gi) && orderingConsistent(q, skey) { 188 return gi.IndexName, pkey, skey 189 } 190 } 191 // There are no matches for both partition and sort key. Now consider matches on partition key only. 192 // That will still be better than a scan. 193 // First, check the table itself. 194 if hasEqualityFilter(q, c.partitionKey) && orderingConsistent(q, c.sortKey) { 195 return nil, c.partitionKey, c.sortKey 196 } 197 // No point checking local indexes: they have the same partition key as the table. 198 // Check the global indexes. 199 for _, gi := range c.description.GlobalSecondaryIndexes { 200 pkey, skey := keyAttributes(gi.KeySchema) 201 if hasEqualityFilter(q, pkey) && c.globalFieldsIncluded(q, gi) && orderingConsistent(q, skey) { 202 return gi.IndexName, pkey, skey 203 } 204 } 205 // We cannot do a query. 206 // TODO: return the reason why we couldn't. At a minimum, distinguish failure due to keys 207 // from failure due to projection (i.e. a global index had the right partition and sort key, 208 // but didn't project the necessary fields). 209 return nil, "", "" 210 } 211 212 // localFieldsIncluded reports whether a local index supports all the selected fields 213 // of a query. Since DynamoDB will read explicitly provided fields from the table if 214 // they are not projected into the index, the only case where a local index cannot 215 // be used is when the query wants all the fields, and the index projection is not ALL. 216 // See https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/LSI.html#LSI.Projections. 217 func localFieldsIncluded(q *driver.Query, li *dyn.LocalSecondaryIndexDescription) bool { 218 return len(q.FieldPaths) > 0 || *li.Projection.ProjectionType == "ALL" 219 } 220 221 // orderingConsistent reports whether the ordering constraint is consistent with the sort key field. 222 // That is, either there is no OrderBy clause, or the clause specifies the sort field. 223 func orderingConsistent(q *driver.Query, sortField string) bool { 224 return q.OrderByField == "" || q.OrderByField == sortField 225 } 226 227 // globalFieldsIncluded reports whether the fields selected by the query are 228 // projected into (that is, contained directly in) the global index. We need this 229 // check before using the index, because if a global index doesn't have all the 230 // desired fields, then a separate RPC for each returned item would be necessary to 231 // retrieve those fields, and we'd rather scan than do that. 232 func (c *collection) globalFieldsIncluded(q *driver.Query, gi *dyn.GlobalSecondaryIndexDescription) bool { 233 proj := gi.Projection 234 if *proj.ProjectionType == "ALL" { 235 // The index has all the fields of the table: we're good. 236 return true 237 } 238 if len(q.FieldPaths) == 0 { 239 // The query wants all the fields of the table, but we can't be sure that the 240 // index has them. 241 return false 242 } 243 // The table's keys and the index's keys are always in the index. 244 pkey, skey := keyAttributes(gi.KeySchema) 245 indexFields := map[string]bool{c.partitionKey: true, pkey: true} 246 if c.sortKey != "" { 247 indexFields[c.sortKey] = true 248 } 249 if skey != "" { 250 indexFields[skey] = true 251 } 252 for _, nka := range proj.NonKeyAttributes { 253 indexFields[*nka] = true 254 } 255 // Every field path in the query must be in the index. 256 for _, fp := range q.FieldPaths { 257 if !indexFields[strings.Join(fp, ".")] { 258 return false 259 } 260 } 261 return true 262 } 263 264 // Extract the names of the partition and sort key attributes from the schema of a 265 // table or index. 266 func keyAttributes(ks []*dyn.KeySchemaElement) (pkey, skey string) { 267 for _, k := range ks { 268 switch *k.KeyType { 269 case "HASH": 270 pkey = *k.AttributeName 271 case "RANGE": 272 skey = *k.AttributeName 273 default: 274 panic("bad key type: " + *k.KeyType) 275 } 276 } 277 return pkey, skey 278 } 279 280 // Reports whether q has a filter that mentions the top-level field. 281 func hasFilter(q *driver.Query, field string) bool { 282 if field == "" { 283 return false 284 } 285 for _, f := range q.Filters { 286 if driver.FieldPathEqualsField(f.FieldPath, field) { 287 return true 288 } 289 } 290 return false 291 } 292 293 // Reports whether q has a filter that checks if the top-level field is equal to something. 294 func hasEqualityFilter(q *driver.Query, field string) bool { 295 for _, f := range q.Filters { 296 if f.Op == driver.EqualOp && driver.FieldPathEqualsField(f.FieldPath, field) { 297 return true 298 } 299 } 300 return false 301 } 302 303 type queryRunner struct { 304 c *collection 305 scanIn *dyn.ScanInput 306 queryIn *dyn.QueryInput 307 beforeRun func(asFunc func(i interface{}) bool) error 308 } 309 310 func (qr *queryRunner) run(ctx context.Context, startAfter avmap) (items []avmap, last avmap, asFunc func(i interface{}) bool, err error) { 311 if qr.scanIn != nil { 312 qr.scanIn.ExclusiveStartKey = startAfter 313 if qr.beforeRun != nil { 314 asFunc := func(i interface{}) bool { 315 p, ok := i.(**dyn.ScanInput) 316 if !ok { 317 return false 318 } 319 *p = qr.scanIn 320 return true 321 } 322 if err := qr.beforeRun(asFunc); err != nil { 323 return nil, nil, nil, err 324 } 325 } 326 out, err := qr.c.db.ScanWithContext(ctx, qr.scanIn) 327 if err != nil { 328 return nil, nil, nil, err 329 } 330 return out.Items, out.LastEvaluatedKey, 331 func(i interface{}) bool { 332 p, ok := i.(**dyn.ScanOutput) 333 if !ok { 334 return false 335 } 336 *p = out 337 return true 338 }, nil 339 } 340 qr.queryIn.ExclusiveStartKey = startAfter 341 if qr.beforeRun != nil { 342 asFunc := func(i interface{}) bool { 343 p, ok := i.(**dyn.QueryInput) 344 if !ok { 345 return false 346 } 347 *p = qr.queryIn 348 return true 349 } 350 if err := qr.beforeRun(asFunc); err != nil { 351 return nil, nil, nil, err 352 } 353 } 354 out, err := qr.c.db.QueryWithContext(ctx, qr.queryIn) 355 if err != nil { 356 return nil, nil, nil, err 357 } 358 return out.Items, out.LastEvaluatedKey, 359 func(i interface{}) bool { 360 p, ok := i.(**dyn.QueryOutput) 361 if !ok { 362 return false 363 } 364 *p = out 365 return true 366 }, nil 367 } 368 369 func processFilters(cb expression.Builder, fs []driver.Filter, pkey, skey string) expression.Builder { 370 var kbs []expression.KeyConditionBuilder 371 var cfs []driver.Filter 372 for _, f := range fs { 373 if kb, ok := toKeyCondition(f, pkey, skey); ok { 374 kbs = append(kbs, kb) 375 continue 376 } 377 cfs = append(cfs, f) 378 } 379 keyBuilder := kbs[0] 380 for i := 1; i < len(kbs); i++ { 381 keyBuilder = keyBuilder.And(kbs[i]) 382 } 383 cb = cb.WithKeyCondition(keyBuilder) 384 if len(cfs) > 0 { 385 cb = cb.WithFilter(filtersToConditionBuilder(cfs)) 386 } 387 return cb 388 } 389 390 func filtersToConditionBuilder(fs []driver.Filter) expression.ConditionBuilder { 391 if len(fs) == 0 { 392 panic("no filters") 393 } 394 var cb expression.ConditionBuilder 395 cb = toFilter(fs[0]) 396 for _, f := range fs[1:] { 397 cb = cb.And(toFilter(f)) 398 } 399 return cb 400 } 401 402 func toKeyCondition(f driver.Filter, pkey, skey string) (expression.KeyConditionBuilder, bool) { 403 kp := strings.Join(f.FieldPath, ".") 404 if kp == pkey || kp == skey { 405 key := expression.Key(kp) 406 val := expression.Value(f.Value) 407 switch f.Op { 408 case "<": 409 return expression.KeyLessThan(key, val), true 410 case "<=": 411 return expression.KeyLessThanEqual(key, val), true 412 case driver.EqualOp: 413 return expression.KeyEqual(key, val), true 414 case ">=": 415 return expression.KeyGreaterThanEqual(key, val), true 416 case ">": 417 return expression.KeyGreaterThan(key, val), true 418 default: 419 panic(fmt.Sprint("invalid filter operation:", f.Op)) 420 } 421 } 422 return expression.KeyConditionBuilder{}, false 423 } 424 425 func toFilter(f driver.Filter) expression.ConditionBuilder { 426 name := expression.Name(strings.Join(f.FieldPath, ".")) 427 val := expression.Value(f.Value) 428 switch f.Op { 429 case "<": 430 return expression.LessThan(name, val) 431 case "<=": 432 return expression.LessThanEqual(name, val) 433 case driver.EqualOp: 434 return expression.Equal(name, val) 435 case ">=": 436 return expression.GreaterThanEqual(name, val) 437 case ">": 438 return expression.GreaterThan(name, val) 439 default: 440 panic(fmt.Sprint("invalid filter operation:", f.Op)) 441 } 442 } 443 444 type documentIterator struct { 445 qr *queryRunner 446 items []map[string]*dyn.AttributeValue 447 curr int 448 limit int 449 count int // number of items returned 450 last map[string]*dyn.AttributeValue 451 asFunc func(i interface{}) bool 452 } 453 454 func (it *documentIterator) Next(ctx context.Context, doc driver.Document) error { 455 if it.limit > 0 && it.count >= it.limit || it.curr >= len(it.items) && it.last == nil { 456 return io.EOF 457 } 458 if it.curr >= len(it.items) { 459 // Make a new query request at the end of this page. 460 var err error 461 it.items, it.last, it.asFunc, err = it.qr.run(ctx, it.last) 462 if err != nil { 463 return err 464 } 465 it.curr = 0 466 } 467 if err := decodeDoc(&dyn.AttributeValue{M: it.items[it.curr]}, doc); err != nil { 468 return err 469 } 470 it.curr++ 471 it.count++ 472 return nil 473 } 474 475 func (it *documentIterator) Stop() { 476 it.items = nil 477 it.last = nil 478 } 479 480 func (it *documentIterator) As(i interface{}) bool { 481 return it.asFunc(i) 482 } 483 484 func (c *collection) QueryPlan(q *driver.Query) (string, error) { 485 qr, err := c.planQuery(q) 486 if err != nil { 487 return "", err 488 } 489 return qr.queryPlan(), nil 490 } 491 492 func (qr *queryRunner) queryPlan() string { 493 if qr.scanIn != nil { 494 return "Scan" 495 } 496 if qr.queryIn.IndexName != nil { 497 return fmt.Sprintf("Index: %q", *qr.queryIn.IndexName) 498 } 499 return "Table" 500 } 501 502 // InMemorySortFallback returns a query fallback function for Options.RunQueryFallback. 503 // The function accepts a query with an OrderBy clause. It runs the query without that clause, 504 // reading all documents into memory, then sorts the documents according to the OrderBy clause. 505 // 506 // Only string, numeric, time and binary ([]byte) fields can be sorted. 507 // 508 // createDocument should create an empty document to be passed to DocumentIterator.Next. 509 // The DocumentIterator returned by the FallbackFunc will also expect the same type of document. 510 // If nil, then a map[string]interface{} will be used. 511 func InMemorySortFallback(createDocument func() interface{}) FallbackFunc { 512 if createDocument == nil { 513 createDocument = func() interface{} { return map[string]interface{}{} } 514 } 515 return func(ctx context.Context, q *driver.Query, run RunQueryFunc) (driver.DocumentIterator, error) { 516 if q.OrderByField == "" { 517 return nil, errors.New("InMemorySortFallback expects an OrderBy query") 518 } 519 // Run the query without the OrderBy. 520 orderByField := q.OrderByField 521 q.OrderByField = "" 522 iter, err := run(ctx, q) 523 if err != nil { 524 return nil, err 525 } 526 defer iter.Stop() 527 // Collect the results into a slice. 528 var docs []driver.Document 529 for { 530 doc, err := driver.NewDocument(createDocument()) 531 if err != nil { 532 return nil, err 533 } 534 err = iter.Next(ctx, doc) 535 if err == io.EOF { 536 break 537 } 538 if err != nil { 539 return nil, err 540 } 541 docs = append(docs, doc) 542 } 543 // Sort the documents. 544 // OrderByField is a single field, not a field path. 545 // First, put the field values in another slice, so we can 546 // return on error. 547 sortValues := make([]interface{}, len(docs)) 548 for i, doc := range docs { 549 v, err := doc.GetField(orderByField) 550 if err != nil { 551 return nil, err 552 } 553 sortValues[i] = v 554 } 555 sort.Sort(docsForSorting{docs, sortValues, q.OrderAscending}) 556 return &sliceIterator{docs: docs}, nil 557 } 558 } 559 560 type docsForSorting struct { 561 docs []driver.Document 562 vals []interface{} 563 ascending bool 564 } 565 566 func (d docsForSorting) Len() int { return len(d.docs) } 567 568 func (d docsForSorting) Swap(i, j int) { 569 d.docs[i], d.docs[j] = d.docs[j], d.docs[i] 570 d.vals[i], d.vals[j] = d.vals[j], d.vals[i] 571 } 572 573 func (d docsForSorting) Less(i, j int) bool { 574 c := compare(d.vals[i], d.vals[j]) 575 if d.ascending { 576 return c < 0 577 } else { 578 return c > 0 579 } 580 } 581 582 // compare returns -1 if v1 < v2, 0 if v1 == v2 and 1 if v1 > v2. 583 // 584 // Arbitrarily decide that strings < times < []byte < numbers. 585 // TODO(jba): find and use the actual sort order that DynamoDB uses. 586 func compare(v1, v2 interface{}) int { 587 switch v1 := v1.(type) { 588 case string: 589 if v2, ok := v2.(string); ok { 590 return strings.Compare(v1, v2) 591 } 592 return -1 593 594 case time.Time: 595 if v2, ok := v2.(time.Time); ok { 596 return driver.CompareTimes(v1, v2) 597 } 598 if _, ok := v2.(string); ok { 599 return 1 600 } 601 return -1 602 603 case []byte: 604 if v2, ok := v2.([]byte); ok { 605 return bytes.Compare(v1, v2) 606 } 607 if _, ok := v2.(string); ok { 608 return 1 609 } 610 if _, ok := v2.(time.Time); ok { 611 return 1 612 } 613 return -1 614 615 default: 616 cmp, err := driver.CompareNumbers(v1, v2) 617 if err != nil { 618 return -1 619 } 620 return cmp 621 } 622 } 623 624 type sliceIterator struct { 625 docs []driver.Document 626 next int 627 } 628 629 func (it *sliceIterator) Next(ctx context.Context, doc driver.Document) error { 630 if it.next >= len(it.docs) { 631 return io.EOF 632 } 633 it.next++ 634 return copyTopLevel(doc, it.docs[it.next-1]) 635 } 636 637 // Copy the top-level fields of src into dest. 638 func copyTopLevel(dest, src driver.Document) error { 639 for _, f := range src.FieldNames() { 640 v, err := src.GetField(f) 641 if err != nil { 642 return err 643 } 644 if err := dest.SetField(f, v); err != nil { 645 return err 646 } 647 } 648 return nil 649 } 650 651 func (*sliceIterator) Stop() {} 652 func (*sliceIterator) As(interface{}) bool { return false }