gitlab.com/beacon-software/gadget@v0.0.0-20181217202115-54565ea1ed5e/database/qb/expression.go (about) 1 package qb 2 3 import ( 4 "fmt" 5 "strings" 6 ) 7 8 const ( 9 // SQLNow is the SQL NOW() function for use as a value in expressions. 10 SQLNow = "NOW()" 11 12 // SQLNull is the SQL representation of NULL 13 SQLNull = "NULL" 14 ) 15 16 type expressionUnion struct { 17 value interface{} 18 field *TableField 19 multi []expressionUnion 20 } 21 22 func newUnion(values ...interface{}) expressionUnion { 23 if len(values) == 1 { 24 tf, ok := values[0].(TableField) 25 if ok { 26 return expressionUnion{field: &tf} 27 } 28 return expressionUnion{value: values[0]} 29 } 30 multi := make([]expressionUnion, len(values)) 31 for i, obj := range values { 32 multi[i] = newUnion(obj) 33 } 34 return expressionUnion{multi: multi} 35 } 36 37 func (union expressionUnion) isString() bool { 38 _, ok := union.value.(string) 39 return ok 40 } 41 42 func (union expressionUnion) isField() bool { 43 return nil != union.field 44 } 45 46 func (union expressionUnion) isMulti() bool { 47 return nil != union.multi 48 } 49 50 func (union expressionUnion) getTables() []string { 51 if union.isField() { 52 return union.field.GetTables() 53 } else if union.isMulti() { 54 tables := []string{} 55 for _, exp := range union.multi { 56 tables = append(tables, exp.getTables()...) 57 } 58 return tables 59 } else { 60 return []string{} 61 } 62 } 63 64 func (union expressionUnion) sql() (string, []interface{}) { 65 var sql string 66 values := []interface{}{} 67 if union.isMulti() { 68 sa := make([]string, len(union.multi)) 69 var subvalues []interface{} 70 for i, exp := range union.multi { 71 sa[i], subvalues = exp.sql() 72 values = append(values, subvalues...) 73 } 74 sql = "(" + strings.Join(sa, ", ") + ")" 75 } else if union.isField() { 76 sql = union.field.SQL() 77 } else if SQLNow == union.value || SQLNull == union.value { 78 sql = fmt.Sprintf("%s", union.value) 79 } else if union.isString() && strings.HasPrefix(union.value.(string), ":") { 80 sql = fmt.Sprintf("%s", union.value) 81 } else { 82 sql = "?" 83 values = append(values, union.value) 84 } 85 return sql, values 86 } 87 88 type comparisonExpression interface { 89 SQL() (string, []interface{}) 90 } 91 92 type parameterExpression struct { 93 left TableField 94 comparison Comparison 95 } 96 97 func (be parameterExpression) SQL() (string, []interface{}) { 98 left := be.left.SQL() 99 return fmt.Sprintf("%s %s :%s", left, be.comparison, be.left.GetName()), make([]interface{}, 0) 100 } 101 102 type binaryExpression struct { 103 left TableField 104 comparison Comparison 105 right expressionUnion 106 } 107 108 func (be binaryExpression) SQL() (string, []interface{}) { 109 left := be.left.SQL() 110 right, values := be.right.sql() 111 return fmt.Sprintf("%s %s %s", left, be.comparison, right), values 112 } 113 114 // ConditionExpression represents an expression that can be used as a condition in a where or join on. 115 type ConditionExpression struct { 116 binary *binaryExpression 117 left *ConditionExpression 118 operator string 119 right *ConditionExpression 120 } 121 122 // Tables that are used in this expression or it's sub expressions. 123 func (exp *ConditionExpression) Tables() []string { 124 tables := []string{} 125 if nil != exp.binary { 126 tables = append(tables, exp.binary.left.GetTables()...) 127 tables = append(tables, exp.binary.right.getTables()...) 128 } else { 129 tables = append(tables, exp.left.Tables()...) 130 tables = append(tables, exp.right.Tables()...) 131 } 132 return tables 133 } 134 135 // FieldComparison to another field or a discrete value. 136 func FieldComparison(left TableField, comparison Comparison, right interface{}) *ConditionExpression { 137 if nil == right { 138 right = SQLNull 139 } 140 return &ConditionExpression{binary: &binaryExpression{left: left, comparison: comparison, right: newUnion(right)}} 141 } 142 143 // FieldIn a series of TableFields and/or values 144 func FieldIn(left TableField, in ...interface{}) *ConditionExpression { 145 // swap any 'nils' for sql null 146 rightValues := make([]interface{}, len(in)) 147 for i, value := range in { 148 if nil == value { 149 value = SQLNull 150 } 151 rightValues[i] = value 152 } 153 comparison := In 154 if len(rightValues) == 1 { 155 comparison = Equal 156 } 157 return &ConditionExpression{binary: &binaryExpression{left: left, comparison: comparison, right: newUnion(rightValues...)}} 158 } 159 160 // And creates an expression with this and the passed expression with an AND conjunction. 161 func (exp *ConditionExpression) And(expression *ConditionExpression) *ConditionExpression { 162 ptr := &ConditionExpression{} 163 *ptr = *exp 164 wrap := &ConditionExpression{left: ptr, right: expression, operator: And} 165 *exp = *wrap 166 return exp 167 } 168 169 // Or creates an expression with this and the passed expression with an OR conjunction. 170 func (exp *ConditionExpression) Or(expression *ConditionExpression) *ConditionExpression { 171 ptr := &ConditionExpression{} 172 *ptr = *exp 173 wrap := &ConditionExpression{left: ptr, right: expression, operator: Or} 174 *exp = *wrap 175 return exp 176 } 177 178 // XOr creates an expression with this and the passed expression with an XOr conjunction. 179 func (exp *ConditionExpression) XOr(expression *ConditionExpression) *ConditionExpression { 180 ptr := &ConditionExpression{} 181 *ptr = *exp 182 wrap := &ConditionExpression{left: ptr, right: expression, operator: XOr} 183 *exp = *wrap 184 return exp 185 } 186 187 // SQL returns this condition expression as a SQL expression. 188 func (exp *ConditionExpression) SQL() (string, []interface{}) { 189 if nil != exp.binary { 190 return exp.binary.SQL() 191 } 192 lsql, values := exp.left.SQL() 193 rsql, rvalues := exp.right.SQL() 194 values = append(values, rvalues...) 195 return fmt.Sprintf("(%s %s %s)", lsql, exp.operator, rsql), values 196 }