github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/queries/query_builders.go (about) 1 package queries 2 3 import ( 4 "bytes" 5 "fmt" 6 "regexp" 7 "sort" 8 "strings" 9 10 "github.com/volatiletech/strmangle" 11 ) 12 13 var ( 14 rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`) 15 rgxInClause = regexp.MustCompile(`^(?i)(.*[\s|\)|\?])IN([\s|\(|\?].*)$`) 16 rgxNotInClause = regexp.MustCompile(`^(?i)(.*[\s|\)|\?])NOT\s+IN([\s|\(|\?].*)$`) 17 ) 18 19 // BuildQuery builds a query object into the query string 20 // and it's accompanying arguments. Using this method 21 // allows query building without immediate execution. 22 func BuildQuery(q *Query) (string, []interface{}) { 23 var buf *bytes.Buffer 24 var args []interface{} 25 26 q.removeSoftDeleteWhere() 27 28 switch { 29 case len(q.rawSQL.sql) != 0: 30 return q.rawSQL.sql, q.rawSQL.args 31 case q.delete: 32 buf, args = buildDeleteQuery(q) 33 case len(q.update) > 0: 34 buf, args = buildUpdateQuery(q) 35 default: 36 buf, args = buildSelectQuery(q) 37 } 38 39 defer strmangle.PutBuffer(buf) 40 41 // Cache the generated query for query object re-use 42 bufStr := buf.String() 43 q.rawSQL.sql = bufStr 44 q.rawSQL.args = args 45 46 return bufStr, args 47 } 48 49 func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { 50 buf := strmangle.GetBuffer() 51 var args []interface{} 52 53 writeComment(q, buf) 54 writeCTEs(q, buf, &args) 55 56 buf.WriteString("SELECT ") 57 58 if q.dialect.UseTopClause { 59 if q.limit != nil && q.offset == 0 { 60 fmt.Fprintf(buf, " TOP (%d) ", *q.limit) 61 } 62 } 63 64 if q.count { 65 buf.WriteString("COUNT(") 66 } 67 68 hasSelectCols := len(q.selectCols) != 0 69 hasJoins := len(q.joins) != 0 70 hasDistinct := q.distinct != "" 71 if hasDistinct { 72 buf.WriteString("DISTINCT ") 73 if q.count { 74 buf.WriteString("(") 75 } 76 buf.WriteString(q.distinct) 77 if q.count { 78 buf.WriteString(")") 79 } 80 } else if hasJoins && hasSelectCols && !q.count { 81 selectColsWithAs := writeAsStatements(q) 82 // Don't identQuoteSlice - writeAsStatements does this 83 buf.WriteString(strings.Join(selectColsWithAs, ", ")) 84 } else if hasSelectCols { 85 buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", ")) 86 } else if hasJoins && !q.count { 87 selectColsWithStars := writeStars(q) 88 buf.WriteString(strings.Join(selectColsWithStars, ", ")) 89 } else { 90 buf.WriteByte('*') 91 } 92 93 // close SQL COUNT function 94 if q.count { 95 buf.WriteByte(')') 96 } 97 98 fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", ")) 99 100 if len(q.joins) > 0 { 101 argsLen := len(args) 102 joinBuf := strmangle.GetBuffer() 103 for _, j := range q.joins { 104 switch j.kind { 105 case JoinInner: 106 fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause) 107 case JoinOuterLeft: 108 fmt.Fprintf(joinBuf, " LEFT JOIN %s", j.clause) 109 case JoinOuterRight: 110 fmt.Fprintf(joinBuf, " RIGHT JOIN %s", j.clause) 111 case JoinOuterFull: 112 fmt.Fprintf(joinBuf, " FULL JOIN %s", j.clause) 113 default: 114 panic(fmt.Sprintf("Unsupported join of kind %v", j.kind)) 115 } 116 args = append(args, j.args...) 117 } 118 var resp string 119 if q.dialect.UseIndexPlaceholders { 120 resp, _ = convertQuestionMarks(joinBuf.String(), argsLen+1) 121 } else { 122 resp = joinBuf.String() 123 } 124 fmt.Fprintf(buf, resp) 125 strmangle.PutBuffer(joinBuf) 126 } 127 128 where, whereArgs := whereClause(q, len(args)+1) 129 buf.WriteString(where) 130 if len(whereArgs) != 0 { 131 args = append(args, whereArgs...) 132 } 133 134 writeModifiers(q, buf, &args) 135 136 buf.WriteByte(';') 137 return buf, args 138 } 139 140 func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { 141 var args []interface{} 142 buf := strmangle.GetBuffer() 143 144 writeComment(q, buf) 145 writeCTEs(q, buf, &args) 146 147 buf.WriteString("DELETE FROM ") 148 buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", ")) 149 150 where, whereArgs := whereClause(q, 1) 151 if len(whereArgs) != 0 { 152 args = append(args, whereArgs...) 153 } 154 buf.WriteString(where) 155 156 writeModifiers(q, buf, &args) 157 158 buf.WriteByte(';') 159 160 return buf, args 161 } 162 163 func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { 164 buf := strmangle.GetBuffer() 165 var args []interface{} 166 167 writeComment(q, buf) 168 writeCTEs(q, buf, &args) 169 170 buf.WriteString("UPDATE ") 171 buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", ")) 172 173 cols := make(sort.StringSlice, len(q.update)) 174 175 count := 0 176 for name := range q.update { 177 cols[count] = name 178 count++ 179 } 180 181 cols.Sort() 182 183 for i := 0; i < len(cols); i++ { 184 args = append(args, q.update[cols[i]]) 185 cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, cols[i]) 186 } 187 188 setSlice := make([]string, len(cols)) 189 for index, col := range cols { 190 setSlice[index] = fmt.Sprintf("%s = %s", col, strmangle.Placeholders(q.dialect.UseIndexPlaceholders, 1, index+1, 1)) 191 } 192 fmt.Fprintf(buf, " SET %s", strings.Join(setSlice, ", ")) 193 194 where, whereArgs := whereClause(q, len(args)+1) 195 if len(whereArgs) != 0 { 196 args = append(args, whereArgs...) 197 } 198 buf.WriteString(where) 199 200 writeModifiers(q, buf, &args) 201 202 buf.WriteByte(';') 203 204 return buf, args 205 } 206 207 func writeParameterizedModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}, keyword, delim string, clauses []argClause) { 208 argsLen := len(*args) 209 modBuf := strmangle.GetBuffer() 210 fmt.Fprintf(modBuf, keyword) 211 212 for i, j := range clauses { 213 if i > 0 { 214 modBuf.WriteString(delim) 215 } 216 modBuf.WriteString(j.clause) 217 *args = append(*args, j.args...) 218 } 219 220 var resp string 221 if q.dialect.UseIndexPlaceholders { 222 resp, _ = convertQuestionMarks(modBuf.String(), argsLen+1) 223 } else { 224 resp = modBuf.String() 225 } 226 227 buf.WriteString(resp) 228 strmangle.PutBuffer(modBuf) 229 } 230 231 func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) { 232 if len(q.groupBy) != 0 { 233 fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", ")) 234 } 235 236 if len(q.having) != 0 { 237 writeParameterizedModifiers(q, buf, args, " HAVING ", " AND ", q.having) 238 } 239 240 if len(q.orderBy) != 0 { 241 writeParameterizedModifiers(q, buf, args, " ORDER BY ", ", ", q.orderBy) 242 } 243 244 if !q.dialect.UseTopClause { 245 if q.limit != nil { 246 fmt.Fprintf(buf, " LIMIT %d", *q.limit) 247 } 248 249 if q.offset != 0 { 250 fmt.Fprintf(buf, " OFFSET %d", q.offset) 251 } 252 } else { 253 // From MS SQL 2012 and above: https://technet.microsoft.com/en-us/library/ms188385(v=sql.110).aspx 254 // ORDER BY ... 255 // OFFSET N ROWS 256 // FETCH NEXT M ROWS ONLY 257 if q.offset != 0 { 258 259 // Hack from https://www.microsoftpressstore.com/articles/article.aspx?p=2314819 260 // ... 261 // As mentioned, the OFFSET-FETCH filter requires an ORDER BY clause. If you want to use arbitrary order, 262 // like TOP without an ORDER BY clause, you can use the trick with ORDER BY (SELECT NULL) 263 // ... 264 if len(q.orderBy) == 0 { 265 buf.WriteString(" ORDER BY (SELECT NULL)") 266 } 267 268 // This seems to be the latest version of mssql's syntax for offset 269 // (the suffix ROWS) 270 // This is true for latest sql server as well as their cloud offerings & the upcoming sql server 2019 271 // https://docs.microsoft.com/en-us/sql/t-sql/queries/select-order-by-clause-transact-sql?view=sql-server-2017 272 // https://docs.microsoft.com/en-us/sql/t-sql/queries/select-order-by-clause-transact-sql?view=sql-server-ver15 273 fmt.Fprintf(buf, " OFFSET %d ROWS", q.offset) 274 275 if q.limit != nil { 276 fmt.Fprintf(buf, " FETCH NEXT %d ROWS ONLY", *q.limit) 277 } 278 } 279 } 280 281 if len(q.forlock) != 0 { 282 fmt.Fprintf(buf, " FOR %s", q.forlock) 283 } 284 } 285 286 func writeStars(q *Query) []string { 287 cols := make([]string, len(q.from)) 288 for i, f := range q.from { 289 toks := strings.Split(f, " ") 290 if len(toks) == 1 { 291 cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, toks[0])) 292 continue 293 } 294 295 alias, name, ok := parseFromClause(toks) 296 if !ok { 297 return nil 298 } 299 300 if len(alias) != 0 { 301 name = alias 302 } 303 cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, name)) 304 } 305 306 return cols 307 } 308 309 func writeAsStatements(q *Query) []string { 310 cols := make([]string, len(q.selectCols)) 311 for i, col := range q.selectCols { 312 if !rgxIdentifier.MatchString(col) { 313 cols[i] = col 314 continue 315 } 316 317 toks := strings.Split(col, ".") 318 if len(toks) == 1 { 319 cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col) 320 continue 321 } 322 323 asParts := make([]string, len(toks)) 324 for j, tok := range toks { 325 asParts[j] = strings.Trim(tok, `"`) 326 } 327 328 cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col), strings.Join(asParts, ".")) 329 } 330 331 return cols 332 } 333 334 // whereClause parses a where slice and converts it into a 335 // single WHERE clause like: 336 // WHERE (a=$1) AND (b=$2) AND (a,b) in (($3, $4), ($5, $6)) 337 // 338 // startAt specifies what number placeholders start at 339 func whereClause(q *Query, startAt int) (string, []interface{}) { 340 if len(q.where) == 0 { 341 return "", nil 342 } 343 344 manualParens := false 345 ManualParen: 346 for _, w := range q.where { 347 switch w.kind { 348 case whereKindLeftParen, whereKindRightParen: 349 manualParens = true 350 break ManualParen 351 } 352 } 353 354 buf := strmangle.GetBuffer() 355 defer strmangle.PutBuffer(buf) 356 var args []interface{} 357 358 notFirstExpression := false 359 buf.WriteString(" WHERE ") 360 for _, where := range q.where { 361 if notFirstExpression && where.kind != whereKindRightParen { 362 if where.orSeparator { 363 buf.WriteString(" OR ") 364 } else { 365 buf.WriteString(" AND ") 366 } 367 } else { 368 notFirstExpression = true 369 } 370 371 switch where.kind { 372 case whereKindNormal: 373 if !manualParens { 374 buf.WriteByte('(') 375 } 376 if q.dialect.UseIndexPlaceholders { 377 replaced, n := convertQuestionMarks(where.clause, startAt) 378 buf.WriteString(replaced) 379 startAt += n 380 } else { 381 buf.WriteString(where.clause) 382 } 383 if !manualParens { 384 buf.WriteByte(')') 385 } 386 args = append(args, where.args...) 387 case whereKindLeftParen: 388 buf.WriteByte('(') 389 notFirstExpression = false 390 case whereKindRightParen: 391 buf.WriteByte(')') 392 case whereKindIn, whereKindNotIn: 393 ln := len(where.args) 394 // WHERE IN () is invalid sql, so it is difficult to simply run code like: 395 // for _, u := range model.Users(qm.WhereIn("id IN ?",uids...)).AllP(db) { 396 // ... 397 // } 398 // instead when we see empty IN we produce 1=0 so it can still be chained 399 // with other queries 400 if ln == 0 { 401 if where.kind == whereKindIn { 402 buf.WriteString("(1=0)") 403 } else if where.kind == whereKindNotIn { 404 buf.WriteString("(1=1)") 405 } 406 break 407 } 408 409 var matches []string 410 if where.kind == whereKindIn { 411 matches = rgxInClause.FindStringSubmatch(where.clause) 412 } else { 413 matches = rgxNotInClause.FindStringSubmatch(where.clause) 414 } 415 416 // If we can't find any matches attempt a simple replace with 1 group. 417 // Clauses that fit this criteria will not be able to contain ? in their 418 // column name side, however if this case is being hit then the regexp 419 // probably needs adjustment, or the user is passing in invalid clauses. 420 if matches == nil { 421 clause, count := convertInQuestionMarks(q.dialect.UseIndexPlaceholders, where.clause, startAt, 1, ln) 422 if !manualParens { 423 buf.WriteByte('(') 424 } 425 buf.WriteString(clause) 426 if !manualParens { 427 buf.WriteByte(')') 428 } 429 args = append(args, where.args...) 430 startAt += count 431 break 432 } 433 434 leftSide := strings.TrimSpace(matches[1]) 435 rightSide := strings.TrimSpace(matches[2]) 436 // If matches are found, we have to parse the left side (column side) 437 // of the clause to determine how many columns they are using. 438 // This number determines the groupAt for the convert function. 439 cols := strings.Split(leftSide, ",") 440 cols = strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, cols) 441 groupAt := len(cols) 442 443 var leftClause string 444 var leftCount int 445 if q.dialect.UseIndexPlaceholders { 446 leftClause, leftCount = convertQuestionMarks(strings.Join(cols, ","), startAt) 447 } else { 448 // Count the number of cols that are question marks, so we know 449 // how much to offset convertInQuestionMarks by 450 for _, v := range cols { 451 if v == "?" { 452 leftCount++ 453 } 454 } 455 leftClause = strings.Join(cols, ",") 456 } 457 rightClause, rightCount := convertInQuestionMarks(q.dialect.UseIndexPlaceholders, rightSide, startAt+leftCount, groupAt, ln-leftCount) 458 if !manualParens { 459 buf.WriteByte('(') 460 } 461 buf.WriteString(leftClause) 462 if where.kind == whereKindIn { 463 buf.WriteString(" IN ") 464 } else if where.kind == whereKindNotIn { 465 buf.WriteString(" NOT IN ") 466 } 467 buf.WriteString(rightClause) 468 if !manualParens { 469 buf.WriteByte(')') 470 } 471 startAt += leftCount + rightCount 472 args = append(args, where.args...) 473 default: 474 panic("unknown where type") 475 } 476 } 477 478 return buf.String(), args 479 } 480 481 // convertInQuestionMarks finds the first unescaped occurrence of ? and swaps it 482 // with a list of numbered placeholders, starting at startAt. 483 // It uses groupAt to determine how many placeholders should be in each group, 484 // for example, groupAt 2 would result in: (($1,$2),($3,$4)) 485 // and groupAt 1 would result in ($1,$2,$3,$4) 486 func convertInQuestionMarks(UseIndexPlaceholders bool, clause string, startAt, groupAt, total int) (string, int) { 487 if startAt == 0 || len(clause) == 0 { 488 panic("Not a valid start number.") 489 } 490 491 paramBuf := strmangle.GetBuffer() 492 defer strmangle.PutBuffer(paramBuf) 493 494 foundAt := -1 495 for i := 0; i < len(clause); i++ { 496 if (i == 0 && clause[i] == '?') || (clause[i] == '?' && clause[i-1] != '\\') { 497 foundAt = i 498 break 499 } 500 } 501 502 if foundAt == -1 { 503 return strings.ReplaceAll(clause, `\?`, "?"), 0 504 } 505 506 paramBuf.WriteString(clause[:foundAt]) 507 paramBuf.WriteByte('(') 508 paramBuf.WriteString(strmangle.Placeholders(UseIndexPlaceholders, total, startAt, groupAt)) 509 paramBuf.WriteByte(')') 510 paramBuf.WriteString(clause[foundAt+1:]) 511 512 // Remove all backslashes from escaped question-marks 513 ret := strings.ReplaceAll(paramBuf.String(), `\?`, "?") 514 return ret, total 515 } 516 517 // convertQuestionMarks converts each occurrence of ? with $<number> 518 // where <number> is an incrementing digit starting at startAt. 519 // If question-mark (?) is escaped using back-slash (\), it will be ignored. 520 func convertQuestionMarks(clause string, startAt int) (string, int) { 521 if startAt == 0 { 522 panic("Not a valid start number.") 523 } 524 525 paramBuf := strmangle.GetBuffer() 526 defer strmangle.PutBuffer(paramBuf) 527 paramIndex := 0 528 total := 0 529 530 for { 531 if paramIndex >= len(clause) { 532 break 533 } 534 535 clause = clause[paramIndex:] 536 paramIndex = strings.IndexByte(clause, '?') 537 538 if paramIndex == -1 { 539 paramBuf.WriteString(clause) 540 break 541 } 542 543 escapeIndex := strings.Index(clause, `\?`) 544 if escapeIndex != -1 && paramIndex > escapeIndex { 545 paramBuf.WriteString(clause[:escapeIndex] + "?") 546 paramIndex++ 547 continue 548 } 549 550 paramBuf.WriteString(clause[:paramIndex] + fmt.Sprintf("$%d", startAt)) 551 total++ 552 startAt++ 553 paramIndex++ 554 } 555 556 return paramBuf.String(), total 557 } 558 559 // parseFromClause will parse something that looks like 560 // a 561 // a b 562 // a as b 563 func parseFromClause(toks []string) (alias, name string, ok bool) { 564 if len(toks) > 3 { 565 toks = toks[:3] 566 } 567 568 sawIdent, sawAs := false, false 569 for _, tok := range toks { 570 if t := strings.ToLower(tok); sawIdent && t == "as" { 571 sawAs = true 572 continue 573 } else if sawIdent && t == "on" { 574 break 575 } 576 577 if !rgxIdentifier.MatchString(tok) { 578 break 579 } 580 581 if sawIdent || sawAs { 582 alias = strings.Trim(tok, `"`) 583 break 584 } 585 586 name = strings.Trim(tok, `"`) 587 sawIdent = true 588 ok = true 589 } 590 591 return alias, name, ok 592 } 593 594 func writeComment(q *Query, buf *bytes.Buffer) { 595 if len(q.comment) == 0 { 596 return 597 } 598 599 lines := strings.Split(q.comment, "\n") 600 for _, line := range lines { 601 buf.WriteString("-- ") 602 buf.WriteString(line) 603 buf.WriteByte('\n') 604 } 605 } 606 607 func writeCTEs(q *Query, buf *bytes.Buffer, args *[]interface{}) { 608 if len(q.withs) == 0 { 609 return 610 } 611 612 buf.WriteString("WITH") 613 argsLen := len(*args) 614 withBuf := strmangle.GetBuffer() 615 lastPos := len(q.withs) - 1 616 for i, w := range q.withs { 617 fmt.Fprintf(withBuf, " %s", w.clause) 618 if i >= 0 && i < lastPos { 619 withBuf.WriteByte(',') 620 } 621 *args = append(*args, w.args...) 622 } 623 withBuf.WriteByte(' ') 624 var resp string 625 if q.dialect.UseIndexPlaceholders { 626 resp, _ = convertQuestionMarks(withBuf.String(), argsLen+1) 627 } else { 628 resp = withBuf.String() 629 } 630 fmt.Fprintf(buf, resp) 631 strmangle.PutBuffer(withBuf) 632 }