github.com/dkishere/pop/v6@v6.103.1/sql_builder.go (about) 1 package pop 2 3 import ( 4 "fmt" 5 "regexp" 6 "strings" 7 "sync" 8 9 "github.com/dkishere/pop/v6/columns" 10 "github.com/dkishere/pop/v6/logging" 11 "github.com/jmoiron/sqlx" 12 ) 13 14 type sqlBuilder struct { 15 Query Query 16 Model *Model 17 AddColumns []string 18 sql string 19 args []interface{} 20 } 21 22 func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder { 23 return &sqlBuilder{ 24 Query: q, 25 Model: m, 26 AddColumns: addColumns, 27 args: []interface{}{}, 28 } 29 } 30 31 var ( 32 regexpMatchLimit = regexp.MustCompile(`(?i).*\s+limit\s+[0-9]*(\s?,\s?[0-9]*)?$`) 33 regexpMatchOffset = regexp.MustCompile(`(?i).*\s+offset\s+[0-9]*$`) 34 regexpMatchRowsOnly = regexp.MustCompile(`(?i).*\s+rows only`) 35 regexpMatchNames = regexp.MustCompile("(?i).*;+.*") // https://play.golang.org/p/FAmre5Sjin5 36 ) 37 38 func hasLimitOrOffset(sqlString string) bool { 39 trimmedSQL := strings.TrimSpace(sqlString) 40 if regexpMatchLimit.MatchString(trimmedSQL) { 41 return true 42 } 43 44 if regexpMatchOffset.MatchString(trimmedSQL) { 45 return true 46 } 47 48 if regexpMatchRowsOnly.MatchString(trimmedSQL) { 49 return true 50 } 51 52 return false 53 } 54 55 func (sq *sqlBuilder) String() string { 56 if sq.sql == "" { 57 sq.compile() 58 } 59 return sq.sql 60 } 61 62 func (sq *sqlBuilder) Args() []interface{} { 63 if len(sq.args) == 0 { 64 if len(sq.Query.RawSQL.Arguments) > 0 { 65 sq.args = sq.Query.RawSQL.Arguments 66 } else { 67 sq.compile() 68 } 69 } 70 return sq.args 71 } 72 73 var inRegex = regexp.MustCompile(`(?i)in\s*\(\s*\?\s*\)`) 74 75 func (sq *sqlBuilder) compile() { 76 if sq.sql == "" { 77 if sq.Query.RawSQL.Fragment != "" { 78 if sq.Query.Paginator != nil && !hasLimitOrOffset(sq.Query.RawSQL.Fragment) { 79 sq.sql = sq.buildPaginationClauses(sq.Query.RawSQL.Fragment) 80 } else { 81 if sq.Query.Paginator != nil { 82 log(logging.Warn, "Query already contains pagination") 83 } 84 sq.sql = sq.Query.RawSQL.Fragment 85 } 86 } else { 87 switch sq.Query.Operation { 88 case Select: 89 sq.sql = sq.buildSelectSQL() 90 case Delete: 91 sq.sql = sq.buildDeleteSQL() 92 default: 93 panic("unexpected query operation " + sq.Query.Operation) 94 } 95 } 96 97 if inRegex.MatchString(sq.sql) { 98 s, args, err := sqlx.In(sq.sql, sq.Args()...) 99 if err == nil { 100 sq.sql = s 101 sq.args = args 102 } 103 } 104 sq.sql = sq.Query.Connection.Dialect.TranslateSQL(sq.sql) 105 } 106 } 107 108 func (sq *sqlBuilder) buildSelectSQL() string { 109 cols := sq.buildColumns() 110 111 fc := sq.buildfromClauses() 112 113 sql := fmt.Sprintf("SELECT %s FROM %s", cols.Readable().SelectString(), fc) 114 115 sql = sq.buildJoinClauses(sql) 116 sql = sq.buildWhereClauses(sql) 117 sql = sq.buildGroupClauses(sql) 118 sql = sq.buildOrderClauses(sql) 119 sql = sq.buildPaginationClauses(sql) 120 121 return sql 122 } 123 124 func (sq *sqlBuilder) buildDeleteSQL() string { 125 fc := sq.buildfromClauses() 126 127 sql := fmt.Sprintf("DELETE FROM %s", fc) 128 129 sql = sq.buildWhereClauses(sql) 130 131 // paginated delete supported by sqlite and mysql 132 // > If SQLite is compiled with the SQLITE_ENABLE_UPDATE_DELETE_LIMIT compile-time option [...] - from https://www.sqlite.org/lang_delete.html 133 // 134 // not generic enough IMO, therefore excluded 135 // 136 //switch sq.Query.Connection.Dialect.Name() { 137 //case nameMySQL, nameSQLite3: 138 // sql = sq.buildOrderClauses(sql) 139 // sql = sq.buildPaginationClauses(sql) 140 //} 141 142 return sql 143 } 144 145 func (sq *sqlBuilder) buildfromClauses() fromClauses { 146 models := []*Model{ 147 sq.Model, 148 } 149 for _, mc := range sq.Query.belongsToThroughClauses { 150 models = append(models, mc.Through) 151 } 152 153 fc := sq.Query.fromClauses 154 for _, m := range models { 155 tableName := m.TableName() 156 asName := m.Alias() 157 fc = append(fc, fromClause{ 158 From: tableName, 159 As: asName, 160 }) 161 } 162 163 return fc 164 } 165 166 func (sq *sqlBuilder) buildWhereClauses(sql string) string { 167 mcs := sq.Query.belongsToThroughClauses 168 for _, mc := range mcs { 169 sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.associationName()), mc.BelongsTo.ID()) 170 sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.associationName())) 171 } 172 173 wc := sq.Query.whereClauses 174 if len(wc) > 0 { 175 sql = fmt.Sprintf("%s WHERE %s", sql, wc.Join(" AND ")) 176 sq.args = append(sq.args, wc.Args()...) 177 } 178 return sql 179 } 180 181 func (sq *sqlBuilder) buildJoinClauses(sql string) string { 182 oc := sq.Query.joinClauses 183 if len(oc) > 0 { 184 sql += " " + oc.String() 185 for i := range oc { 186 sq.args = append(sq.args, oc[i].Arguments...) 187 } 188 } 189 190 return sql 191 } 192 193 func (sq *sqlBuilder) buildGroupClauses(sql string) string { 194 gc := sq.Query.groupClauses 195 if len(gc) > 0 { 196 sql = fmt.Sprintf("%s GROUP BY %s", sql, gc.String()) 197 198 hc := sq.Query.havingClauses 199 if len(hc) > 0 { 200 sql = fmt.Sprintf("%s HAVING %s", sql, hc.String()) 201 } 202 203 for i := range hc { 204 sq.args = append(sq.args, hc[i].Arguments...) 205 } 206 } 207 208 return sql 209 } 210 211 func (sq *sqlBuilder) buildOrderClauses(sql string) string { 212 oc := sq.Query.orderClauses 213 if len(oc) > 0 { 214 orderSQL := oc.Join(", ") 215 if regexpMatchNames.MatchString(orderSQL) { 216 warningMsg := fmt.Sprintf("Order clause(s) contains invalid characters: %s", orderSQL) 217 log(logging.Warn, warningMsg) 218 return sql 219 } 220 221 sql = fmt.Sprintf("%s ORDER BY %s", sql, orderSQL) 222 sq.args = append(sq.args, oc.Args()...) 223 } 224 return sql 225 } 226 227 func (sq *sqlBuilder) buildPaginationClauses(sql string) string { 228 if sq.Query.limitResults > 0 && sq.Query.Paginator == nil { 229 sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.limitResults) 230 } 231 if sq.Query.Paginator != nil { 232 sql = fmt.Sprintf("%s LIMIT %d", sql, sq.Query.Paginator.PerPage) 233 sql = fmt.Sprintf("%s OFFSET %d", sql, sq.Query.Paginator.Offset) 234 } 235 return sql 236 } 237 238 // columnCache is used to prevent columns rebuilding. 239 var columnCache = map[string]columns.Columns{} 240 var columnCacheMutex = sync.RWMutex{} 241 242 func (sq *sqlBuilder) buildColumns() columns.Columns { 243 tableName := sq.Model.TableName() 244 asName := sq.Model.Alias() 245 acl := len(sq.AddColumns) 246 if acl == 0 { 247 columnCacheMutex.RLock() 248 cols, ok := columnCache[tableName] 249 columnCacheMutex.RUnlock() 250 // if alias is the same, don't remake columns 251 if ok && cols.TableAlias == asName { 252 return cols 253 } 254 cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName, sq.Model.IDField()) 255 columnCacheMutex.Lock() 256 columnCache[tableName] = cols 257 columnCacheMutex.Unlock() 258 return cols 259 } 260 261 // acl > 0 262 cols := columns.NewColumns("", sq.Model.IDField()) 263 cols.Add(sq.AddColumns...) 264 return cols 265 }