github.com/dkishere/pop/v6@v6.103.1/sql_builder.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"strings"
     7  	"sync"
     8  
     9  	"github.com/dkishere/pop/v6/columns"
    10  	"github.com/dkishere/pop/v6/logging"
    11  	"github.com/jmoiron/sqlx"
    12  )
    13  
    14  type sqlBuilder struct {
    15  	Query      Query
    16  	Model      *Model
    17  	AddColumns []string
    18  	sql        string
    19  	args       []interface{}
    20  }
    21  
    22  func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder {
    23  	return &sqlBuilder{
    24  		Query:      q,
    25  		Model:      m,
    26  		AddColumns: addColumns,
    27  		args:       []interface{}{},
    28  	}
    29  }
    30  
    31  var (
    32  	regexpMatchLimit    = regexp.MustCompile(`(?i).*\s+limit\s+[0-9]*(\s?,\s?[0-9]*)?$`)
    33  	regexpMatchOffset   = regexp.MustCompile(`(?i).*\s+offset\s+[0-9]*$`)
    34  	regexpMatchRowsOnly = regexp.MustCompile(`(?i).*\s+rows only`)
    35  	regexpMatchNames    = regexp.MustCompile("(?i).*;+.*") // https://play.golang.org/p/FAmre5Sjin5
    36  )
    37  
    38  func hasLimitOrOffset(sqlString string) bool {
    39  	trimmedSQL := strings.TrimSpace(sqlString)
    40  	if regexpMatchLimit.MatchString(trimmedSQL) {
    41  		return true
    42  	}
    43  
    44  	if regexpMatchOffset.MatchString(trimmedSQL) {
    45  		return true
    46  	}
    47  
    48  	if regexpMatchRowsOnly.MatchString(trimmedSQL) {
    49  		return true
    50  	}
    51  
    52  	return false
    53  }
    54  
    55  func (sq *sqlBuilder) String() string {
    56  	if sq.sql == "" {
    57  		sq.compile()
    58  	}
    59  	return sq.sql
    60  }
    61  
    62  func (sq *sqlBuilder) Args() []interface{} {
    63  	if len(sq.args) == 0 {
    64  		if len(sq.Query.RawSQL.Arguments) > 0 {
    65  			sq.args = sq.Query.RawSQL.Arguments
    66  		} else {
    67  			sq.compile()
    68  		}
    69  	}
    70  	return sq.args
    71  }
    72  
    73  var inRegex = regexp.MustCompile(`(?i)in\s*\(\s*\?\s*\)`)
    74  
    75  func (sq *sqlBuilder) compile() {
    76  	if sq.sql == "" {
    77  		if sq.Query.RawSQL.Fragment != "" {
    78  			if sq.Query.Paginator != nil && !hasLimitOrOffset(sq.Query.RawSQL.Fragment) {
    79  				sq.sql = sq.buildPaginationClauses(sq.Query.RawSQL.Fragment)
    80  			} else {
    81  				if sq.Query.Paginator != nil {
    82  					log(logging.Warn, "Query already contains pagination")
    83  				}
    84  				sq.sql = sq.Query.RawSQL.Fragment
    85  			}
    86  		} else {
    87  			switch sq.Query.Operation {
    88  			case Select:
    89  				sq.sql = sq.buildSelectSQL()
    90  			case Delete:
    91  				sq.sql = sq.buildDeleteSQL()
    92  			default:
    93  				panic("unexpected query operation " + sq.Query.Operation)
    94  			}
    95  		}
    96  
    97  		if inRegex.MatchString(sq.sql) {
    98  			s, args, err := sqlx.In(sq.sql, sq.Args()...)
    99  			if err == nil {
   100  				sq.sql = s
   101  				sq.args = args
   102  			}
   103  		}
   104  		sq.sql = sq.Query.Connection.Dialect.TranslateSQL(sq.sql)
   105  	}
   106  }
   107  
   108  func (sq *sqlBuilder) buildSelectSQL() string {
   109  	cols := sq.buildColumns()
   110  
   111  	fc := sq.buildfromClauses()
   112  
   113  	sql := fmt.Sprintf("SELECT %s FROM %s", cols.Readable().SelectString(), fc)
   114  
   115  	sql = sq.buildJoinClauses(sql)
   116  	sql = sq.buildWhereClauses(sql)
   117  	sql = sq.buildGroupClauses(sql)
   118  	sql = sq.buildOrderClauses(sql)
   119  	sql = sq.buildPaginationClauses(sql)
   120  
   121  	return sql
   122  }
   123  
   124  func (sq *sqlBuilder) buildDeleteSQL() string {
   125  	fc := sq.buildfromClauses()
   126  
   127  	sql := fmt.Sprintf("DELETE FROM %s", fc)
   128  
   129  	sql = sq.buildWhereClauses(sql)
   130  
   131  	// paginated delete supported by sqlite and mysql
   132  	// > If SQLite is compiled with the SQLITE_ENABLE_UPDATE_DELETE_LIMIT compile-time option [...] - from https://www.sqlite.org/lang_delete.html
   133  	//
   134  	// not generic enough IMO, therefore excluded
   135  	//
   136  	//switch sq.Query.Connection.Dialect.Name() {
   137  	//case nameMySQL, nameSQLite3:
   138  	//	sql = sq.buildOrderClauses(sql)
   139  	//	sql = sq.buildPaginationClauses(sql)
   140  	//}
   141  
   142  	return sql
   143  }
   144  
   145  func (sq *sqlBuilder) buildfromClauses() fromClauses {
   146  	models := []*Model{
   147  		sq.Model,
   148  	}
   149  	for _, mc := range sq.Query.belongsToThroughClauses {
   150  		models = append(models, mc.Through)
   151  	}
   152  
   153  	fc := sq.Query.fromClauses
   154  	for _, m := range models {
   155  		tableName := m.TableName()
   156  		asName := m.Alias()
   157  		fc = append(fc, fromClause{
   158  			From: tableName,
   159  			As:   asName,
   160  		})
   161  	}
   162  
   163  	return fc
   164  }
   165  
   166  func (sq *sqlBuilder) buildWhereClauses(sql string) string {
   167  	mcs := sq.Query.belongsToThroughClauses
   168  	for _, mc := range mcs {
   169  		sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.associationName()), mc.BelongsTo.ID())
   170  		sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.associationName()))
   171  	}
   172  
   173  	wc := sq.Query.whereClauses
   174  	if len(wc) > 0 {
   175  		sql = fmt.Sprintf("%s WHERE %s", sql, wc.Join(" AND "))
   176  		sq.args = append(sq.args, wc.Args()...)
   177  	}
   178  	return sql
   179  }
   180  
   181  func (sq *sqlBuilder) buildJoinClauses(sql string) string {
   182  	oc := sq.Query.joinClauses
   183  	if len(oc) > 0 {
   184  		sql += " " + oc.String()
   185  		for i := range oc {
   186  			sq.args = append(sq.args, oc[i].Arguments...)
   187  		}
   188  	}
   189  
   190  	return sql
   191  }
   192  
   193  func (sq *sqlBuilder) buildGroupClauses(sql string) string {
   194  	gc := sq.Query.groupClauses
   195  	if len(gc) > 0 {
   196  		sql = fmt.Sprintf("%s GROUP BY %s", sql, gc.String())
   197  
   198  		hc := sq.Query.havingClauses
   199  		if len(hc) > 0 {
   200  			sql = fmt.Sprintf("%s HAVING %s", sql, hc.String())
   201  		}
   202  
   203  		for i := range hc {
   204  			sq.args = append(sq.args, hc[i].Arguments...)
   205  		}
   206  	}
   207  
   208  	return sql
   209  }
   210  
   211  func (sq *sqlBuilder) buildOrderClauses(sql string) string {
   212  	oc := sq.Query.orderClauses
   213  	if len(oc) > 0 {
   214  		orderSQL := oc.Join(", ")
   215  		if regexpMatchNames.MatchString(orderSQL) {
   216  			warningMsg := fmt.Sprintf("Order clause(s) contains invalid characters: %s", orderSQL)
   217  			log(logging.Warn, warningMsg)
   218  			return sql
   219  		}
   220  
   221  		sql = fmt.Sprintf("%s ORDER BY %s", sql, orderSQL)
   222  		sq.args = append(sq.args, oc.Args()...)
   223  	}
   224  	return sql
   225  }
   226  
   227  func (sq *sqlBuilder) buildPaginationClauses(sql string) string {
   228  	if sq.Query.limitResults > 0 && sq.Query.Paginator == nil {
   229  		sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.limitResults)
   230  	}
   231  	if sq.Query.Paginator != nil {
   232  		sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.Paginator.PerPage)
   233  		sql = fmt.Sprintf("%s OFFSET %d", sql, sq.Query.Paginator.Offset)
   234  	}
   235  	return sql
   236  }
   237  
   238  // columnCache is used to prevent columns rebuilding.
   239  var columnCache = map[string]columns.Columns{}
   240  var columnCacheMutex = sync.RWMutex{}
   241  
   242  func (sq *sqlBuilder) buildColumns() columns.Columns {
   243  	tableName := sq.Model.TableName()
   244  	asName := sq.Model.Alias()
   245  	acl := len(sq.AddColumns)
   246  	if acl == 0 {
   247  		columnCacheMutex.RLock()
   248  		cols, ok := columnCache[tableName]
   249  		columnCacheMutex.RUnlock()
   250  		// if alias is the same, don't remake columns
   251  		if ok && cols.TableAlias == asName {
   252  			return cols
   253  		}
   254  		cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName, sq.Model.IDField())
   255  		columnCacheMutex.Lock()
   256  		columnCache[tableName] = cols
   257  		columnCacheMutex.Unlock()
   258  		return cols
   259  	}
   260  
   261  	// acl > 0
   262  	cols := columns.NewColumns("", sq.Model.IDField())
   263  	cols.Add(sq.AddColumns...)
   264  	return cols
   265  }