github.com/paweljw/pop/v5@v5.4.6/sql_builder.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"strings"
     7  	"sync"
     8  
     9  	"github.com/gobuffalo/pop/v5/columns"
    10  	"github.com/gobuffalo/pop/v5/logging"
    11  	"github.com/jmoiron/sqlx"
    12  )
    13  
    14  type sqlBuilder struct {
    15  	Query      Query
    16  	Model      *Model
    17  	AddColumns []string
    18  	sql        string
    19  	args       []interface{}
    20  }
    21  
    22  func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder {
    23  	return &sqlBuilder{
    24  		Query:      q,
    25  		Model:      m,
    26  		AddColumns: addColumns,
    27  		args:       []interface{}{},
    28  	}
    29  }
    30  
    31  var (
    32  	regexpMatchLimit    = regexp.MustCompile(`(?i).*\s+limit\s+[0-9]*(\s?,\s?[0-9]*)?$`)
    33  	regexpMatchOffset   = regexp.MustCompile(`(?i).*\s+offset\s+[0-9]*$`)
    34  	regexpMatchRowsOnly = regexp.MustCompile(`(?i).*\s+rows only`)
    35  	regexpMatchNames    = regexp.MustCompile("(?i).*;+.*") // https://play.golang.org/p/FAmre5Sjin5
    36  )
    37  
    38  func hasLimitOrOffset(sqlString string) bool {
    39  	trimmedSQL := strings.TrimSpace(sqlString)
    40  	if regexpMatchLimit.MatchString(trimmedSQL) {
    41  		return true
    42  	}
    43  
    44  	if regexpMatchOffset.MatchString(trimmedSQL) {
    45  		return true
    46  	}
    47  
    48  	if regexpMatchRowsOnly.MatchString(trimmedSQL) {
    49  		return true
    50  	}
    51  
    52  	return false
    53  }
    54  
    55  func (sq *sqlBuilder) String() string {
    56  	if sq.sql == "" {
    57  		sq.compile()
    58  	}
    59  	return sq.sql
    60  }
    61  
    62  func (sq *sqlBuilder) Args() []interface{} {
    63  	if len(sq.args) == 0 {
    64  		if len(sq.Query.RawSQL.Arguments) > 0 {
    65  			sq.args = sq.Query.RawSQL.Arguments
    66  		} else {
    67  			sq.compile()
    68  		}
    69  	}
    70  	return sq.args
    71  }
    72  
    73  var inRegex = regexp.MustCompile(`(?i)in\s*\(\s*\?\s*\)`)
    74  
    75  func (sq *sqlBuilder) compile() {
    76  	if sq.sql == "" {
    77  		if sq.Query.RawSQL.Fragment != "" {
    78  			if sq.Query.Paginator != nil && !hasLimitOrOffset(sq.Query.RawSQL.Fragment) {
    79  				sq.sql = sq.buildPaginationClauses(sq.Query.RawSQL.Fragment)
    80  			} else {
    81  				if sq.Query.Paginator != nil {
    82  					log(logging.Warn, "Query already contains pagination")
    83  				}
    84  				sq.sql = sq.Query.RawSQL.Fragment
    85  			}
    86  		} else {
    87  			sq.sql = sq.buildSelectSQL()
    88  		}
    89  
    90  		if inRegex.MatchString(sq.sql) {
    91  			s, args, err := sqlx.In(sq.sql, sq.Args()...)
    92  			if err == nil {
    93  				sq.sql = s
    94  				sq.args = args
    95  			}
    96  		}
    97  		sq.sql = sq.Query.Connection.Dialect.TranslateSQL(sq.sql)
    98  	}
    99  }
   100  
   101  func (sq *sqlBuilder) buildSelectSQL() string {
   102  	cols := sq.buildColumns()
   103  
   104  	fc := sq.buildfromClauses()
   105  
   106  	sql := fmt.Sprintf("SELECT %s FROM %s", cols.Readable().SelectString(), fc)
   107  
   108  	sql = sq.buildJoinClauses(sql)
   109  	sql = sq.buildWhereClauses(sql)
   110  	sql = sq.buildGroupClauses(sql)
   111  	sql = sq.buildOrderClauses(sql)
   112  	sql = sq.buildPaginationClauses(sql)
   113  
   114  	return sql
   115  }
   116  
   117  func (sq *sqlBuilder) buildfromClauses() fromClauses {
   118  	models := []*Model{
   119  		sq.Model,
   120  	}
   121  	for _, mc := range sq.Query.belongsToThroughClauses {
   122  		models = append(models, mc.Through)
   123  	}
   124  
   125  	fc := sq.Query.fromClauses
   126  	for _, m := range models {
   127  		tableName := m.TableName()
   128  		asName := m.alias()
   129  		fc = append(fc, fromClause{
   130  			From: tableName,
   131  			As:   asName,
   132  		})
   133  	}
   134  
   135  	return fc
   136  }
   137  
   138  func (sq *sqlBuilder) buildWhereClauses(sql string) string {
   139  	mcs := sq.Query.belongsToThroughClauses
   140  	for _, mc := range mcs {
   141  		sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.associationName()), mc.BelongsTo.ID())
   142  		sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.associationName()))
   143  	}
   144  
   145  	wc := sq.Query.whereClauses
   146  	if len(wc) > 0 {
   147  		sql = fmt.Sprintf("%s WHERE %s", sql, wc.Join(" AND "))
   148  		sq.args = append(sq.args, wc.Args()...)
   149  	}
   150  	return sql
   151  }
   152  
   153  func (sq *sqlBuilder) buildJoinClauses(sql string) string {
   154  	oc := sq.Query.joinClauses
   155  	if len(oc) > 0 {
   156  		sql += " " + oc.String()
   157  		for i := range oc {
   158  			sq.args = append(sq.args, oc[i].Arguments...)
   159  		}
   160  	}
   161  
   162  	return sql
   163  }
   164  
   165  func (sq *sqlBuilder) buildGroupClauses(sql string) string {
   166  	gc := sq.Query.groupClauses
   167  	if len(gc) > 0 {
   168  		sql = fmt.Sprintf("%s GROUP BY %s", sql, gc.String())
   169  
   170  		hc := sq.Query.havingClauses
   171  		if len(hc) > 0 {
   172  			sql = fmt.Sprintf("%s HAVING %s", sql, hc.String())
   173  		}
   174  
   175  		for i := range hc {
   176  			sq.args = append(sq.args, hc[i].Arguments...)
   177  		}
   178  	}
   179  
   180  	return sql
   181  }
   182  
   183  func (sq *sqlBuilder) buildOrderClauses(sql string) string {
   184  	oc := sq.Query.orderClauses
   185  	if len(oc) > 0 {
   186  		orderSQL := oc.Join(", ")
   187  		if regexpMatchNames.MatchString(orderSQL) {
   188  			warningMsg := fmt.Sprintf("Order clause(s) contains invalid characters: %s", orderSQL)
   189  			log(logging.Warn, warningMsg)
   190  			return sql
   191  		}
   192  
   193  		sql = fmt.Sprintf("%s ORDER BY %s", sql, orderSQL)
   194  		sq.args = append(sq.args, oc.Args()...)
   195  	}
   196  	return sql
   197  }
   198  
   199  func (sq *sqlBuilder) buildPaginationClauses(sql string) string {
   200  	if sq.Query.limitResults > 0 && sq.Query.Paginator == nil {
   201  		sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.limitResults)
   202  	}
   203  	if sq.Query.Paginator != nil {
   204  		sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.Paginator.PerPage)
   205  		sql = fmt.Sprintf("%s OFFSET %d", sql, sq.Query.Paginator.Offset)
   206  	}
   207  	return sql
   208  }
   209  
   210  // columnCache is used to prevent columns rebuilding.
   211  var columnCache = map[string]columns.Columns{}
   212  var columnCacheMutex = sync.RWMutex{}
   213  
   214  func (sq *sqlBuilder) buildColumns() columns.Columns {
   215  	tableName := sq.Model.TableName()
   216  	asName := sq.Model.alias()
   217  	acl := len(sq.AddColumns)
   218  	if acl == 0 {
   219  		columnCacheMutex.RLock()
   220  		cols, ok := columnCache[tableName]
   221  		columnCacheMutex.RUnlock()
   222  		// if alias is the same, don't remake columns
   223  		if ok && cols.TableAlias == asName {
   224  			return cols
   225  		}
   226  		cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName, sq.Model.IDField())
   227  		columnCacheMutex.Lock()
   228  		columnCache[tableName] = cols
   229  		columnCacheMutex.Unlock()
   230  		return cols
   231  	}
   232  
   233  	// acl > 0
   234  	cols := columns.NewColumns("", sq.Model.IDField())
   235  	cols.Add(sq.AddColumns...)
   236  	return cols
   237  }