github.com/pjdufour-truss/pop@v4.11.2-0.20190705085848-4c90b0ff4d5a+incompatible/sql_builder.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"strings"
     7  	"sync"
     8  
     9  	"github.com/gobuffalo/pop/columns"
    10  	"github.com/gobuffalo/pop/logging"
    11  	"github.com/jmoiron/sqlx"
    12  )
    13  
    14  type sqlBuilder struct {
    15  	Query      Query
    16  	Model      *Model
    17  	AddColumns []string
    18  	sql        string
    19  	args       []interface{}
    20  }
    21  
    22  func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder {
    23  	return &sqlBuilder{
    24  		Query:      q,
    25  		Model:      m,
    26  		AddColumns: addColumns,
    27  		args:       []interface{}{},
    28  	}
    29  }
    30  
    31  var (
    32  	regexpMatchLimit    = regexp.MustCompile(`(?i).*\s+limit\s+[0-9]*(\s?,\s?[0-9]*)?$`)
    33  	regexpMatchOffset   = regexp.MustCompile(`(?i).*\s+offset\s+[0-9]*$`)
    34  	regexpMatchRowsOnly = regexp.MustCompile(`(?i).*\s+rows only`)
    35  	regexpMatchNames    = regexp.MustCompile("(?i).*;+.*") // https://play.golang.org/p/FAmre5Sjin5
    36  )
    37  
    38  func hasLimitOrOffset(sqlString string) bool {
    39  	trimmedSQL := strings.TrimSpace(sqlString)
    40  	if regexpMatchLimit.MatchString(trimmedSQL) {
    41  		return true
    42  	}
    43  
    44  	if regexpMatchOffset.MatchString(trimmedSQL) {
    45  		return true
    46  	}
    47  
    48  	if regexpMatchRowsOnly.MatchString(trimmedSQL) {
    49  		return true
    50  	}
    51  
    52  	return false
    53  }
    54  
    55  func (sq *sqlBuilder) String() string {
    56  	if sq.sql == "" {
    57  		sq.compile()
    58  	}
    59  	return sq.sql
    60  }
    61  
    62  func (sq *sqlBuilder) Args() []interface{} {
    63  	if len(sq.args) == 0 {
    64  		if len(sq.Query.RawSQL.Arguments) > 0 {
    65  			sq.args = sq.Query.RawSQL.Arguments
    66  		} else {
    67  			sq.compile()
    68  		}
    69  	}
    70  	return sq.args
    71  }
    72  
    73  var inRegex = regexp.MustCompile(`(?i)in\s*\(\s*\?\s*\)`)
    74  
    75  func (sq *sqlBuilder) compile() {
    76  	if sq.sql == "" {
    77  		if sq.Query.RawSQL.Fragment != "" {
    78  			if sq.Query.Paginator != nil && !hasLimitOrOffset(sq.Query.RawSQL.Fragment) {
    79  				sq.sql = sq.buildPaginationClauses(sq.Query.RawSQL.Fragment)
    80  			} else {
    81  				if sq.Query.Paginator != nil {
    82  					log(logging.Warn, "Query already contains pagination")
    83  				}
    84  				sq.sql = sq.Query.RawSQL.Fragment
    85  			}
    86  		} else {
    87  			sq.sql = sq.buildSelectSQL()
    88  		}
    89  
    90  		if inRegex.MatchString(sq.sql) {
    91  			s, args, err := sqlx.In(sq.sql, sq.Args()...)
    92  			if err == nil {
    93  				sq.sql = s
    94  				sq.args = args
    95  			}
    96  		}
    97  		sq.sql = sq.Query.Connection.Dialect.TranslateSQL(sq.sql)
    98  	}
    99  }
   100  
   101  func (sq *sqlBuilder) buildSelectSQL() string {
   102  	cols := sq.buildColumns()
   103  
   104  	fc := sq.buildfromClauses()
   105  
   106  	sql := fmt.Sprintf("SELECT %s FROM %s", cols.Readable().SelectString(), fc)
   107  
   108  	sql = sq.buildJoinClauses(sql)
   109  	sql = sq.buildWhereClauses(sql)
   110  	sql = sq.buildGroupClauses(sql)
   111  	sql = sq.buildOrderClauses(sql)
   112  	sql = sq.buildPaginationClauses(sql)
   113  
   114  	return sql
   115  }
   116  
   117  func (sq *sqlBuilder) buildfromClauses() fromClauses {
   118  	models := []*Model{
   119  		sq.Model,
   120  	}
   121  	for _, mc := range sq.Query.belongsToThroughClauses {
   122  		models = append(models, mc.Through)
   123  	}
   124  
   125  	fc := sq.Query.fromClauses
   126  	for _, m := range models {
   127  		tableName := m.TableName()
   128  		asName := m.As
   129  		if asName == "" {
   130  			asName = strings.Replace(tableName, ".", "_", -1)
   131  		}
   132  		fc = append(fc, fromClause{
   133  			From: tableName,
   134  			As:   asName,
   135  		})
   136  	}
   137  
   138  	return fc
   139  }
   140  
   141  func (sq *sqlBuilder) buildWhereClauses(sql string) string {
   142  	mcs := sq.Query.belongsToThroughClauses
   143  	for _, mc := range mcs {
   144  		sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.associationName()), mc.BelongsTo.ID())
   145  		sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.associationName()))
   146  	}
   147  
   148  	wc := sq.Query.whereClauses
   149  	if len(wc) > 0 {
   150  		sql = fmt.Sprintf("%s WHERE %s", sql, wc.Join(" AND "))
   151  		sq.args = append(sq.args, wc.Args()...)
   152  	}
   153  	return sql
   154  }
   155  
   156  func (sq *sqlBuilder) buildJoinClauses(sql string) string {
   157  	oc := sq.Query.joinClauses
   158  	if len(oc) > 0 {
   159  		sql += " " + oc.String()
   160  		for i := range oc {
   161  			sq.args = append(sq.args, oc[i].Arguments...)
   162  		}
   163  	}
   164  
   165  	return sql
   166  }
   167  
   168  func (sq *sqlBuilder) buildGroupClauses(sql string) string {
   169  	gc := sq.Query.groupClauses
   170  	if len(gc) > 0 {
   171  		sql = fmt.Sprintf("%s GROUP BY %s", sql, gc.String())
   172  
   173  		hc := sq.Query.havingClauses
   174  		if len(hc) > 0 {
   175  			sql = fmt.Sprintf("%s HAVING %s", sql, hc.String())
   176  		}
   177  
   178  		for i := range hc {
   179  			sq.args = append(sq.args, hc[i].Arguments...)
   180  		}
   181  	}
   182  
   183  	return sql
   184  }
   185  
   186  func (sq *sqlBuilder) buildOrderClauses(sql string) string {
   187  	oc := sq.Query.orderClauses
   188  	if len(oc) > 0 {
   189  		orderSQL := oc.Join(", ")
   190  		if regexpMatchNames.MatchString(orderSQL) {
   191  			warningMsg := fmt.Sprintf("Order clause(s) contains invalid characters: %s", orderSQL)
   192  			log(logging.Warn, warningMsg)
   193  			return sql
   194  		}
   195  
   196  		sql = fmt.Sprintf("%s ORDER BY %s", sql, orderSQL)
   197  		sq.args = append(sq.args, oc.Args()...)
   198  	}
   199  	return sql
   200  }
   201  
   202  func (sq *sqlBuilder) buildPaginationClauses(sql string) string {
   203  	if sq.Query.limitResults > 0 && sq.Query.Paginator == nil {
   204  		sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.limitResults)
   205  	}
   206  	if sq.Query.Paginator != nil {
   207  		sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.Paginator.PerPage)
   208  		sql = fmt.Sprintf("%s OFFSET %d", sql, sq.Query.Paginator.Offset)
   209  	}
   210  	return sql
   211  }
   212  
   213  // columnCache is used to prevent columns rebuilding.
   214  var columnCache = map[string]columns.Columns{}
   215  var columnCacheMutex = sync.RWMutex{}
   216  
   217  func (sq *sqlBuilder) buildColumns() columns.Columns {
   218  	tableName := sq.Model.TableName()
   219  	asName := sq.Model.As
   220  	if asName == "" {
   221  		asName = strings.Replace(tableName, ".", "_", -1)
   222  	}
   223  	acl := len(sq.AddColumns)
   224  	if acl == 0 {
   225  		columnCacheMutex.RLock()
   226  		cols, ok := columnCache[tableName]
   227  		columnCacheMutex.RUnlock()
   228  		// if alias is the same, don't remake columns
   229  		if ok && cols.TableAlias == asName {
   230  			return cols
   231  		}
   232  		cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName)
   233  		columnCacheMutex.Lock()
   234  		columnCache[tableName] = cols
   235  		columnCacheMutex.Unlock()
   236  		return cols
   237  	}
   238  
   239  	// acl > 0
   240  	cols := columns.NewColumns("")
   241  	cols.Add(sq.AddColumns...)
   242  	return cols
   243  }