gitlab.com/beacon-software/gadget@v0.0.0-20181217202115-54565ea1ed5e/database/qb/update.go (about) 1 package qb 2 3 import ( 4 "fmt" 5 "strings" 6 7 "gitlab.com/beacon-software/gadget/errors" 8 ) 9 10 /* 11 UPDATE [LOW_PRIORITY] [IGNORE] table_reference 12 SET assignment_list 13 [WHERE where_condition] 14 [ORDER BY ...] 15 [LIMIT row_count] 16 17 value: 18 {expr | DEFAULT} 19 20 assignment: 21 col_name = value 22 23 assignment_list: 24 assignment [, assignment] ... 25 */ 26 27 // UpdateQuery represents a query to update rows in a database 28 // Currently only supports single table, to change this the tableReference would have to be built out more. 29 type UpdateQuery struct { 30 tableReference Table 31 assignments []comparisonExpression 32 where *whereCondition 33 orderBy *orderBy 34 err error 35 } 36 37 // GetAlias returns the alias for the passed tablename used in this query. 38 func (q *UpdateQuery) GetAlias(tableName string) string { 39 // no aliasing in update 40 return tableName 41 } 42 43 // Set adds a assignment to this update query. 44 func (q *UpdateQuery) Set(field TableField, value interface{}) *UpdateQuery { 45 if field.Table != q.tableReference.GetName() { 46 q.err = errors.New("field table does not match table reference on update query") 47 } else { 48 q.assignments = append(q.assignments, binaryExpression{left: field, comparison: Equal, right: newUnion(value)}) 49 } 50 return q 51 } 52 53 // SetParam adds a parameterized assignment to this update query. 54 func (q *UpdateQuery) SetParam(field TableField) *UpdateQuery { 55 if field.Table != q.tableReference.GetName() { 56 q.err = errors.New("field table does not match table reference on update query") 57 } else { 58 q.assignments = append(q.assignments, parameterExpression{left: field, comparison: Equal}) 59 } 60 return q 61 } 62 63 // Where determines the conditions by which the assignments in this query apply 64 func (q *UpdateQuery) Where(condition *ConditionExpression) *UpdateQuery { 65 q.where.expression = condition 66 return q 67 } 68 69 // OrderBy the passed field and direction. 70 func (q *UpdateQuery) OrderBy(field TableField, direction OrderDirection) *UpdateQuery { 71 q.orderBy.addExpression(field, direction) 72 return q 73 } 74 75 // SQL representation of this query. 76 func (q *UpdateQuery) SQL(limit int) (string, []interface{}, error) { 77 if nil != q.err { 78 return "", nil, q.err 79 } 80 if len(q.assignments) == 0 { 81 return "", nil, errors.New("no assignments in update query") 82 } 83 sql := []string{fmt.Sprintf("UPDATE `%s` SET ", q.tableReference.GetName())} 84 alines := []string{} 85 values := []interface{}{} 86 for _, assignment := range q.assignments { 87 s, v := assignment.SQL() 88 alines = append(alines, s) 89 values = append(values, v...) 90 } 91 sql = append(sql, strings.Join(alines, ", ")) 92 // WHERE 93 if where, whereValues, ok := q.where.sql(); ok { 94 sql = append(sql, "WHERE", where) 95 values = append(values, whereValues...) 96 } 97 // ORDER BY 98 if s, ok := q.orderBy.sql(); ok { 99 sql = append(sql, s) 100 } 101 // LIMIT 102 if NoLimit != limit { 103 sql = append(sql, fmt.Sprintf("LIMIT %d", limit)) 104 } 105 return strings.Join(sql, " "), values, q.err 106 } 107 108 // ParameterizedSQL representation of this query. 109 func (q *UpdateQuery) ParameterizedSQL(limit int) (string, error) { 110 if nil != q.err { 111 return "", q.err 112 } 113 if len(q.assignments) == 0 { 114 return "", errors.New("no assignments in update query") 115 } 116 sql := []string{fmt.Sprintf("UPDATE `%s` SET ", q.tableReference.GetName())} 117 alines := []string{} 118 values := []interface{}{} 119 for _, assignment := range q.assignments { 120 s, v := assignment.SQL() 121 alines = append(alines, s) 122 values = append(values, v...) 123 } 124 sql = append(sql, strings.Join(alines, ", ")) 125 // WHERE 126 if where, _, ok := q.where.sql(); ok { 127 sql = append(sql, "WHERE", where) 128 } 129 // ORDER BY 130 if s, ok := q.orderBy.sql(); ok { 131 sql = append(sql, s) 132 } 133 // LIMIT 134 if NoLimit != limit { 135 sql = append(sql, fmt.Sprintf("LIMIT %d", limit)) 136 } 137 return strings.Join(sql, " "), q.err 138 }