github.com/astaxie/beego@v1.12.3/orm/cmd_utils.go (about) 1 // Copyright 2014 beego Author. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package orm 16 17 import ( 18 "fmt" 19 "os" 20 "strings" 21 ) 22 23 type dbIndex struct { 24 Table string 25 Name string 26 SQL string 27 } 28 29 // create database drop sql. 30 func getDbDropSQL(al *alias) (sqls []string) { 31 if len(modelCache.cache) == 0 { 32 fmt.Println("no Model found, need register your model") 33 os.Exit(2) 34 } 35 36 Q := al.DbBaser.TableQuote() 37 38 for _, mi := range modelCache.allOrdered() { 39 sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) 40 } 41 return sqls 42 } 43 44 // get database column type string. 45 func getColumnTyp(al *alias, fi *fieldInfo) (col string) { 46 T := al.DbBaser.DbTypes() 47 fieldType := fi.fieldType 48 fieldSize := fi.size 49 50 checkColumn: 51 switch fieldType { 52 case TypeBooleanField: 53 col = T["bool"] 54 case TypeVarCharField: 55 if al.Driver == DRPostgres && fi.toText { 56 col = T["string-text"] 57 } else { 58 col = fmt.Sprintf(T["string"], fieldSize) 59 } 60 case TypeCharField: 61 col = fmt.Sprintf(T["string-char"], fieldSize) 62 case TypeTextField: 63 col = T["string-text"] 64 case TypeTimeField: 65 col = T["time.Time-clock"] 66 case TypeDateField: 67 col = T["time.Time-date"] 68 case TypeDateTimeField: 69 col = T["time.Time"] 70 case TypeBitField: 71 col = T["int8"] 72 case TypeSmallIntegerField: 73 col = T["int16"] 74 case TypeIntegerField: 75 col = T["int32"] 76 case TypeBigIntegerField: 77 if al.Driver == DRSqlite { 78 fieldType = TypeIntegerField 79 goto checkColumn 80 } 81 col = T["int64"] 82 case TypePositiveBitField: 83 col = T["uint8"] 84 case TypePositiveSmallIntegerField: 85 col = T["uint16"] 86 case TypePositiveIntegerField: 87 col = T["uint32"] 88 case TypePositiveBigIntegerField: 89 col = T["uint64"] 90 case TypeFloatField: 91 col = T["float64"] 92 case TypeDecimalField: 93 s := T["float64-decimal"] 94 if !strings.Contains(s, "%d") { 95 col = s 96 } else { 97 col = fmt.Sprintf(s, fi.digits, fi.decimals) 98 } 99 case TypeJSONField: 100 if al.Driver != DRPostgres { 101 fieldType = TypeVarCharField 102 goto checkColumn 103 } 104 col = T["json"] 105 case TypeJsonbField: 106 if al.Driver != DRPostgres { 107 fieldType = TypeVarCharField 108 goto checkColumn 109 } 110 col = T["jsonb"] 111 case RelForeignKey, RelOneToOne: 112 fieldType = fi.relModelInfo.fields.pk.fieldType 113 fieldSize = fi.relModelInfo.fields.pk.size 114 goto checkColumn 115 } 116 117 return 118 } 119 120 // create alter sql string. 121 func getColumnAddQuery(al *alias, fi *fieldInfo) string { 122 Q := al.DbBaser.TableQuote() 123 typ := getColumnTyp(al, fi) 124 125 if !fi.null { 126 typ += " " + "NOT NULL" 127 } 128 129 return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", 130 Q, fi.mi.table, Q, 131 Q, fi.column, Q, 132 typ, getColumnDefault(fi), 133 ) 134 } 135 136 // create database creation string. 137 func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { 138 if len(modelCache.cache) == 0 { 139 fmt.Println("no Model found, need register your model") 140 os.Exit(2) 141 } 142 143 Q := al.DbBaser.TableQuote() 144 T := al.DbBaser.DbTypes() 145 sep := fmt.Sprintf("%s, %s", Q, Q) 146 147 tableIndexes = make(map[string][]dbIndex) 148 149 for _, mi := range modelCache.allOrdered() { 150 sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) 151 sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) 152 sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) 153 154 sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) 155 156 columns := make([]string, 0, len(mi.fields.fieldsDB)) 157 158 sqlIndexes := [][]string{} 159 160 for _, fi := range mi.fields.fieldsDB { 161 162 column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) 163 col := getColumnTyp(al, fi) 164 165 if fi.auto { 166 switch al.Driver { 167 case DRSqlite, DRPostgres: 168 column += T["auto"] 169 default: 170 column += col + " " + T["auto"] 171 } 172 } else if fi.pk { 173 column += col + " " + T["pk"] 174 } else { 175 column += col 176 177 if !fi.null { 178 column += " " + "NOT NULL" 179 } 180 181 //if fi.initial.String() != "" { 182 // column += " DEFAULT " + fi.initial.String() 183 //} 184 185 // Append attribute DEFAULT 186 column += getColumnDefault(fi) 187 188 if fi.unique { 189 column += " " + "UNIQUE" 190 } 191 192 if fi.index { 193 sqlIndexes = append(sqlIndexes, []string{fi.column}) 194 } 195 } 196 197 if strings.Contains(column, "%COL%") { 198 column = strings.Replace(column, "%COL%", fi.column, -1) 199 } 200 201 if fi.description != "" && al.Driver != DRSqlite { 202 column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) 203 } 204 205 columns = append(columns, column) 206 } 207 208 if mi.model != nil { 209 allnames := getTableUnique(mi.addrField) 210 if !mi.manual && len(mi.uniques) > 0 { 211 allnames = append(allnames, mi.uniques) 212 } 213 for _, names := range allnames { 214 cols := make([]string, 0, len(names)) 215 for _, name := range names { 216 if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { 217 cols = append(cols, fi.column) 218 } else { 219 panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) 220 } 221 } 222 column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) 223 columns = append(columns, column) 224 } 225 } 226 227 sql += strings.Join(columns, ",\n") 228 sql += "\n)" 229 230 if al.Driver == DRMySQL { 231 var engine string 232 if mi.model != nil { 233 engine = getTableEngine(mi.addrField) 234 } 235 if engine == "" { 236 engine = al.Engine 237 } 238 sql += " ENGINE=" + engine 239 } 240 241 sql += ";" 242 sqls = append(sqls, sql) 243 244 if mi.model != nil { 245 for _, names := range getTableIndex(mi.addrField) { 246 cols := make([]string, 0, len(names)) 247 for _, name := range names { 248 if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { 249 cols = append(cols, fi.column) 250 } else { 251 panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) 252 } 253 } 254 sqlIndexes = append(sqlIndexes, cols) 255 } 256 } 257 258 for _, names := range sqlIndexes { 259 name := mi.table + "_" + strings.Join(names, "_") 260 cols := strings.Join(names, sep) 261 sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) 262 263 index := dbIndex{} 264 index.Table = mi.table 265 index.Name = name 266 index.SQL = sql 267 268 tableIndexes[mi.table] = append(tableIndexes[mi.table], index) 269 } 270 271 } 272 273 return 274 } 275 276 // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands 277 func getColumnDefault(fi *fieldInfo) string { 278 var ( 279 v, t, d string 280 ) 281 282 // Skip default attribute if field is in relations 283 if fi.rel || fi.reverse { 284 return v 285 } 286 287 t = " DEFAULT '%s' " 288 289 // These defaults will be useful if there no config value orm:"default" and NOT NULL is on 290 switch fi.fieldType { 291 case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: 292 return v 293 294 case TypeBitField, TypeSmallIntegerField, TypeIntegerField, 295 TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, 296 TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, 297 TypeDecimalField: 298 t = " DEFAULT %s " 299 d = "0" 300 case TypeBooleanField: 301 t = " DEFAULT %s " 302 d = "FALSE" 303 case TypeJSONField, TypeJsonbField: 304 d = "{}" 305 } 306 307 if fi.colDefault { 308 if !fi.initial.Exist() { 309 v = fmt.Sprintf(t, "") 310 } else { 311 v = fmt.Sprintf(t, fi.initial.String()) 312 } 313 } else { 314 if !fi.null { 315 v = fmt.Sprintf(t, d) 316 } 317 } 318 319 return v 320 }