github.com/tsmith1024/pop@v4.12.2+incompatible/sql_builder.go (about) 1 package pop 2 3 import ( 4 "fmt" 5 "regexp" 6 "strings" 7 "sync" 8 9 "github.com/gobuffalo/pop/columns" 10 "github.com/gobuffalo/pop/logging" 11 "github.com/jmoiron/sqlx" 12 ) 13 14 type sqlBuilder struct { 15 Query Query 16 Model *Model 17 AddColumns []string 18 sql string 19 args []interface{} 20 } 21 22 func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder { 23 return &sqlBuilder{ 24 Query: q, 25 Model: m, 26 AddColumns: addColumns, 27 args: []interface{}{}, 28 } 29 } 30 31 var ( 32 regexpMatchLimit = regexp.MustCompile(`(?i).*\s+limit\s+[0-9]*(\s?,\s?[0-9]*)?$`) 33 regexpMatchOffset = regexp.MustCompile(`(?i).*\s+offset\s+[0-9]*$`) 34 regexpMatchRowsOnly = regexp.MustCompile(`(?i).*\s+rows only`) 35 regexpMatchNames = regexp.MustCompile("(?i).*;+.*") // https://play.golang.org/p/FAmre5Sjin5 36 ) 37 38 func hasLimitOrOffset(sqlString string) bool { 39 trimmedSQL := strings.TrimSpace(sqlString) 40 if regexpMatchLimit.MatchString(trimmedSQL) { 41 return true 42 } 43 44 if regexpMatchOffset.MatchString(trimmedSQL) { 45 return true 46 } 47 48 if regexpMatchRowsOnly.MatchString(trimmedSQL) { 49 return true 50 } 51 52 return false 53 } 54 55 func (sq *sqlBuilder) String() string { 56 if sq.sql == "" { 57 sq.compile() 58 } 59 return sq.sql 60 } 61 62 func (sq *sqlBuilder) Args() []interface{} { 63 if len(sq.args) == 0 { 64 if len(sq.Query.RawSQL.Arguments) > 0 { 65 sq.args = sq.Query.RawSQL.Arguments 66 } else { 67 sq.compile() 68 } 69 } 70 return sq.args 71 } 72 73 var inRegex = regexp.MustCompile(`(?i)in\s*\(\s*\?\s*\)`) 74 75 func (sq *sqlBuilder) compile() { 76 if sq.sql == "" { 77 if sq.Query.RawSQL.Fragment != "" { 78 if sq.Query.Paginator != nil && !hasLimitOrOffset(sq.Query.RawSQL.Fragment) { 79 sq.sql = sq.buildPaginationClauses(sq.Query.RawSQL.Fragment) 80 } else { 81 if sq.Query.Paginator != nil { 82 log(logging.Warn, "Query already contains pagination") 83 } 84 sq.sql = sq.Query.RawSQL.Fragment 85 } 86 } else { 87 sq.sql = sq.buildSelectSQL() 88 } 89 90 if inRegex.MatchString(sq.sql) { 91 s, args, err := sqlx.In(sq.sql, sq.Args()...) 92 if err == nil { 93 sq.sql = s 94 sq.args = args 95 } 96 } 97 sq.sql = sq.Query.Connection.Dialect.TranslateSQL(sq.sql) 98 } 99 } 100 101 func (sq *sqlBuilder) buildSelectSQL() string { 102 cols := sq.buildColumns() 103 104 fc := sq.buildfromClauses() 105 106 sql := fmt.Sprintf("SELECT %s FROM %s", cols.Readable().SelectString(), fc) 107 108 sql = sq.buildJoinClauses(sql) 109 sql = sq.buildWhereClauses(sql) 110 sql = sq.buildGroupClauses(sql) 111 sql = sq.buildOrderClauses(sql) 112 sql = sq.buildPaginationClauses(sql) 113 114 return sql 115 } 116 117 func (sq *sqlBuilder) buildfromClauses() fromClauses { 118 models := []*Model{ 119 sq.Model, 120 } 121 for _, mc := range sq.Query.belongsToThroughClauses { 122 models = append(models, mc.Through) 123 } 124 125 fc := sq.Query.fromClauses 126 for _, m := range models { 127 tableName := m.TableName() 128 asName := m.As 129 if asName == "" { 130 asName = strings.Replace(tableName, ".", "_", -1) 131 } 132 fc = append(fc, fromClause{ 133 From: tableName, 134 As: asName, 135 }) 136 } 137 138 return fc 139 } 140 141 func (sq *sqlBuilder) buildWhereClauses(sql string) string { 142 mcs := sq.Query.belongsToThroughClauses 143 for _, mc := range mcs { 144 sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.associationName()), mc.BelongsTo.ID()) 145 sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.associationName())) 146 } 147 148 wc := sq.Query.whereClauses 149 if len(wc) > 0 { 150 sql = fmt.Sprintf("%s WHERE %s", sql, wc.Join(" AND ")) 151 sq.args = append(sq.args, wc.Args()...) 152 } 153 return sql 154 } 155 156 func (sq *sqlBuilder) buildJoinClauses(sql string) string { 157 oc := sq.Query.joinClauses 158 if len(oc) > 0 { 159 sql += " " + oc.String() 160 for i := range oc { 161 sq.args = append(sq.args, oc[i].Arguments...) 162 } 163 } 164 165 return sql 166 } 167 168 func (sq *sqlBuilder) buildGroupClauses(sql string) string { 169 gc := sq.Query.groupClauses 170 if len(gc) > 0 { 171 sql = fmt.Sprintf("%s GROUP BY %s", sql, gc.String()) 172 173 hc := sq.Query.havingClauses 174 if len(hc) > 0 { 175 sql = fmt.Sprintf("%s HAVING %s", sql, hc.String()) 176 } 177 178 for i := range hc { 179 sq.args = append(sq.args, hc[i].Arguments...) 180 } 181 } 182 183 return sql 184 } 185 186 func (sq *sqlBuilder) buildOrderClauses(sql string) string { 187 oc := sq.Query.orderClauses 188 if len(oc) > 0 { 189 orderSQL := oc.Join(", ") 190 if regexpMatchNames.MatchString(orderSQL) { 191 warningMsg := fmt.Sprintf("Order clause(s) contains invalid characters: %s", orderSQL) 192 log(logging.Warn, warningMsg) 193 return sql 194 } 195 196 sql = fmt.Sprintf("%s ORDER BY %s", sql, orderSQL) 197 sq.args = append(sq.args, oc.Args()...) 198 } 199 return sql 200 } 201 202 func (sq *sqlBuilder) buildPaginationClauses(sql string) string { 203 if sq.Query.limitResults > 0 && sq.Query.Paginator == nil { 204 sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.limitResults) 205 } 206 if sq.Query.Paginator != nil { 207 sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.Paginator.PerPage) 208 sql = fmt.Sprintf("%s OFFSET %d", sql, sq.Query.Paginator.Offset) 209 } 210 return sql 211 } 212 213 // columnCache is used to prevent columns rebuilding. 214 var columnCache = map[string]columns.Columns{} 215 var columnCacheMutex = sync.RWMutex{} 216 217 func (sq *sqlBuilder) buildColumns() columns.Columns { 218 tableName := sq.Model.TableName() 219 asName := sq.Model.As 220 if asName == "" { 221 asName = strings.Replace(tableName, ".", "_", -1) 222 } 223 acl := len(sq.AddColumns) 224 if acl == 0 { 225 columnCacheMutex.RLock() 226 cols, ok := columnCache[tableName] 227 columnCacheMutex.RUnlock() 228 // if alias is the same, don't remake columns 229 if ok && cols.TableAlias == asName { 230 return cols 231 } 232 cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName) 233 columnCacheMutex.Lock() 234 columnCache[tableName] = cols 235 columnCacheMutex.Unlock() 236 return cols 237 } 238 239 // acl > 0 240 cols := columns.NewColumns("") 241 cols.Add(sq.AddColumns...) 242 return cols 243 }