github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/queries/query_builders.go (about)

     1  package queries
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"regexp"
     7  	"sort"
     8  	"strings"
     9  
    10  	"github.com/volatiletech/strmangle"
    11  )
    12  
    13  var (
    14  	rgxIdentifier  = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`)
    15  	rgxInClause    = regexp.MustCompile(`^(?i)(.*[\s|\)|\?])IN([\s|\(|\?].*)$`)
    16  	rgxNotInClause = regexp.MustCompile(`^(?i)(.*[\s|\)|\?])NOT\s+IN([\s|\(|\?].*)$`)
    17  )
    18  
    19  // BuildQuery builds a query object into the query string
    20  // and it's accompanying arguments. Using this method
    21  // allows query building without immediate execution.
    22  func BuildQuery(q *Query) (string, []interface{}) {
    23  	var buf *bytes.Buffer
    24  	var args []interface{}
    25  
    26  	q.removeSoftDeleteWhere()
    27  
    28  	switch {
    29  	case len(q.rawSQL.sql) != 0:
    30  		return q.rawSQL.sql, q.rawSQL.args
    31  	case q.delete:
    32  		buf, args = buildDeleteQuery(q)
    33  	case len(q.update) > 0:
    34  		buf, args = buildUpdateQuery(q)
    35  	default:
    36  		buf, args = buildSelectQuery(q)
    37  	}
    38  
    39  	defer strmangle.PutBuffer(buf)
    40  
    41  	// Cache the generated query for query object re-use
    42  	bufStr := buf.String()
    43  	q.rawSQL.sql = bufStr
    44  	q.rawSQL.args = args
    45  
    46  	return bufStr, args
    47  }
    48  
    49  func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
    50  	buf := strmangle.GetBuffer()
    51  	var args []interface{}
    52  
    53  	writeComment(q, buf)
    54  	writeCTEs(q, buf, &args)
    55  
    56  	buf.WriteString("SELECT ")
    57  
    58  	if q.dialect.UseTopClause {
    59  		if q.limit != nil && q.offset == 0 {
    60  			fmt.Fprintf(buf, " TOP (%d) ", *q.limit)
    61  		}
    62  	}
    63  
    64  	if q.count {
    65  		buf.WriteString("COUNT(")
    66  	}
    67  
    68  	hasSelectCols := len(q.selectCols) != 0
    69  	hasJoins := len(q.joins) != 0
    70  	hasDistinct := q.distinct != ""
    71  	if hasDistinct {
    72  		buf.WriteString("DISTINCT ")
    73  		if q.count {
    74  			buf.WriteString("(")
    75  		}
    76  		buf.WriteString(q.distinct)
    77  		if q.count {
    78  			buf.WriteString(")")
    79  		}
    80  	} else if hasJoins && hasSelectCols && !q.count {
    81  		selectColsWithAs := writeAsStatements(q)
    82  		// Don't identQuoteSlice - writeAsStatements does this
    83  		buf.WriteString(strings.Join(selectColsWithAs, ", "))
    84  	} else if hasSelectCols {
    85  		buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", "))
    86  	} else if hasJoins && !q.count {
    87  		selectColsWithStars := writeStars(q)
    88  		buf.WriteString(strings.Join(selectColsWithStars, ", "))
    89  	} else {
    90  		buf.WriteByte('*')
    91  	}
    92  
    93  	// close SQL COUNT function
    94  	if q.count {
    95  		buf.WriteByte(')')
    96  	}
    97  
    98  	fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
    99  
   100  	if len(q.joins) > 0 {
   101  		argsLen := len(args)
   102  		joinBuf := strmangle.GetBuffer()
   103  		for _, j := range q.joins {
   104  			switch j.kind {
   105  			case JoinInner:
   106  				fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause)
   107  			case JoinOuterLeft:
   108  				fmt.Fprintf(joinBuf, " LEFT JOIN %s", j.clause)
   109  			case JoinOuterRight:
   110  				fmt.Fprintf(joinBuf, " RIGHT JOIN %s", j.clause)
   111  			case JoinOuterFull:
   112  				fmt.Fprintf(joinBuf, " FULL JOIN %s", j.clause)
   113  			default:
   114  				panic(fmt.Sprintf("Unsupported join of kind %v", j.kind))
   115  			}
   116  			args = append(args, j.args...)
   117  		}
   118  		var resp string
   119  		if q.dialect.UseIndexPlaceholders {
   120  			resp, _ = convertQuestionMarks(joinBuf.String(), argsLen+1)
   121  		} else {
   122  			resp = joinBuf.String()
   123  		}
   124  		fmt.Fprintf(buf, resp)
   125  		strmangle.PutBuffer(joinBuf)
   126  	}
   127  
   128  	where, whereArgs := whereClause(q, len(args)+1)
   129  	buf.WriteString(where)
   130  	if len(whereArgs) != 0 {
   131  		args = append(args, whereArgs...)
   132  	}
   133  
   134  	writeModifiers(q, buf, &args)
   135  
   136  	buf.WriteByte(';')
   137  	return buf, args
   138  }
   139  
   140  func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
   141  	var args []interface{}
   142  	buf := strmangle.GetBuffer()
   143  
   144  	writeComment(q, buf)
   145  	writeCTEs(q, buf, &args)
   146  
   147  	buf.WriteString("DELETE FROM ")
   148  	buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
   149  
   150  	where, whereArgs := whereClause(q, 1)
   151  	if len(whereArgs) != 0 {
   152  		args = append(args, whereArgs...)
   153  	}
   154  	buf.WriteString(where)
   155  
   156  	writeModifiers(q, buf, &args)
   157  
   158  	buf.WriteByte(';')
   159  
   160  	return buf, args
   161  }
   162  
   163  func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
   164  	buf := strmangle.GetBuffer()
   165  	var args []interface{}
   166  
   167  	writeComment(q, buf)
   168  	writeCTEs(q, buf, &args)
   169  
   170  	buf.WriteString("UPDATE ")
   171  	buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
   172  
   173  	cols := make(sort.StringSlice, len(q.update))
   174  
   175  	count := 0
   176  	for name := range q.update {
   177  		cols[count] = name
   178  		count++
   179  	}
   180  
   181  	cols.Sort()
   182  
   183  	for i := 0; i < len(cols); i++ {
   184  		args = append(args, q.update[cols[i]])
   185  		cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, cols[i])
   186  	}
   187  
   188  	setSlice := make([]string, len(cols))
   189  	for index, col := range cols {
   190  		setSlice[index] = fmt.Sprintf("%s = %s", col, strmangle.Placeholders(q.dialect.UseIndexPlaceholders, 1, index+1, 1))
   191  	}
   192  	fmt.Fprintf(buf, " SET %s", strings.Join(setSlice, ", "))
   193  
   194  	where, whereArgs := whereClause(q, len(args)+1)
   195  	if len(whereArgs) != 0 {
   196  		args = append(args, whereArgs...)
   197  	}
   198  	buf.WriteString(where)
   199  
   200  	writeModifiers(q, buf, &args)
   201  
   202  	buf.WriteByte(';')
   203  
   204  	return buf, args
   205  }
   206  
   207  func writeParameterizedModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}, keyword, delim string, clauses []argClause) {
   208  	argsLen := len(*args)
   209  	modBuf := strmangle.GetBuffer()
   210  	fmt.Fprintf(modBuf, keyword)
   211  
   212  	for i, j := range clauses {
   213  		if i > 0 {
   214  			modBuf.WriteString(delim)
   215  		}
   216  		modBuf.WriteString(j.clause)
   217  		*args = append(*args, j.args...)
   218  	}
   219  
   220  	var resp string
   221  	if q.dialect.UseIndexPlaceholders {
   222  		resp, _ = convertQuestionMarks(modBuf.String(), argsLen+1)
   223  	} else {
   224  		resp = modBuf.String()
   225  	}
   226  
   227  	buf.WriteString(resp)
   228  	strmangle.PutBuffer(modBuf)
   229  }
   230  
   231  func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) {
   232  	if len(q.groupBy) != 0 {
   233  		fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", "))
   234  	}
   235  
   236  	if len(q.having) != 0 {
   237  		writeParameterizedModifiers(q, buf, args, " HAVING ", " AND ", q.having)
   238  	}
   239  
   240  	if len(q.orderBy) != 0 {
   241  		writeParameterizedModifiers(q, buf, args, " ORDER BY ", ", ", q.orderBy)
   242  	}
   243  
   244  	if !q.dialect.UseTopClause {
   245  		if q.limit != nil {
   246  			fmt.Fprintf(buf, " LIMIT %d", *q.limit)
   247  		}
   248  
   249  		if q.offset != 0 {
   250  			fmt.Fprintf(buf, " OFFSET %d", q.offset)
   251  		}
   252  	} else {
   253  		// From MS SQL 2012 and above: https://technet.microsoft.com/en-us/library/ms188385(v=sql.110).aspx
   254  		// ORDER BY ...
   255  		// OFFSET N ROWS
   256  		// FETCH NEXT M ROWS ONLY
   257  		if q.offset != 0 {
   258  
   259  			// Hack from https://www.microsoftpressstore.com/articles/article.aspx?p=2314819
   260  			// ...
   261  			// As mentioned, the OFFSET-FETCH filter requires an ORDER BY clause. If you want to use arbitrary order,
   262  			// like TOP without an ORDER BY clause, you can use the trick with ORDER BY (SELECT NULL)
   263  			// ...
   264  			if len(q.orderBy) == 0 {
   265  				buf.WriteString(" ORDER BY (SELECT NULL)")
   266  			}
   267  
   268  			// This seems to be the latest version of mssql's syntax for offset
   269  			// (the suffix ROWS)
   270  			// This is true for latest sql server as well as their cloud offerings & the upcoming sql server 2019
   271  			// https://docs.microsoft.com/en-us/sql/t-sql/queries/select-order-by-clause-transact-sql?view=sql-server-2017
   272  			// https://docs.microsoft.com/en-us/sql/t-sql/queries/select-order-by-clause-transact-sql?view=sql-server-ver15
   273  			fmt.Fprintf(buf, " OFFSET %d ROWS", q.offset)
   274  
   275  			if q.limit != nil {
   276  				fmt.Fprintf(buf, " FETCH NEXT %d ROWS ONLY", *q.limit)
   277  			}
   278  		}
   279  	}
   280  
   281  	if len(q.forlock) != 0 {
   282  		fmt.Fprintf(buf, " FOR %s", q.forlock)
   283  	}
   284  }
   285  
   286  func writeStars(q *Query) []string {
   287  	cols := make([]string, len(q.from))
   288  	for i, f := range q.from {
   289  		toks := strings.Split(f, " ")
   290  		if len(toks) == 1 {
   291  			cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, toks[0]))
   292  			continue
   293  		}
   294  
   295  		alias, name, ok := parseFromClause(toks)
   296  		if !ok {
   297  			return nil
   298  		}
   299  
   300  		if len(alias) != 0 {
   301  			name = alias
   302  		}
   303  		cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, name))
   304  	}
   305  
   306  	return cols
   307  }
   308  
   309  func writeAsStatements(q *Query) []string {
   310  	cols := make([]string, len(q.selectCols))
   311  	for i, col := range q.selectCols {
   312  		if !rgxIdentifier.MatchString(col) {
   313  			cols[i] = col
   314  			continue
   315  		}
   316  
   317  		toks := strings.Split(col, ".")
   318  		if len(toks) == 1 {
   319  			cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col)
   320  			continue
   321  		}
   322  
   323  		asParts := make([]string, len(toks))
   324  		for j, tok := range toks {
   325  			asParts[j] = strings.Trim(tok, `"`)
   326  		}
   327  
   328  		cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col), strings.Join(asParts, "."))
   329  	}
   330  
   331  	return cols
   332  }
   333  
   334  // whereClause parses a where slice and converts it into a
   335  // single WHERE clause like:
   336  // WHERE (a=$1) AND (b=$2) AND (a,b) in (($3, $4), ($5, $6))
   337  //
   338  // startAt specifies what number placeholders start at
   339  func whereClause(q *Query, startAt int) (string, []interface{}) {
   340  	if len(q.where) == 0 {
   341  		return "", nil
   342  	}
   343  
   344  	manualParens := false
   345  ManualParen:
   346  	for _, w := range q.where {
   347  		switch w.kind {
   348  		case whereKindLeftParen, whereKindRightParen:
   349  			manualParens = true
   350  			break ManualParen
   351  		}
   352  	}
   353  
   354  	buf := strmangle.GetBuffer()
   355  	defer strmangle.PutBuffer(buf)
   356  	var args []interface{}
   357  
   358  	notFirstExpression := false
   359  	buf.WriteString(" WHERE ")
   360  	for _, where := range q.where {
   361  		if notFirstExpression && where.kind != whereKindRightParen {
   362  			if where.orSeparator {
   363  				buf.WriteString(" OR ")
   364  			} else {
   365  				buf.WriteString(" AND ")
   366  			}
   367  		} else {
   368  			notFirstExpression = true
   369  		}
   370  
   371  		switch where.kind {
   372  		case whereKindNormal:
   373  			if !manualParens {
   374  				buf.WriteByte('(')
   375  			}
   376  			if q.dialect.UseIndexPlaceholders {
   377  				replaced, n := convertQuestionMarks(where.clause, startAt)
   378  				buf.WriteString(replaced)
   379  				startAt += n
   380  			} else {
   381  				buf.WriteString(where.clause)
   382  			}
   383  			if !manualParens {
   384  				buf.WriteByte(')')
   385  			}
   386  			args = append(args, where.args...)
   387  		case whereKindLeftParen:
   388  			buf.WriteByte('(')
   389  			notFirstExpression = false
   390  		case whereKindRightParen:
   391  			buf.WriteByte(')')
   392  		case whereKindIn, whereKindNotIn:
   393  			ln := len(where.args)
   394  			// WHERE IN () is invalid sql, so it is difficult to simply run code like:
   395  			// for _, u := range model.Users(qm.WhereIn("id IN ?",uids...)).AllP(db) {
   396  			//    ...
   397  			// }
   398  			// instead when we see empty IN we produce 1=0 so it can still be chained
   399  			// with other queries
   400  			if ln == 0 {
   401  				if where.kind == whereKindIn {
   402  					buf.WriteString("(1=0)")
   403  				} else if where.kind == whereKindNotIn {
   404  					buf.WriteString("(1=1)")
   405  				}
   406  				break
   407  			}
   408  
   409  			var matches []string
   410  			if where.kind == whereKindIn {
   411  				matches = rgxInClause.FindStringSubmatch(where.clause)
   412  			} else {
   413  				matches = rgxNotInClause.FindStringSubmatch(where.clause)
   414  			}
   415  
   416  			// If we can't find any matches attempt a simple replace with 1 group.
   417  			// Clauses that fit this criteria will not be able to contain ? in their
   418  			// column name side, however if this case is being hit then the regexp
   419  			// probably needs adjustment, or the user is passing in invalid clauses.
   420  			if matches == nil {
   421  				clause, count := convertInQuestionMarks(q.dialect.UseIndexPlaceholders, where.clause, startAt, 1, ln)
   422  				if !manualParens {
   423  					buf.WriteByte('(')
   424  				}
   425  				buf.WriteString(clause)
   426  				if !manualParens {
   427  					buf.WriteByte(')')
   428  				}
   429  				args = append(args, where.args...)
   430  				startAt += count
   431  				break
   432  			}
   433  
   434  			leftSide := strings.TrimSpace(matches[1])
   435  			rightSide := strings.TrimSpace(matches[2])
   436  			// If matches are found, we have to parse the left side (column side)
   437  			// of the clause to determine how many columns they are using.
   438  			// This number determines the groupAt for the convert function.
   439  			cols := strings.Split(leftSide, ",")
   440  			cols = strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, cols)
   441  			groupAt := len(cols)
   442  
   443  			var leftClause string
   444  			var leftCount int
   445  			if q.dialect.UseIndexPlaceholders {
   446  				leftClause, leftCount = convertQuestionMarks(strings.Join(cols, ","), startAt)
   447  			} else {
   448  				// Count the number of cols that are question marks, so we know
   449  				// how much to offset convertInQuestionMarks by
   450  				for _, v := range cols {
   451  					if v == "?" {
   452  						leftCount++
   453  					}
   454  				}
   455  				leftClause = strings.Join(cols, ",")
   456  			}
   457  			rightClause, rightCount := convertInQuestionMarks(q.dialect.UseIndexPlaceholders, rightSide, startAt+leftCount, groupAt, ln-leftCount)
   458  			if !manualParens {
   459  				buf.WriteByte('(')
   460  			}
   461  			buf.WriteString(leftClause)
   462  			if where.kind == whereKindIn {
   463  				buf.WriteString(" IN ")
   464  			} else if where.kind == whereKindNotIn {
   465  				buf.WriteString(" NOT IN ")
   466  			}
   467  			buf.WriteString(rightClause)
   468  			if !manualParens {
   469  				buf.WriteByte(')')
   470  			}
   471  			startAt += leftCount + rightCount
   472  			args = append(args, where.args...)
   473  		default:
   474  			panic("unknown where type")
   475  		}
   476  	}
   477  
   478  	return buf.String(), args
   479  }
   480  
   481  // convertInQuestionMarks finds the first unescaped occurrence of ? and swaps it
   482  // with a list of numbered placeholders, starting at startAt.
   483  // It uses groupAt to determine how many placeholders should be in each group,
   484  // for example, groupAt 2 would result in: (($1,$2),($3,$4))
   485  // and groupAt 1 would result in ($1,$2,$3,$4)
   486  func convertInQuestionMarks(UseIndexPlaceholders bool, clause string, startAt, groupAt, total int) (string, int) {
   487  	if startAt == 0 || len(clause) == 0 {
   488  		panic("Not a valid start number.")
   489  	}
   490  
   491  	paramBuf := strmangle.GetBuffer()
   492  	defer strmangle.PutBuffer(paramBuf)
   493  
   494  	foundAt := -1
   495  	for i := 0; i < len(clause); i++ {
   496  		if (i == 0 && clause[i] == '?') || (clause[i] == '?' && clause[i-1] != '\\') {
   497  			foundAt = i
   498  			break
   499  		}
   500  	}
   501  
   502  	if foundAt == -1 {
   503  		return strings.ReplaceAll(clause, `\?`, "?"), 0
   504  	}
   505  
   506  	paramBuf.WriteString(clause[:foundAt])
   507  	paramBuf.WriteByte('(')
   508  	paramBuf.WriteString(strmangle.Placeholders(UseIndexPlaceholders, total, startAt, groupAt))
   509  	paramBuf.WriteByte(')')
   510  	paramBuf.WriteString(clause[foundAt+1:])
   511  
   512  	// Remove all backslashes from escaped question-marks
   513  	ret := strings.ReplaceAll(paramBuf.String(), `\?`, "?")
   514  	return ret, total
   515  }
   516  
   517  // convertQuestionMarks converts each occurrence of ? with $<number>
   518  // where <number> is an incrementing digit starting at startAt.
   519  // If question-mark (?) is escaped using back-slash (\), it will be ignored.
   520  func convertQuestionMarks(clause string, startAt int) (string, int) {
   521  	if startAt == 0 {
   522  		panic("Not a valid start number.")
   523  	}
   524  
   525  	paramBuf := strmangle.GetBuffer()
   526  	defer strmangle.PutBuffer(paramBuf)
   527  	paramIndex := 0
   528  	total := 0
   529  
   530  	for {
   531  		if paramIndex >= len(clause) {
   532  			break
   533  		}
   534  
   535  		clause = clause[paramIndex:]
   536  		paramIndex = strings.IndexByte(clause, '?')
   537  
   538  		if paramIndex == -1 {
   539  			paramBuf.WriteString(clause)
   540  			break
   541  		}
   542  
   543  		escapeIndex := strings.Index(clause, `\?`)
   544  		if escapeIndex != -1 && paramIndex > escapeIndex {
   545  			paramBuf.WriteString(clause[:escapeIndex] + "?")
   546  			paramIndex++
   547  			continue
   548  		}
   549  
   550  		paramBuf.WriteString(clause[:paramIndex] + fmt.Sprintf("$%d", startAt))
   551  		total++
   552  		startAt++
   553  		paramIndex++
   554  	}
   555  
   556  	return paramBuf.String(), total
   557  }
   558  
   559  // parseFromClause will parse something that looks like
   560  // a
   561  // a b
   562  // a as b
   563  func parseFromClause(toks []string) (alias, name string, ok bool) {
   564  	if len(toks) > 3 {
   565  		toks = toks[:3]
   566  	}
   567  
   568  	sawIdent, sawAs := false, false
   569  	for _, tok := range toks {
   570  		if t := strings.ToLower(tok); sawIdent && t == "as" {
   571  			sawAs = true
   572  			continue
   573  		} else if sawIdent && t == "on" {
   574  			break
   575  		}
   576  
   577  		if !rgxIdentifier.MatchString(tok) {
   578  			break
   579  		}
   580  
   581  		if sawIdent || sawAs {
   582  			alias = strings.Trim(tok, `"`)
   583  			break
   584  		}
   585  
   586  		name = strings.Trim(tok, `"`)
   587  		sawIdent = true
   588  		ok = true
   589  	}
   590  
   591  	return alias, name, ok
   592  }
   593  
   594  func writeComment(q *Query, buf *bytes.Buffer) {
   595  	if len(q.comment) == 0 {
   596  		return
   597  	}
   598  
   599  	lines := strings.Split(q.comment, "\n")
   600  	for _, line := range lines {
   601  		buf.WriteString("-- ")
   602  		buf.WriteString(line)
   603  		buf.WriteByte('\n')
   604  	}
   605  }
   606  
   607  func writeCTEs(q *Query, buf *bytes.Buffer, args *[]interface{}) {
   608  	if len(q.withs) == 0 {
   609  		return
   610  	}
   611  
   612  	buf.WriteString("WITH")
   613  	argsLen := len(*args)
   614  	withBuf := strmangle.GetBuffer()
   615  	lastPos := len(q.withs) - 1
   616  	for i, w := range q.withs {
   617  		fmt.Fprintf(withBuf, " %s", w.clause)
   618  		if i >= 0 && i < lastPos {
   619  			withBuf.WriteByte(',')
   620  		}
   621  		*args = append(*args, w.args...)
   622  	}
   623  	withBuf.WriteByte(' ')
   624  	var resp string
   625  	if q.dialect.UseIndexPlaceholders {
   626  		resp, _ = convertQuestionMarks(withBuf.String(), argsLen+1)
   627  	} else {
   628  		resp = withBuf.String()
   629  	}
   630  	fmt.Fprintf(buf, resp)
   631  	strmangle.PutBuffer(withBuf)
   632  }