github.com/kotovmak/go-admin@v1.1.1/modules/db/dialect/dialect.go (about) 1 // Copyright 2019 GoAdmin Core Team. All rights reserved. 2 // Use of this source code is governed by a Apache-2.0 style 3 // license that can be found in the LICENSE file. 4 5 package dialect 6 7 import ( 8 "strings" 9 10 "github.com/kotovmak/go-admin/modules/config" 11 ) 12 13 // Dialect is methods set of different driver. 14 type Dialect interface { 15 // GetName get dialect's name 16 GetName() string 17 18 // ShowColumns show columns of specified table 19 ShowColumns(table string) string 20 21 // ShowTables show tables of database 22 ShowTables() string 23 24 // Insert 25 Insert(comp *SQLComponent) string 26 27 // Delete 28 Delete(comp *SQLComponent) string 29 30 // Update 31 Update(comp *SQLComponent) string 32 33 // Select 34 Select(comp *SQLComponent) string 35 36 // GetDelimiter return the delimiter of Dialect. 37 GetDelimiter() string 38 } 39 40 // GetDialect return the default Dialect. 41 func GetDialect() Dialect { 42 return GetDialectByDriver(config.GetDatabases().GetDefault().Driver) 43 } 44 45 // GetDialectByDriver return the Dialect of given driver. 46 func GetDialectByDriver(driver string) Dialect { 47 switch driver { 48 case "mysql": 49 return mysql{ 50 commonDialect: commonDialect{delimiter: "`", delimiter2: "`"}, 51 } 52 case "mssql": 53 return mssql{ 54 commonDialect: commonDialect{delimiter: "[", delimiter2: "]"}, 55 } 56 case "postgresql": 57 return postgresql{ 58 commonDialect: commonDialect{delimiter: `"`, delimiter2: `"`}, 59 } 60 case "sqlite": 61 return sqlite{ 62 commonDialect: commonDialect{delimiter: "`", delimiter2: "`"}, 63 } 64 case "oceanbase": 65 return oceanbase{ 66 commonDialect: commonDialect{delimiter: "`", delimiter2: "`"}, 67 } 68 default: 69 return commonDialect{delimiter: "`", delimiter2: "`"} 70 } 71 } 72 73 // H is a shorthand of map. 74 type H map[string]interface{} 75 76 // SQLComponent is a sql components set. 77 type SQLComponent struct { 78 Fields []string 79 Functions []string 80 TableName string 81 Wheres []Where 82 Leftjoins []Join 83 Args []interface{} 84 Order string 85 Offset string 86 Limit string 87 WhereRaws string 88 UpdateRaws []RawUpdate 89 Group string 90 Statement string 91 Values H 92 } 93 94 // Where contains the operation and field. 95 type Where struct { 96 Operation string 97 Field string 98 Qmark string 99 } 100 101 // Join contains the table and field and operation. 102 type Join struct { 103 Table string 104 FieldA string 105 Operation string 106 FieldB string 107 } 108 109 // RawUpdate contains the expression and arguments. 110 type RawUpdate struct { 111 Expression string 112 Args []interface{} 113 } 114 115 // ******************************* 116 // internal help function 117 // ******************************* 118 119 func (sql *SQLComponent) getLimit() string { 120 if sql.Limit == "" { 121 return "" 122 } 123 return " limit " + sql.Limit + " " 124 } 125 126 func (sql *SQLComponent) getOffset() string { 127 if sql.Offset == "" { 128 return "" 129 } 130 return " offset " + sql.Offset + " " 131 } 132 133 func (sql *SQLComponent) getOrderBy() string { 134 if sql.Order == "" { 135 return "" 136 } 137 return " order by " + sql.Order + " " 138 } 139 140 func (sql *SQLComponent) getGroupBy() string { 141 if sql.Group == "" { 142 return "" 143 } 144 return " group by " + sql.Group + " " 145 } 146 147 func (sql *SQLComponent) getJoins(delimiter, delimiter2 string) string { 148 if len(sql.Leftjoins) == 0 { 149 return "" 150 } 151 joins := "" 152 for _, join := range sql.Leftjoins { 153 joins += " left join " + wrap(delimiter, delimiter2, join.Table) + " on " + 154 sql.processLeftJoinField(join.FieldA, delimiter, delimiter2) + " " + join.Operation + " " + 155 sql.processLeftJoinField(join.FieldB, delimiter, delimiter2) + " " 156 } 157 return joins 158 } 159 160 func (sql *SQLComponent) processLeftJoinField(field, delimiter, delimiter2 string) string { 161 arr := strings.Split(field, ".") 162 if len(arr) > 0 { 163 return delimiter + arr[0] + delimiter2 + "." + delimiter + arr[1] + delimiter2 164 } 165 return field 166 } 167 168 func (sql *SQLComponent) getFields(delimiter, delimiter2 string) string { 169 if len(sql.Fields) == 0 { 170 return "*" 171 } 172 fields := "" 173 if len(sql.Leftjoins) == 0 { 174 for k, field := range sql.Fields { 175 if sql.Functions[k] != "" { 176 fields += sql.Functions[k] + "(" + wrap(delimiter, delimiter2, field) + ")," 177 } else { 178 fields += wrap(delimiter, delimiter2, field) + "," 179 } 180 } 181 } else { 182 for _, field := range sql.Fields { 183 arr := strings.Split(field, ".") 184 if len(arr) > 1 { 185 fields += wrap(delimiter, delimiter2, arr[0]) + "." + wrap(delimiter, delimiter2, arr[1]) + "," 186 } else { 187 fields += wrap(delimiter, delimiter2, field) + "," 188 } 189 } 190 } 191 return fields[:len(fields)-1] 192 } 193 194 func wrap(delimiter, delimiter2, field string) string { 195 if field == "*" { 196 return "*" 197 } 198 return delimiter + field + delimiter2 199 } 200 201 func (sql *SQLComponent) getWheres(delimiter, delimiter2 string) string { 202 if len(sql.Wheres) == 0 { 203 if sql.WhereRaws != "" { 204 return " where " + sql.WhereRaws 205 } 206 return "" 207 } 208 wheres := " where " 209 var arr []string 210 for _, where := range sql.Wheres { 211 arr = strings.Split(where.Field, ".") 212 if len(arr) > 1 { 213 wheres += arr[0] + "." + wrap(delimiter, delimiter2, arr[1]) + " " + where.Operation + " " + where.Qmark + " and " 214 } else { 215 wheres += wrap(delimiter, delimiter2, where.Field) + " " + where.Operation + " " + where.Qmark + " and " 216 } 217 } 218 219 if sql.WhereRaws != "" { 220 return wheres + sql.WhereRaws 221 } 222 return wheres[:len(wheres)-5] 223 } 224 225 func (sql *SQLComponent) prepareUpdate(delimiter, delimiter2 string) { 226 fields := "" 227 args := make([]interface{}, 0) 228 229 if len(sql.Values) != 0 { 230 231 for key, value := range sql.Values { 232 fields += wrap(delimiter, delimiter2, key) + " = ?, " 233 args = append(args, value) 234 } 235 236 if len(sql.UpdateRaws) == 0 { 237 fields = fields[:len(fields)-2] 238 } else { 239 for i := 0; i < len(sql.UpdateRaws); i++ { 240 if i == len(sql.UpdateRaws)-1 { 241 fields += sql.UpdateRaws[i].Expression + " " 242 } else { 243 fields += sql.UpdateRaws[i].Expression + "," 244 } 245 args = append(args, sql.UpdateRaws[i].Args...) 246 } 247 } 248 249 sql.Args = append(args, sql.Args...) 250 } else { 251 if len(sql.UpdateRaws) == 0 { 252 panic("prepareUpdate: wrong parameter") 253 } else { 254 for i := 0; i < len(sql.UpdateRaws); i++ { 255 if i == len(sql.UpdateRaws)-1 { 256 fields += sql.UpdateRaws[i].Expression + " " 257 } else { 258 fields += sql.UpdateRaws[i].Expression + "," 259 } 260 args = append(args, sql.UpdateRaws[i].Args...) 261 } 262 } 263 sql.Args = append(args, sql.Args...) 264 } 265 266 sql.Statement = "update " + delimiter + sql.TableName + delimiter2 + " set " + fields + sql.getWheres(delimiter, delimiter2) 267 } 268 269 func (sql *SQLComponent) prepareInsert(delimiter, delimiter2 string) { 270 fields := " (" 271 quesMark := "(" 272 273 for key, value := range sql.Values { 274 fields += wrap(delimiter, delimiter2, key) + "," 275 quesMark += "?," 276 sql.Args = append(sql.Args, value) 277 } 278 fields = fields[:len(fields)-1] + ")" 279 quesMark = quesMark[:len(quesMark)-1] + ")" 280 281 sql.Statement = "insert into " + delimiter + sql.TableName + delimiter2 + fields + " values " + quesMark 282 }