gitlab.com/beacon-software/gadget@v0.0.0-20181217202115-54565ea1ed5e/database/qb/select.go (about) 1 package qb 2 3 import ( 4 "fmt" 5 "strings" 6 ) 7 8 // SelectExpression for use in identifying the fields desired in a select query. 9 type SelectExpression interface { 10 // GetName that can be used to reference this expression 11 GetName() string 12 // GetTables that are used in this expression 13 GetTables() []string 14 // SQL that represents this SelectExpression 15 SQL() string 16 } 17 18 type alias struct { 19 field TableField 20 alias string 21 } 22 23 func (a alias) GetName() string { 24 return a.alias 25 } 26 27 func (a alias) GetTables() []string { 28 return a.field.GetTables() 29 } 30 31 func (a alias) SQL() string { 32 return fmt.Sprintf("%s AS `%s`", a.field.SQL(), a.alias) 33 } 34 35 // Alias the passed table field for use in or as a SelectExpression 36 func Alias(tableField TableField, aliasName string) SelectExpression { 37 return alias{field: tableField, alias: aliasName} 38 } 39 40 type notNull struct { 41 field TableField 42 alias string 43 } 44 45 func (nn notNull) GetName() string { 46 return nn.alias 47 } 48 49 func (nn notNull) GetTables() []string { 50 return nn.field.GetTables() 51 } 52 53 func (nn notNull) SQL() string { 54 return fmt.Sprintf("(%s IS NOT NULL) AS `%s`", nn.field.SQL(), nn.alias) 55 } 56 57 // NotNull field for use as a select expression 58 func NotNull(tableField TableField, alias string) SelectExpression { 59 return notNull{field: tableField, alias: alias} 60 } 61 62 type coalesce struct { 63 field TableField 64 def string 65 name string 66 } 67 68 func (c coalesce) GetName() string { 69 return c.name 70 } 71 72 func (c coalesce) GetTables() []string { 73 return c.field.GetTables() 74 } 75 76 func (c coalesce) SQL() string { 77 return fmt.Sprintf("COALESCE(%s, '%s') AS `%s`", c.field.SQL(), c.def, c.name) 78 } 79 80 // Coalesce creates a SQL coalesce that can be used as a SelectExpression 81 func Coalesce(column TableField, defaultValue string, alias string) SelectExpression { 82 return coalesce{field: column, def: defaultValue, name: alias} 83 } 84 85 // SelectQuery for retrieving data from a database table. 86 type SelectQuery struct { 87 distinct bool 88 from Table 89 selectExps []SelectExpression 90 joins []*Join 91 orderBy *orderBy 92 groupBy []SelectExpression 93 where *whereCondition 94 Seperator string 95 err error 96 } 97 98 // GetAlias of the passed table name in this query. 99 func (q *SelectQuery) GetAlias(tableName string) string { 100 return tableName 101 } 102 103 // From sets the primary table the query will get values from. 104 func (q *SelectQuery) From(table Table) *SelectQuery { 105 q.from = table 106 return q 107 } 108 109 // InnerJoin with another table in the database. 110 func (q *SelectQuery) InnerJoin(table Table) *Join { 111 join := NewJoin(Inner, Right, table) 112 q.joins = append(q.joins, join) 113 return join 114 } 115 116 // OuterJoin with another table in the database. 117 func (q *SelectQuery) OuterJoin(direction JoinDirection, table Table) *Join { 118 join := NewJoin(Outer, direction, table) 119 q.joins = append(q.joins, join) 120 return join 121 } 122 123 // Where the comparison between the two tablefields evaluates to true. 124 func (q *SelectQuery) Where(condition *ConditionExpression) *SelectQuery { 125 q.where.expression = condition 126 return q 127 } 128 129 // OrderBy the passed field and direction. 130 func (q *SelectQuery) OrderBy(field TableField, direction OrderDirection) *SelectQuery { 131 q.orderBy.addExpression(field, direction) 132 return q 133 } 134 135 // GroupBy the passed table field. 136 func (q *SelectQuery) GroupBy(expressions ...SelectExpression) *SelectQuery { 137 q.groupBy = append(q.groupBy, expressions...) 138 return q 139 } 140 141 func (q *SelectQuery) selectExpressionsSQL() string { 142 var prefix string 143 if q.distinct { 144 prefix = "SELECT DISTINCT" 145 } else { 146 prefix = "SELECT" 147 } 148 expressions := make([]string, len(q.selectExps)) 149 for i, exp := range q.selectExps { 150 expressions[i] = exp.SQL() 151 } 152 return fmt.Sprintf("%s %s", prefix, strings.Join(expressions, ", ")) 153 } 154 155 // Validate that this query can be executed. 156 func (q *SelectQuery) Validate() bool { 157 q.err = nil 158 // gather up all the tables that must be present in the from or in a join 159 tablesRequired := make(map[string]bool) 160 // check the select 161 for _, exp := range q.selectExps { 162 for _, table := range exp.GetTables() { 163 tablesRequired[table] = true 164 } 165 } 166 // check the where expressions 167 for _, table := range q.where.tables() { 168 tablesRequired[table] = true 169 } 170 // now get the tables from the order by 171 for _, table := range q.orderBy.getTables() { 172 tablesRequired[table] = true 173 } 174 // grab the tables from the group by 175 for _, tf := range q.groupBy { 176 for _, table := range tf.GetTables() { 177 tablesRequired[table] = true 178 } 179 } 180 // check that the from table is set 181 if nil == q.from { 182 q.err = NewValidationFromNotSetError() 183 return false 184 } 185 delete(tablesRequired, q.from.GetAlias()) 186 187 for _, join := range q.joins { 188 delete(tablesRequired, join.table.GetAlias()) 189 if nil != join.err { 190 q.err = join.err 191 return false 192 } 193 } 194 195 if len(tablesRequired) > 0 { 196 missingTables := make([]string, len(tablesRequired)) 197 i := 0 198 for key := range tablesRequired { 199 missingTables[i] = key 200 i++ 201 } 202 q.err = NewMissingTablesError(missingTables) 203 return false 204 } 205 206 return true 207 } 208 209 // SQL statement corresponding to this query. 210 func (q *SelectQuery) SQL(limit, offset uint) (string, []interface{}, error) { 211 if !q.Validate() { 212 return "", []interface{}{}, q.err 213 } 214 // SELECT 215 lines := []string{q.selectExpressionsSQL()} 216 var values []interface{} 217 218 // FROM 219 from := fmt.Sprintf("FROM `%s` AS `%s`", 220 q.from.GetName(), 221 q.from.GetAlias()) 222 lines = append(lines, from) 223 224 // JOIN 225 for _, join := range q.joins { 226 joinSQL, joinValues := join.SQL() 227 lines = append(lines, joinSQL) 228 values = append(values, joinValues...) 229 } 230 231 // WHERE 232 if where, whereValues, ok := q.where.sql(); ok { 233 lines = append(lines, "WHERE", where) 234 values = append(values, whereValues...) 235 } 236 237 // GROUP BY 238 if len(q.groupBy) > 0 { 239 groupByLines := []string{} 240 for _, tf := range q.groupBy { 241 groupByLines = append(groupByLines, tf.SQL()) 242 } 243 groupByStatement := "GROUP BY " + strings.Join(groupByLines, ", ") 244 lines = append(lines, groupByStatement) 245 } 246 247 // ORDER BY 248 if orderby, ok := q.orderBy.sql(); ok { 249 lines = append(lines, orderby) 250 } 251 252 // LIMIT, OFFSET 253 if NoLimit != limit { 254 lines = append(lines, fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)) 255 } 256 return strings.Join(lines, q.Seperator), values, q.err 257 }