github.com/fragmenta/query@v1.5.3/query.go (about) 1 // Package query lets you build and execute SQL chainable queries against a database of your choice, and defer execution of SQL until you wish to extract a count or array of models. 2 3 // NB in order to allow cross-compilation, we exlude sqlite drivers by default 4 // uncomment them to allow use of sqlite 5 6 package query 7 8 import ( 9 "database/sql" 10 "fmt" 11 "sort" 12 "strconv" 13 "strings" 14 ) 15 16 // FIXME - this package global should in theory be protected by a mutex, even if it is only for debugging 17 18 // Debug sets whether we output debug statements for SQL 19 var Debug bool 20 21 func init() { 22 Debug = false // default to false 23 } 24 25 // Result holds the results of a query as map[string]interface{} 26 type Result map[string]interface{} 27 28 // Func is a function which applies effects to queries 29 type Func func(q *Query) *Query 30 31 // Query provides all the chainable relational query builder methods 32 type Query struct { 33 34 // Database - database name and primary key, set with New() 35 tablename string 36 primarykey string 37 38 // SQL - Private fields used to store sql before building sql query 39 sql string 40 sel string 41 join string 42 where string 43 group string 44 having string 45 order string 46 offset string 47 limit string 48 49 // Extra args to be substituted in the *where* clause 50 args []interface{} 51 } 52 53 // New builds a new Query, given the table and primary key 54 func New(t string, pk string) *Query { 55 56 // If we have no db, return nil 57 if database == nil { 58 return nil 59 } 60 61 q := &Query{ 62 tablename: t, 63 primarykey: pk, 64 } 65 66 return q 67 } 68 69 // Exec the given sql and args against the database directly 70 // Returning sql.Result (NB not rows) 71 func Exec(sql string, args ...interface{}) (sql.Result, error) { 72 results, err := database.Exec(sql, args...) 73 return results, err 74 } 75 76 // Rows executes the given sql and args against the database directly 77 // Returning sql.Rows 78 func Rows(sql string, args ...interface{}) (*sql.Rows, error) { 79 results, err := database.Query(sql, args...) 80 return results, err 81 } 82 83 // Copy returns a new copy of this query which can be mutated without affecting the original 84 func (q *Query) Copy() *Query { 85 return &Query{ 86 tablename: q.tablename, 87 primarykey: q.primarykey, 88 sql: q.sql, 89 sel: q.sel, 90 join: q.join, 91 where: q.where, 92 group: q.group, 93 having: q.having, 94 order: q.order, 95 offset: q.offset, 96 limit: q.limit, 97 args: q.args, 98 } 99 } 100 101 // TODO: These should instead be something like query.New("table_name").Join(a,b).Insert() and just have one multiple function? 102 103 // InsertJoin inserts a join clause on the query 104 func (q *Query) InsertJoin(a int64, b int64) error { 105 return q.InsertJoins([]int64{a}, []int64{b}) 106 } 107 108 // InsertJoins using an array of ids (more general version of above) 109 // This inserts joins for every possible relation between the ids 110 func (q *Query) InsertJoins(a []int64, b []int64) error { 111 112 // Make sure we have some data 113 if len(a) == 0 || len(b) == 0 { 114 return fmt.Errorf("Null data for joins insert %s", q.table()) 115 } 116 117 // Check for null entries in start of data - this is not a good idea. 118 // if a[0] == 0 || b[0] == 0 { 119 // return fmt.Errorf("Zero data for joins insert %s", q.table()) 120 // } 121 122 values := "" 123 for _, av := range a { 124 for _, bv := range b { 125 // NB no zero values allowed, we simply ignore zero values 126 if av != 0 && bv != 0 { 127 values += fmt.Sprintf("(%d,%d),", av, bv) 128 } 129 130 } 131 } 132 133 values = strings.TrimRight(values, ",") 134 135 sql := fmt.Sprintf("INSERT into %s VALUES %s;", q.table(), values) 136 137 if Debug { 138 fmt.Printf("JOINS SQL:%s\n", sql) 139 } 140 141 _, err := database.Exec(sql) 142 return err 143 } 144 145 // UpdateJoins updates the given joins, using the given id to clear joins first 146 func (q *Query) UpdateJoins(id int64, a []int64, b []int64) error { 147 148 if Debug { 149 fmt.Printf("SetJoins %s %s=%d: %v %v \n", q.table(), q.pk(), id, a, b) 150 } 151 152 // First delete any existing joins 153 err := q.Where(fmt.Sprintf("%s=?", q.pk()), id).Delete() 154 if err != nil { 155 return err 156 } 157 158 // Now join all a's with all b's by generating joins for each possible combination 159 160 // Make sure we have data in both cases, otherwise do not attempt insert any joins 161 if len(a) > 0 && len(b) > 0 { 162 // Now insert all new ids - NB the order of arguments here MUST match the order in the table 163 err = q.InsertJoins(a, b) 164 if err != nil { 165 return err 166 } 167 } 168 169 return nil 170 } 171 172 // Insert inserts a record in the database 173 func (q *Query) Insert(params map[string]string) (int64, error) { 174 175 // Insert and retrieve ID in one step from db 176 sql := q.insertSQL(params) 177 178 if Debug { 179 fmt.Printf("INSERT SQL:%s %v\n", sql, valuesFromParams(params)) 180 } 181 182 id, err := database.Insert(sql, valuesFromParams(params)...) 183 if err != nil { 184 return 0, err 185 } 186 187 return id, nil 188 } 189 190 // insertSQL sets the insert sql for update statements, turn params into sql i.e. "col"=? 191 // NB we always use parameterized queries, never string values. 192 func (q *Query) insertSQL(params map[string]string) string { 193 var cols, vals []string 194 195 for i, k := range sortedParamKeys(params) { 196 cols = append(cols, database.QuoteField(k)) 197 vals = append(vals, database.Placeholder(i+1)) 198 } 199 query := fmt.Sprintf("INSERT INTO %s (%s) VALUES(%s) %s;", q.table(), strings.Join(cols, ","), strings.Join(vals, ","), database.InsertSQL(q.pk())) 200 201 return query 202 } 203 204 // Update one model specified in this query - the column names MUST be verified in the model 205 func (q *Query) Update(params map[string]string) error { 206 // We should check the query has a where limitation to avoid updating all? 207 // pq unfortunately does not accept limit(1) here 208 return q.UpdateAll(params) 209 } 210 211 // Delete one model specified in this relation 212 func (q *Query) Delete() error { 213 // We should check the query has a where limitation? 214 return q.DeleteAll() 215 } 216 217 // UpdateAll updates all models specified in this relation 218 func (q *Query) UpdateAll(params map[string]string) error { 219 // Create sql for update from ALL params 220 q.Select(fmt.Sprintf("UPDATE %s SET %s", q.table(), querySQL(params))) 221 222 // Execute, after PREpending params to args 223 // in an update statement, the where comes at the end 224 q.args = append(valuesFromParams(params), q.args...) 225 226 if Debug { 227 fmt.Printf("UPDATE SQL:%s\n%v\n", q.QueryString(), valuesFromParams(params)) 228 } 229 230 _, err := q.Result() 231 232 return err 233 } 234 235 // DeleteAll delets *all* models specified in this relation 236 func (q *Query) DeleteAll() error { 237 238 q.Select(fmt.Sprintf("DELETE FROM %s", q.table())) 239 240 if Debug { 241 fmt.Printf("DELETE SQL:%s <= %v\n", q.QueryString(), q.args) 242 } 243 244 // Execute 245 _, err := q.Result() 246 247 return err 248 } 249 250 // Count fetches a count of model objects (executes SQL). 251 func (q *Query) Count() (int64, error) { 252 253 // In order to get consistent results, we use the same query builder 254 // but reset select to simple count select 255 256 // Store the previous select and set 257 s := q.sel 258 countSelect := fmt.Sprintf("SELECT COUNT(%s) FROM %s", q.pk(), q.table()) 259 q.Select(countSelect) 260 261 // Store the previous order (minus order by) and set to empty 262 // Order must be blank on count because of limited select 263 o := strings.Replace(q.order, "ORDER BY ", "", 1) 264 q.order = "" 265 266 // Fetch count from db for our sql with count select and no order set 267 var count int64 268 rows, err := q.Rows() 269 if err != nil { 270 return 0, fmt.Errorf("Error querying database for count: %s\nQuery:%s", err, q.QueryString()) 271 } 272 273 // We expect just one row, with one column (count) 274 defer rows.Close() 275 for rows.Next() { 276 err = rows.Scan(&count) 277 if err != nil { 278 return 0, err 279 } 280 } 281 282 // Reset select after getting count query 283 q.Select(s) 284 q.Order(o) 285 q.reset() 286 287 return count, err 288 } 289 290 // Result executes the query against the database, returning sql.Result, and error (no rows) 291 // (Executes SQL) 292 func (q *Query) Result() (sql.Result, error) { 293 results, err := database.Exec(q.QueryString(), q.args...) 294 return results, err 295 } 296 297 // Rows executes the query against the database, and return the sql rows result for this query 298 // (Executes SQL) 299 func (q *Query) Rows() (*sql.Rows, error) { 300 results, err := database.Query(q.QueryString(), q.args...) 301 return results, err 302 } 303 304 // FirstResult executes the SQL and returrns the first result 305 func (q *Query) FirstResult() (Result, error) { 306 307 // Set a limit on the query 308 q.Limit(1) 309 310 // Fetch all results (1) 311 results, err := q.Results() 312 if err != nil { 313 return nil, err 314 } 315 316 if len(results) == 0 { 317 return nil, fmt.Errorf("No results found for Query:%s", q.QueryString()) 318 } 319 320 // Return the first result 321 return results[0], nil 322 } 323 324 // ResultInt64 returns the first result from a query stored in the column named col as an int64. 325 func (q *Query) ResultInt64(c string) (int64, error) { 326 result, err := q.FirstResult() 327 if err != nil || result[c] == nil { 328 return 0, err 329 } 330 var i int64 331 switch result[c].(type) { 332 case int64: 333 i = result[c].(int64) 334 case int: 335 i = int64(result[c].(int)) 336 case float64: 337 i = int64(result[c].(float64)) 338 case string: 339 f, err := strconv.ParseFloat(result[c].(string), 64) 340 if err != nil { 341 return i, err 342 } 343 i = int64(f) 344 } 345 346 return i, nil 347 } 348 349 // ResultFloat64 returns the first result from a query stored in the column named col as a float64. 350 func (q *Query) ResultFloat64(c string) (float64, error) { 351 result, err := q.FirstResult() 352 if err != nil || result[c] == nil { 353 return 0, err 354 } 355 var f float64 356 switch result[c].(type) { 357 case float64: 358 f = result[c].(float64) 359 case int: 360 f = float64(result[c].(int)) 361 case int64: 362 f = float64(result[c].(int)) 363 case string: 364 f, err = strconv.ParseFloat(result[c].(string), 64) 365 if err != nil { 366 return f, err 367 } 368 } 369 370 return f, nil 371 } 372 373 // Results returns an array of results 374 func (q *Query) Results() ([]Result, error) { 375 376 // Make an empty result set map 377 var results []Result 378 379 // Fetch rows from db for our sql 380 rows, err := q.Rows() 381 382 if err != nil { 383 return results, fmt.Errorf("Error querying database for rows: %s\nQUERY:%s", err, q) 384 } 385 386 // Close rows before returning 387 defer rows.Close() 388 389 // Fetch the columns from the database 390 cols, err := rows.Columns() 391 if err != nil { 392 return results, fmt.Errorf("Error fetching columns: %s\nQUERY:%s\nCOLS:%s", err, q, cols) 393 } 394 395 // For each row, construct an entry in results with a map of column string keys to values 396 for rows.Next() { 397 result, err := scanRow(cols, rows) 398 if err != nil { 399 return results, fmt.Errorf("Error fetching row: %s\nQUERY:%s\nCOLS:%s", err, q, cols) 400 } 401 results = append(results, result) 402 } 403 404 return results, nil 405 } 406 407 // ResultIDs returns an array of ids as the result of a query 408 // FIXME - this should really use the query primary key, not "id" hardcoded 409 func (q *Query) ResultIDs() []int64 { 410 var ids []int64 411 if Debug { 412 fmt.Printf("#info ResultIDs:%s\n", q.DebugString()) 413 } 414 results, err := q.Results() 415 if err != nil { 416 return ids 417 } 418 419 for _, r := range results { 420 if r["id"] != nil { 421 ids = append(ids, r["id"].(int64)) 422 } 423 } 424 425 return ids 426 } 427 428 // ResultIDSets returns a map from a values to arrays of b values, the order of a,b is respected not the table key order 429 func (q *Query) ResultIDSets(a, b string) map[int64][]int64 { 430 idSets := make(map[int64][]int64, 0) 431 432 results, err := q.Results() 433 if err != nil { 434 return idSets 435 } 436 437 for _, r := range results { 438 if r[a] != nil && r[b] != nil { 439 av := r[a].(int64) 440 bv := r[b].(int64) 441 idSets[av] = append(idSets[av], bv) 442 } 443 } 444 if Debug { 445 fmt.Printf("#info ResultIDSets:%s\n", q.DebugString()) 446 } 447 return idSets 448 } 449 450 // QueryString builds a query string to use for results 451 func (q *Query) QueryString() string { 452 453 if q.sql == "" { 454 455 // if we have arguments override the selector 456 if q.sel == "" { 457 // Note q.table() etc perform quoting on field names 458 q.sel = fmt.Sprintf("SELECT %s.* FROM %s", q.table(), q.table()) 459 } 460 461 q.sql = fmt.Sprintf("%s %s %s %s %s %s %s %s", q.sel, q.join, q.where, q.group, q.having, q.order, q.offset, q.limit) 462 q.sql = strings.TrimRight(q.sql, " ") 463 q.sql = strings.Replace(q.sql, " ", " ", -1) 464 q.sql = strings.Replace(q.sql, " ", " ", -1) 465 466 // Replace ? with whatever placeholder db prefers 467 q.replaceArgPlaceholders() 468 469 q.sql = q.sql + ";" 470 } 471 472 return q.sql 473 } 474 475 // CHAINABLE FINDERS 476 477 // Apply the Func to this query, and return the modified Query 478 // This allows chainable finders from other packages 479 // e.g. q.Apply(status.Published) where status.Published is a Func 480 func (q *Query) Apply(f Func) *Query { 481 return f(q) 482 } 483 484 // Conditions applies a series of query funcs to a query 485 func (q *Query) Conditions(funcs ...Func) *Query { 486 for _, f := range funcs { 487 q = f(q) 488 } 489 return q 490 } 491 492 // SQL defines sql manually and overrides all other setters 493 // Completely replaces all stored sql 494 func (q *Query) SQL(sql string) *Query { 495 q.sql = sql 496 q.reset() 497 return q 498 } 499 500 // Limit sets the sql LIMIT with an int 501 func (q *Query) Limit(limit int) *Query { 502 q.limit = fmt.Sprintf("LIMIT %d", limit) 503 q.reset() 504 return q 505 } 506 507 // Offset sets the sql OFFSET with an int 508 func (q *Query) Offset(offset int) *Query { 509 q.offset = fmt.Sprintf("OFFSET %d", offset) 510 q.reset() 511 return q 512 } 513 514 // Where defines a WHERE clause on SQL - Additional calls add WHERE () AND () clauses 515 func (q *Query) Where(sql string, args ...interface{}) *Query { 516 517 if len(q.where) > 0 { 518 q.where = fmt.Sprintf("%s AND (%s)", q.where, sql) 519 } else { 520 q.where = fmt.Sprintf("WHERE (%s)", sql) 521 } 522 523 // NB this assumes that args are only supplied for where clauses 524 // this may be an incorrect assumption! 525 if args != nil { 526 if q.args == nil { 527 q.args = args 528 } else { 529 q.args = append(q.args, args...) 530 } 531 } 532 533 q.reset() 534 return q 535 } 536 537 // OrWhere defines a where clause on SQL - Additional calls add WHERE () OR () clauses 538 func (q *Query) OrWhere(sql string, args ...interface{}) *Query { 539 540 if len(q.where) > 0 { 541 q.where = fmt.Sprintf("%s OR (%s)", q.where, sql) 542 } else { 543 q.where = fmt.Sprintf("WHERE (%s)", sql) 544 } 545 546 if args != nil { 547 if q.args == nil { 548 q.args = args 549 } else { 550 q.args = append(q.args, args...) 551 } 552 } 553 554 q.reset() 555 return q 556 } 557 558 // WhereIn adds a Where clause which selects records IN() the given array 559 // If IDs is an empty array, the query limit is set to 0 560 func (q *Query) WhereIn(col string, IDs []int64) *Query { 561 // Return no results, so that when chaining callers 562 // don't have to check for empty arrays 563 if len(IDs) == 0 { 564 q.Limit(0) 565 q.reset() 566 return q 567 } 568 569 in := "" 570 for _, ID := range IDs { 571 in = fmt.Sprintf("%s%d,", in, ID) 572 } 573 in = strings.TrimRight(in, ",") 574 sql := fmt.Sprintf("%s IN (%s)", col, in) 575 576 if len(q.where) > 0 { 577 q.where = fmt.Sprintf("%s AND (%s)", q.where, sql) 578 } else { 579 q.where = fmt.Sprintf("WHERE (%s)", sql) 580 } 581 582 q.reset() 583 return q 584 } 585 586 // Define a join clause on SQL - we create an inner join like this: 587 // INNER JOIN extras_seasons ON extras.id = extra_id 588 // q.Select("SELECT units.* FROM units INNER JOIN sites ON units.site_id = sites.id") 589 590 // rails join example 591 // INNER JOIN "posts_tags" ON "posts_tags"."tag_id" = "tags"."id" WHERE "posts_tags"."post_id" = 111 592 593 // Join adds an inner join to the query 594 func (q *Query) Join(otherModel string) *Query { 595 modelTable := q.tablename 596 597 tables := []string{ 598 modelTable, 599 ToPlural(otherModel), 600 } 601 sort.Strings(tables) 602 joinTable := fmt.Sprintf("%s_%s", tables[0], tables[1]) 603 604 sql := fmt.Sprintf("INNER JOIN %s ON %s.id = %s.%s_id", database.QuoteField(joinTable), database.QuoteField(modelTable), database.QuoteField(joinTable), ToSingular(modelTable)) 605 606 if len(q.join) > 0 { 607 q.join = fmt.Sprintf("%s %s", q.join, sql) 608 } else { 609 q.join = fmt.Sprintf("%s", sql) 610 } 611 612 q.reset() 613 return q 614 } 615 616 // Order defines ORDER BY sql 617 func (q *Query) Order(sql string) *Query { 618 if sql == "" { 619 q.order = "" 620 } else { 621 q.order = fmt.Sprintf("ORDER BY %s", sql) 622 } 623 q.reset() 624 625 return q 626 } 627 628 // Group defines GROUP BY sql 629 func (q *Query) Group(sql string) *Query { 630 if sql == "" { 631 q.group = "" 632 } else { 633 q.group = fmt.Sprintf("GROUP BY %s", sql) 634 } 635 q.reset() 636 return q 637 } 638 639 // Having defines HAVING sql 640 func (q *Query) Having(sql string) *Query { 641 if sql == "" { 642 q.having = "" 643 } else { 644 q.having = fmt.Sprintf("HAVING %s", sql) 645 } 646 q.reset() 647 return q 648 } 649 650 // Select defines SELECT sql 651 func (q *Query) Select(sql string) *Query { 652 q.sel = sql 653 q.reset() 654 return q 655 } 656 657 // DebugString returns a query representation string useful for debugging 658 func (q *Query) DebugString() string { 659 return fmt.Sprintf("--\nQuery-SQL:%s\nARGS:%s\n--", q.QueryString(), q.argString()) 660 } 661 662 // Clear sql/query caches 663 func (q *Query) reset() { 664 // Perhaps later clear cached compiled representation of query too 665 666 // clear stored sql 667 q.sql = "" 668 } 669 670 // Return an arg string (for debugging) 671 func (q *Query) argString() string { 672 output := "-" 673 674 for _, a := range q.args { 675 output = output + fmt.Sprintf("'%s',", q.argToString(a)) 676 } 677 output = strings.TrimRight(output, ",") 678 output = output + "" 679 680 return output 681 } 682 683 // Convert arguments to string - used only for debug argument strings 684 // Not to be exported or used to try to escape strings... 685 func (q *Query) argToString(arg interface{}) string { 686 switch arg.(type) { 687 case string: 688 return arg.(string) 689 case []byte: 690 return string(arg.([]byte)) 691 case int, int8, int16, int32, uint, uint8, uint16, uint32: 692 return fmt.Sprintf("%d", arg) 693 case int64, uint64: 694 return fmt.Sprintf("%d", arg) 695 case float32, float64: 696 return fmt.Sprintf("%f", arg) 697 case bool: 698 return fmt.Sprintf("%d", arg) 699 default: 700 return fmt.Sprintf("%v", arg) 701 702 } 703 704 } 705 706 // Ask model for primary key name to use 707 func (q *Query) pk() string { 708 return database.QuoteField(q.primarykey) 709 } 710 711 // Ask model for table name to use 712 func (q *Query) table() string { 713 return database.QuoteField(q.tablename) 714 } 715 716 // Replace ? with whatever database prefers (psql uses numbered args) 717 func (q *Query) replaceArgPlaceholders() { 718 // Match ? and replace with argument placeholder from database 719 for i := range q.args { 720 q.sql = strings.Replace(q.sql, "?", database.Placeholder(i+1), 1) 721 } 722 } 723 724 // Sorts the param names given - map iteration order is explicitly random in Go 725 // but we need params in a defined order to avoid unexpected results. 726 func sortedParamKeys(params map[string]string) []string { 727 sortedKeys := make([]string, len(params)) 728 i := 0 729 for k := range params { 730 sortedKeys[i] = k 731 i++ 732 } 733 sort.Strings(sortedKeys) 734 735 return sortedKeys 736 } 737 738 // Generate a set of values for the params in order 739 func valuesFromParams(params map[string]string) []interface{} { 740 741 // NB DO NOT DEPEND ON PARAMS ORDER - see note on SortedParamKeys 742 var values []interface{} 743 for _, key := range sortedParamKeys(params) { 744 values = append(values, params[key]) 745 } 746 return values 747 } 748 749 // Used for update statements, turn params into sql i.e. "col"=? 750 func querySQL(params map[string]string) string { 751 var output []string 752 for _, k := range sortedParamKeys(params) { 753 output = append(output, fmt.Sprintf("%s=?", database.QuoteField(k))) 754 } 755 return strings.Join(output, ",") 756 } 757 758 func scanRow(cols []string, rows *sql.Rows) (Result, error) { 759 760 // We return a map[string]interface{} for each row scanned 761 result := Result{} 762 763 values := make([]interface{}, len(cols)) 764 for i := 0; i < len(cols); i++ { 765 var col interface{} 766 values[i] = &col 767 } 768 769 // Scan results into these interfaces 770 err := rows.Scan(values...) 771 if err != nil { 772 return nil, fmt.Errorf("Error scanning row: %s", err) 773 } 774 775 // Make a string => interface map and hand off to caller 776 // We fix up a few types which the pq driver returns as less handy equivalents 777 // We enforce usage of int64 at all times as all our records use int64 778 for i := 0; i < len(cols); i++ { 779 v := *values[i].(*interface{}) 780 if values[i] != nil { 781 switch v.(type) { 782 default: 783 result[cols[i]] = v 784 case bool: 785 result[cols[i]] = v.(bool) 786 case int: 787 result[cols[i]] = int64(v.(int)) 788 case []byte: // text cols are given as bytes 789 result[cols[i]] = string(v.([]byte)) 790 case int64: 791 result[cols[i]] = v.(int64) 792 } 793 } 794 795 } 796 797 return result, nil 798 }