github.com/nshntarora/pop@v0.1.2/sql_builder.go (about) 1 package pop 2 3 import ( 4 "fmt" 5 "regexp" 6 "strings" 7 "sync" 8 9 "github.com/jmoiron/sqlx" 10 "github.com/nshntarora/pop/columns" 11 "github.com/nshntarora/pop/logging" 12 ) 13 14 type sqlBuilder struct { 15 Query Query 16 Model *Model 17 AddColumns []string 18 sql string 19 args []interface{} 20 isCompiled bool 21 err error 22 } 23 24 func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder { 25 return &sqlBuilder{ 26 Query: q, 27 Model: m, 28 AddColumns: addColumns, 29 args: []interface{}{}, 30 isCompiled: false, 31 } 32 } 33 34 var ( 35 regexpMatchLimit = regexp.MustCompile(`(?i).*\s+limit\s+[0-9]*(\s?,\s?[0-9]*)?$`) 36 regexpMatchOffset = regexp.MustCompile(`(?i).*\s+offset\s+[0-9]*$`) 37 regexpMatchRowsOnly = regexp.MustCompile(`(?i).*\s+rows only`) 38 regexpMatchNames = regexp.MustCompile("(?i).*;+.*") // https://play.golang.org/p/FAmre5Sjin5 39 ) 40 41 func hasLimitOrOffset(sqlString string) bool { 42 trimmedSQL := strings.TrimSpace(sqlString) 43 if regexpMatchLimit.MatchString(trimmedSQL) { 44 return true 45 } 46 47 if regexpMatchOffset.MatchString(trimmedSQL) { 48 return true 49 } 50 51 if regexpMatchRowsOnly.MatchString(trimmedSQL) { 52 return true 53 } 54 55 return false 56 } 57 58 func (sq *sqlBuilder) String() string { 59 if !sq.isCompiled { 60 sq.compile() 61 } 62 return sq.sql 63 } 64 65 func (sq *sqlBuilder) Args() []interface{} { 66 if !sq.isCompiled { 67 sq.compile() 68 } 69 return sq.args 70 } 71 72 var inRegex = regexp.MustCompile(`(?i)in\s*\(\s*\?\s*\)`) 73 74 func (sq *sqlBuilder) compile() { 75 if sq.sql == "" { 76 if sq.Query.RawSQL.Fragment != "" { 77 if sq.Query.Paginator != nil && !hasLimitOrOffset(sq.Query.RawSQL.Fragment) { 78 sq.sql = sq.buildPaginationClauses(sq.Query.RawSQL.Fragment) 79 } else { 80 if sq.Query.Paginator != nil { 81 log(logging.Warn, "Query already contains pagination") 82 } 83 sq.sql = sq.Query.RawSQL.Fragment 84 } 85 sq.args = sq.Query.RawSQL.Arguments 86 } else { 87 if sq.Model == nil { 88 sq.err = fmt.Errorf("sqlBuilder.compile() called but no RawSQL and Model specified") 89 return 90 } 91 switch sq.Query.Operation { 92 case Select: 93 sq.sql = sq.buildSelectSQL() 94 case Delete: 95 sq.sql = sq.buildDeleteSQL() 96 default: 97 panic("unexpected query operation " + sq.Query.Operation) 98 } 99 } 100 101 if inRegex.MatchString(sq.sql) { 102 s, args, err := sqlx.In(sq.sql, sq.Args()...) 103 if err == nil { 104 sq.sql = s 105 sq.args = args 106 } 107 } 108 sq.sql = sq.Query.Connection.Dialect.TranslateSQL(sq.sql) 109 } 110 } 111 112 func (sq *sqlBuilder) buildSelectSQL() string { 113 cols := sq.buildColumns() 114 115 fc := sq.buildfromClauses() 116 117 sql := fmt.Sprintf("SELECT %s FROM %s", cols.Readable().SelectString(), fc) 118 119 sql = sq.buildJoinClauses(sql) 120 sql = sq.buildWhereClauses(sql) 121 sql = sq.buildGroupClauses(sql) 122 sql = sq.buildOrderClauses(sql) 123 sql = sq.buildPaginationClauses(sql) 124 125 return sql 126 } 127 128 func (sq *sqlBuilder) buildDeleteSQL() string { 129 fc := sq.buildfromClauses() 130 131 sql := fmt.Sprintf("DELETE FROM %s", fc) 132 133 sql = sq.buildWhereClauses(sql) 134 135 // paginated delete supported by sqlite and mysql 136 // > If SQLite is compiled with the SQLITE_ENABLE_UPDATE_DELETE_LIMIT compile-time option [...] - from https://www.sqlite.org/lang_delete.html 137 // 138 // not generic enough IMO, therefore excluded 139 // 140 //switch sq.Query.Connection.Dialect.Name() { 141 //case nameMySQL, nameSQLite3: 142 // sql = sq.buildOrderClauses(sql) 143 // sql = sq.buildPaginationClauses(sql) 144 //} 145 146 return sql 147 } 148 149 func (sq *sqlBuilder) buildfromClauses() fromClauses { 150 models := []*Model{ 151 sq.Model, 152 } 153 for _, mc := range sq.Query.belongsToThroughClauses { 154 models = append(models, mc.Through) 155 } 156 157 fc := sq.Query.fromClauses 158 for _, m := range models { 159 tableName := m.TableName() 160 asName := m.Alias() 161 fc = append(fc, fromClause{ 162 From: tableName, 163 As: asName, 164 }) 165 } 166 167 return fc 168 } 169 170 func (sq *sqlBuilder) buildWhereClauses(sql string) string { 171 mcs := sq.Query.belongsToThroughClauses 172 for _, mc := range mcs { 173 sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.associationName()), mc.BelongsTo.ID()) 174 sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.associationName())) 175 } 176 177 wc := sq.Query.whereClauses 178 if len(wc) > 0 { 179 sql = fmt.Sprintf("%s WHERE %s", sql, wc.Join(" AND ")) 180 sq.args = append(sq.args, wc.Args()...) 181 } 182 return sql 183 } 184 185 func (sq *sqlBuilder) buildJoinClauses(sql string) string { 186 oc := sq.Query.joinClauses 187 if len(oc) > 0 { 188 sql += " " + oc.String() 189 for i := range oc { 190 sq.args = append(sq.args, oc[i].Arguments...) 191 } 192 } 193 194 return sql 195 } 196 197 func (sq *sqlBuilder) buildGroupClauses(sql string) string { 198 gc := sq.Query.groupClauses 199 if len(gc) > 0 { 200 sql = fmt.Sprintf("%s GROUP BY %s", sql, gc.String()) 201 202 hc := sq.Query.havingClauses 203 if len(hc) > 0 { 204 sql = fmt.Sprintf("%s HAVING %s", sql, hc.String()) 205 } 206 207 for i := range hc { 208 sq.args = append(sq.args, hc[i].Arguments...) 209 } 210 } 211 212 return sql 213 } 214 215 func (sq *sqlBuilder) buildOrderClauses(sql string) string { 216 oc := sq.Query.orderClauses 217 if len(oc) > 0 { 218 orderSQL := oc.Join(", ") 219 if regexpMatchNames.MatchString(orderSQL) { 220 warningMsg := fmt.Sprintf("Order clause(s) contains invalid characters: %s", orderSQL) 221 log(logging.Warn, warningMsg) 222 return sql 223 } 224 225 sql = fmt.Sprintf("%s ORDER BY %s", sql, orderSQL) 226 sq.args = append(sq.args, oc.Args()...) 227 } 228 return sql 229 } 230 231 func (sq *sqlBuilder) buildPaginationClauses(sql string) string { 232 if sq.Query.limitResults > 0 && sq.Query.Paginator == nil { 233 sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.limitResults) 234 } 235 if sq.Query.Paginator != nil { 236 sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.Paginator.PerPage) 237 sql = fmt.Sprintf("%s OFFSET %d", sql, sq.Query.Paginator.Offset) 238 } 239 return sql 240 } 241 242 // columnCache is used to prevent columns rebuilding. 243 var columnCache = map[string]columns.Columns{} 244 var columnCacheMutex = sync.RWMutex{} 245 246 func (sq *sqlBuilder) buildColumns() columns.Columns { 247 tableName := sq.Model.TableName() 248 asName := sq.Model.Alias() 249 acl := len(sq.AddColumns) 250 if acl == 0 { 251 columnCacheMutex.RLock() 252 cols, ok := columnCache[tableName] 253 columnCacheMutex.RUnlock() 254 // if alias is the same, don't remake columns 255 if ok && cols.TableAlias == asName { 256 return cols 257 } 258 cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName, sq.Model.IDField()) 259 columnCacheMutex.Lock() 260 columnCache[tableName] = cols 261 columnCacheMutex.Unlock() 262 return cols 263 } 264 265 // acl > 0 266 cols := columns.NewColumns("", sq.Model.IDField()) 267 cols.Add(sq.AddColumns...) 268 return cols 269 }