github.com/paweljw/pop/v5@v5.4.6/sql_builder.go (about) 1 package pop 2 3 import ( 4 "fmt" 5 "regexp" 6 "strings" 7 "sync" 8 9 "github.com/gobuffalo/pop/v5/columns" 10 "github.com/gobuffalo/pop/v5/logging" 11 "github.com/jmoiron/sqlx" 12 ) 13 14 type sqlBuilder struct { 15 Query Query 16 Model *Model 17 AddColumns []string 18 sql string 19 args []interface{} 20 } 21 22 func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder { 23 return &sqlBuilder{ 24 Query: q, 25 Model: m, 26 AddColumns: addColumns, 27 args: []interface{}{}, 28 } 29 } 30 31 var ( 32 regexpMatchLimit = regexp.MustCompile(`(?i).*\s+limit\s+[0-9]*(\s?,\s?[0-9]*)?$`) 33 regexpMatchOffset = regexp.MustCompile(`(?i).*\s+offset\s+[0-9]*$`) 34 regexpMatchRowsOnly = regexp.MustCompile(`(?i).*\s+rows only`) 35 regexpMatchNames = regexp.MustCompile("(?i).*;+.*") // https://play.golang.org/p/FAmre5Sjin5 36 ) 37 38 func hasLimitOrOffset(sqlString string) bool { 39 trimmedSQL := strings.TrimSpace(sqlString) 40 if regexpMatchLimit.MatchString(trimmedSQL) { 41 return true 42 } 43 44 if regexpMatchOffset.MatchString(trimmedSQL) { 45 return true 46 } 47 48 if regexpMatchRowsOnly.MatchString(trimmedSQL) { 49 return true 50 } 51 52 return false 53 } 54 55 func (sq *sqlBuilder) String() string { 56 if sq.sql == "" { 57 sq.compile() 58 } 59 return sq.sql 60 } 61 62 func (sq *sqlBuilder) Args() []interface{} { 63 if len(sq.args) == 0 { 64 if len(sq.Query.RawSQL.Arguments) > 0 { 65 sq.args = sq.Query.RawSQL.Arguments 66 } else { 67 sq.compile() 68 } 69 } 70 return sq.args 71 } 72 73 var inRegex = regexp.MustCompile(`(?i)in\s*\(\s*\?\s*\)`) 74 75 func (sq *sqlBuilder) compile() { 76 if sq.sql == "" { 77 if sq.Query.RawSQL.Fragment != "" { 78 if sq.Query.Paginator != nil && !hasLimitOrOffset(sq.Query.RawSQL.Fragment) { 79 sq.sql = sq.buildPaginationClauses(sq.Query.RawSQL.Fragment) 80 } else { 81 if sq.Query.Paginator != nil { 82 log(logging.Warn, "Query already contains pagination") 83 } 84 sq.sql = sq.Query.RawSQL.Fragment 85 } 86 } else { 87 sq.sql = sq.buildSelectSQL() 88 } 89 90 if inRegex.MatchString(sq.sql) { 91 s, args, err := sqlx.In(sq.sql, sq.Args()...) 92 if err == nil { 93 sq.sql = s 94 sq.args = args 95 } 96 } 97 sq.sql = sq.Query.Connection.Dialect.TranslateSQL(sq.sql) 98 } 99 } 100 101 func (sq *sqlBuilder) buildSelectSQL() string { 102 cols := sq.buildColumns() 103 104 fc := sq.buildfromClauses() 105 106 sql := fmt.Sprintf("SELECT %s FROM %s", cols.Readable().SelectString(), fc) 107 108 sql = sq.buildJoinClauses(sql) 109 sql = sq.buildWhereClauses(sql) 110 sql = sq.buildGroupClauses(sql) 111 sql = sq.buildOrderClauses(sql) 112 sql = sq.buildPaginationClauses(sql) 113 114 return sql 115 } 116 117 func (sq *sqlBuilder) buildfromClauses() fromClauses { 118 models := []*Model{ 119 sq.Model, 120 } 121 for _, mc := range sq.Query.belongsToThroughClauses { 122 models = append(models, mc.Through) 123 } 124 125 fc := sq.Query.fromClauses 126 for _, m := range models { 127 tableName := m.TableName() 128 asName := m.alias() 129 fc = append(fc, fromClause{ 130 From: tableName, 131 As: asName, 132 }) 133 } 134 135 return fc 136 } 137 138 func (sq *sqlBuilder) buildWhereClauses(sql string) string { 139 mcs := sq.Query.belongsToThroughClauses 140 for _, mc := range mcs { 141 sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.associationName()), mc.BelongsTo.ID()) 142 sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.associationName())) 143 } 144 145 wc := sq.Query.whereClauses 146 if len(wc) > 0 { 147 sql = fmt.Sprintf("%s WHERE %s", sql, wc.Join(" AND ")) 148 sq.args = append(sq.args, wc.Args()...) 149 } 150 return sql 151 } 152 153 func (sq *sqlBuilder) buildJoinClauses(sql string) string { 154 oc := sq.Query.joinClauses 155 if len(oc) > 0 { 156 sql += " " + oc.String() 157 for i := range oc { 158 sq.args = append(sq.args, oc[i].Arguments...) 159 } 160 } 161 162 return sql 163 } 164 165 func (sq *sqlBuilder) buildGroupClauses(sql string) string { 166 gc := sq.Query.groupClauses 167 if len(gc) > 0 { 168 sql = fmt.Sprintf("%s GROUP BY %s", sql, gc.String()) 169 170 hc := sq.Query.havingClauses 171 if len(hc) > 0 { 172 sql = fmt.Sprintf("%s HAVING %s", sql, hc.String()) 173 } 174 175 for i := range hc { 176 sq.args = append(sq.args, hc[i].Arguments...) 177 } 178 } 179 180 return sql 181 } 182 183 func (sq *sqlBuilder) buildOrderClauses(sql string) string { 184 oc := sq.Query.orderClauses 185 if len(oc) > 0 { 186 orderSQL := oc.Join(", ") 187 if regexpMatchNames.MatchString(orderSQL) { 188 warningMsg := fmt.Sprintf("Order clause(s) contains invalid characters: %s", orderSQL) 189 log(logging.Warn, warningMsg) 190 return sql 191 } 192 193 sql = fmt.Sprintf("%s ORDER BY %s", sql, orderSQL) 194 sq.args = append(sq.args, oc.Args()...) 195 } 196 return sql 197 } 198 199 func (sq *sqlBuilder) buildPaginationClauses(sql string) string { 200 if sq.Query.limitResults > 0 && sq.Query.Paginator == nil { 201 sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.limitResults) 202 } 203 if sq.Query.Paginator != nil { 204 sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.Paginator.PerPage) 205 sql = fmt.Sprintf("%s OFFSET %d", sql, sq.Query.Paginator.Offset) 206 } 207 return sql 208 } 209 210 // columnCache is used to prevent columns rebuilding. 211 var columnCache = map[string]columns.Columns{} 212 var columnCacheMutex = sync.RWMutex{} 213 214 func (sq *sqlBuilder) buildColumns() columns.Columns { 215 tableName := sq.Model.TableName() 216 asName := sq.Model.alias() 217 acl := len(sq.AddColumns) 218 if acl == 0 { 219 columnCacheMutex.RLock() 220 cols, ok := columnCache[tableName] 221 columnCacheMutex.RUnlock() 222 // if alias is the same, don't remake columns 223 if ok && cols.TableAlias == asName { 224 return cols 225 } 226 cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName, sq.Model.IDField()) 227 columnCacheMutex.Lock() 228 columnCache[tableName] = cols 229 columnCacheMutex.Unlock() 230 return cols 231 } 232 233 // acl > 0 234 cols := columns.NewColumns("", sq.Model.IDField()) 235 cols.Add(sq.AddColumns...) 236 return cols 237 }