github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/sql_builder.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"strings"
     7  	"sync"
     8  
     9  	"github.com/gobuffalo/pop/v6/columns"
    10  	"github.com/gobuffalo/pop/v6/logging"
    11  	"github.com/jmoiron/sqlx"
    12  )
    13  
    14  type sqlBuilder struct {
    15  	Query      Query
    16  	Model      *Model
    17  	AddColumns []string
    18  	sql        string
    19  	args       []interface{}
    20  	isCompiled bool
    21  	err        error
    22  }
    23  
    24  func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder {
    25  	return &sqlBuilder{
    26  		Query:      q,
    27  		Model:      m,
    28  		AddColumns: addColumns,
    29  		args:       []interface{}{},
    30  		isCompiled: false,
    31  	}
    32  }
    33  
    34  var (
    35  	regexpMatchLimit    = regexp.MustCompile(`(?i).*\s+limit\s+[0-9]*(\s?,\s?[0-9]*)?$`)
    36  	regexpMatchOffset   = regexp.MustCompile(`(?i).*\s+offset\s+[0-9]*$`)
    37  	regexpMatchRowsOnly = regexp.MustCompile(`(?i).*\s+rows only`)
    38  	regexpMatchNames    = regexp.MustCompile("(?i).*;+.*") // https://play.golang.org/p/FAmre5Sjin5
    39  )
    40  
    41  func hasLimitOrOffset(sqlString string) bool {
    42  	trimmedSQL := strings.TrimSpace(sqlString)
    43  	if regexpMatchLimit.MatchString(trimmedSQL) {
    44  		return true
    45  	}
    46  
    47  	if regexpMatchOffset.MatchString(trimmedSQL) {
    48  		return true
    49  	}
    50  
    51  	if regexpMatchRowsOnly.MatchString(trimmedSQL) {
    52  		return true
    53  	}
    54  
    55  	return false
    56  }
    57  
    58  func (sq *sqlBuilder) String() string {
    59  	if !sq.isCompiled {
    60  		sq.compile()
    61  	}
    62  	return sq.sql
    63  }
    64  
    65  func (sq *sqlBuilder) Args() []interface{} {
    66  	if !sq.isCompiled {
    67  		sq.compile()
    68  	}
    69  	return sq.args
    70  }
    71  
    72  var inRegex = regexp.MustCompile(`(?i)in\s*\(\s*\?\s*\)`)
    73  
    74  func (sq *sqlBuilder) compile() {
    75  	if sq.sql == "" {
    76  		if sq.Query.RawSQL.Fragment != "" {
    77  			if sq.Query.Paginator != nil && !hasLimitOrOffset(sq.Query.RawSQL.Fragment) {
    78  				sq.sql = sq.buildPaginationClauses(sq.Query.RawSQL.Fragment)
    79  			} else {
    80  				if sq.Query.Paginator != nil {
    81  					log(logging.Warn, "Query already contains pagination")
    82  				}
    83  				sq.sql = sq.Query.RawSQL.Fragment
    84  			}
    85  			sq.args = sq.Query.RawSQL.Arguments
    86  		} else {
    87  			if sq.Model == nil {
    88  				sq.err = fmt.Errorf("sqlBuilder.compile() called but no RawSQL and Model specified")
    89  				return
    90  			}
    91  			switch sq.Query.Operation {
    92  			case Select:
    93  				sq.sql = sq.buildSelectSQL()
    94  			case Delete:
    95  				sq.sql = sq.buildDeleteSQL()
    96  			default:
    97  				panic("unexpected query operation " + sq.Query.Operation)
    98  			}
    99  		}
   100  
   101  		if inRegex.MatchString(sq.sql) {
   102  			s, args, err := sqlx.In(sq.sql, sq.Args()...)
   103  			if err == nil {
   104  				sq.sql = s
   105  				sq.args = args
   106  			}
   107  		}
   108  		sq.sql = sq.Query.Connection.Dialect.TranslateSQL(sq.sql)
   109  	}
   110  }
   111  
   112  func (sq *sqlBuilder) buildSelectSQL() string {
   113  	cols := sq.buildColumns()
   114  
   115  	fc := sq.buildfromClauses()
   116  
   117  	sql := fmt.Sprintf("SELECT %s FROM %s", cols.Readable().SelectString(), fc)
   118  
   119  	sql = sq.buildJoinClauses(sql)
   120  	sql = sq.buildWhereClauses(sql)
   121  	sql = sq.buildGroupClauses(sql)
   122  	sql = sq.buildOrderClauses(sql)
   123  	sql = sq.buildPaginationClauses(sql)
   124  
   125  	return sql
   126  }
   127  
   128  func (sq *sqlBuilder) buildDeleteSQL() string {
   129  	fc := sq.buildfromClauses()
   130  
   131  	sql := fmt.Sprintf("DELETE FROM %s", fc)
   132  
   133  	sql = sq.buildWhereClauses(sql)
   134  
   135  	// paginated delete supported by sqlite and mysql
   136  	// > If SQLite is compiled with the SQLITE_ENABLE_UPDATE_DELETE_LIMIT compile-time option [...] - from https://www.sqlite.org/lang_delete.html
   137  	//
   138  	// not generic enough IMO, therefore excluded
   139  	//
   140  	//switch sq.Query.Connection.Dialect.Name() {
   141  	//case nameMySQL, nameSQLite3:
   142  	//	sql = sq.buildOrderClauses(sql)
   143  	//	sql = sq.buildPaginationClauses(sql)
   144  	//}
   145  
   146  	return sql
   147  }
   148  
   149  func (sq *sqlBuilder) buildfromClauses() fromClauses {
   150  	models := []*Model{
   151  		sq.Model,
   152  	}
   153  	for _, mc := range sq.Query.belongsToThroughClauses {
   154  		models = append(models, mc.Through)
   155  	}
   156  
   157  	fc := sq.Query.fromClauses
   158  	for _, m := range models {
   159  		tableName := m.TableName()
   160  		asName := m.Alias()
   161  		fc = append(fc, fromClause{
   162  			From: tableName,
   163  			As:   asName,
   164  		})
   165  	}
   166  
   167  	return fc
   168  }
   169  
   170  func (sq *sqlBuilder) buildWhereClauses(sql string) string {
   171  	mcs := sq.Query.belongsToThroughClauses
   172  	for _, mc := range mcs {
   173  		sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.associationName()), mc.BelongsTo.ID())
   174  		sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.associationName()))
   175  	}
   176  
   177  	wc := sq.Query.whereClauses
   178  	if len(wc) > 0 {
   179  		sql = fmt.Sprintf("%s WHERE %s", sql, wc.Join(" AND "))
   180  		sq.args = append(sq.args, wc.Args()...)
   181  	}
   182  	return sql
   183  }
   184  
   185  func (sq *sqlBuilder) buildJoinClauses(sql string) string {
   186  	oc := sq.Query.joinClauses
   187  	if len(oc) > 0 {
   188  		sql += " " + oc.String()
   189  		for i := range oc {
   190  			sq.args = append(sq.args, oc[i].Arguments...)
   191  		}
   192  	}
   193  
   194  	return sql
   195  }
   196  
   197  func (sq *sqlBuilder) buildGroupClauses(sql string) string {
   198  	gc := sq.Query.groupClauses
   199  	if len(gc) > 0 {
   200  		sql = fmt.Sprintf("%s GROUP BY %s", sql, gc.String())
   201  
   202  		hc := sq.Query.havingClauses
   203  		if len(hc) > 0 {
   204  			sql = fmt.Sprintf("%s HAVING %s", sql, hc.String())
   205  		}
   206  
   207  		for i := range hc {
   208  			sq.args = append(sq.args, hc[i].Arguments...)
   209  		}
   210  	}
   211  
   212  	return sql
   213  }
   214  
   215  func (sq *sqlBuilder) buildOrderClauses(sql string) string {
   216  	oc := sq.Query.orderClauses
   217  	if len(oc) > 0 {
   218  		orderSQL := oc.Join(", ")
   219  		if regexpMatchNames.MatchString(orderSQL) {
   220  			warningMsg := fmt.Sprintf("Order clause(s) contains invalid characters: %s", orderSQL)
   221  			log(logging.Warn, warningMsg)
   222  			return sql
   223  		}
   224  
   225  		sql = fmt.Sprintf("%s ORDER BY %s", sql, orderSQL)
   226  		sq.args = append(sq.args, oc.Args()...)
   227  	}
   228  	return sql
   229  }
   230  
   231  func (sq *sqlBuilder) buildPaginationClauses(sql string) string {
   232  	if sq.Query.limitResults > 0 && sq.Query.Paginator == nil {
   233  		sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.limitResults)
   234  	}
   235  	if sq.Query.Paginator != nil {
   236  		sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.Paginator.PerPage)
   237  		sql = fmt.Sprintf("%s OFFSET %d", sql, sq.Query.Paginator.Offset)
   238  	}
   239  	return sql
   240  }
   241  
   242  // columnCache is used to prevent columns rebuilding.
   243  var columnCache = map[string]columns.Columns{}
   244  var columnCacheMutex = sync.RWMutex{}
   245  
   246  func (sq *sqlBuilder) buildColumns() columns.Columns {
   247  	tableName := sq.Model.TableName()
   248  	asName := sq.Model.Alias()
   249  	acl := len(sq.AddColumns)
   250  	if acl == 0 {
   251  		columnCacheMutex.RLock()
   252  		cols, ok := columnCache[tableName]
   253  		columnCacheMutex.RUnlock()
   254  		// if alias is the same, don't remake columns
   255  		if ok && cols.TableAlias == asName {
   256  			return cols
   257  		}
   258  		cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName, sq.Model.IDField())
   259  		columnCacheMutex.Lock()
   260  		columnCache[tableName] = cols
   261  		columnCacheMutex.Unlock()
   262  		return cols
   263  	}
   264  
   265  	// acl > 0
   266  	cols := columns.NewColumns("", sq.Model.IDField())
   267  	cols.Add(sq.AddColumns...)
   268  	return cols
   269  }