github.com/astaxie/beego@v1.12.3/orm/cmd_utils.go (about)

     1  // Copyright 2014 beego Author. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package orm
    16  
    17  import (
    18  	"fmt"
    19  	"os"
    20  	"strings"
    21  )
    22  
    23  type dbIndex struct {
    24  	Table string
    25  	Name  string
    26  	SQL   string
    27  }
    28  
    29  // create database drop sql.
    30  func getDbDropSQL(al *alias) (sqls []string) {
    31  	if len(modelCache.cache) == 0 {
    32  		fmt.Println("no Model found, need register your model")
    33  		os.Exit(2)
    34  	}
    35  
    36  	Q := al.DbBaser.TableQuote()
    37  
    38  	for _, mi := range modelCache.allOrdered() {
    39  		sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q))
    40  	}
    41  	return sqls
    42  }
    43  
    44  // get database column type string.
    45  func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
    46  	T := al.DbBaser.DbTypes()
    47  	fieldType := fi.fieldType
    48  	fieldSize := fi.size
    49  
    50  checkColumn:
    51  	switch fieldType {
    52  	case TypeBooleanField:
    53  		col = T["bool"]
    54  	case TypeVarCharField:
    55  		if al.Driver == DRPostgres && fi.toText {
    56  			col = T["string-text"]
    57  		} else {
    58  			col = fmt.Sprintf(T["string"], fieldSize)
    59  		}
    60  	case TypeCharField:
    61  		col = fmt.Sprintf(T["string-char"], fieldSize)
    62  	case TypeTextField:
    63  		col = T["string-text"]
    64  	case TypeTimeField:
    65  		col = T["time.Time-clock"]
    66  	case TypeDateField:
    67  		col = T["time.Time-date"]
    68  	case TypeDateTimeField:
    69  		col = T["time.Time"]
    70  	case TypeBitField:
    71  		col = T["int8"]
    72  	case TypeSmallIntegerField:
    73  		col = T["int16"]
    74  	case TypeIntegerField:
    75  		col = T["int32"]
    76  	case TypeBigIntegerField:
    77  		if al.Driver == DRSqlite {
    78  			fieldType = TypeIntegerField
    79  			goto checkColumn
    80  		}
    81  		col = T["int64"]
    82  	case TypePositiveBitField:
    83  		col = T["uint8"]
    84  	case TypePositiveSmallIntegerField:
    85  		col = T["uint16"]
    86  	case TypePositiveIntegerField:
    87  		col = T["uint32"]
    88  	case TypePositiveBigIntegerField:
    89  		col = T["uint64"]
    90  	case TypeFloatField:
    91  		col = T["float64"]
    92  	case TypeDecimalField:
    93  		s := T["float64-decimal"]
    94  		if !strings.Contains(s, "%d") {
    95  			col = s
    96  		} else {
    97  			col = fmt.Sprintf(s, fi.digits, fi.decimals)
    98  		}
    99  	case TypeJSONField:
   100  		if al.Driver != DRPostgres {
   101  			fieldType = TypeVarCharField
   102  			goto checkColumn
   103  		}
   104  		col = T["json"]
   105  	case TypeJsonbField:
   106  		if al.Driver != DRPostgres {
   107  			fieldType = TypeVarCharField
   108  			goto checkColumn
   109  		}
   110  		col = T["jsonb"]
   111  	case RelForeignKey, RelOneToOne:
   112  		fieldType = fi.relModelInfo.fields.pk.fieldType
   113  		fieldSize = fi.relModelInfo.fields.pk.size
   114  		goto checkColumn
   115  	}
   116  
   117  	return
   118  }
   119  
   120  // create alter sql string.
   121  func getColumnAddQuery(al *alias, fi *fieldInfo) string {
   122  	Q := al.DbBaser.TableQuote()
   123  	typ := getColumnTyp(al, fi)
   124  
   125  	if !fi.null {
   126  		typ += " " + "NOT NULL"
   127  	}
   128  
   129  	return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
   130  		Q, fi.mi.table, Q,
   131  		Q, fi.column, Q,
   132  		typ, getColumnDefault(fi),
   133  	)
   134  }
   135  
   136  // create database creation string.
   137  func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
   138  	if len(modelCache.cache) == 0 {
   139  		fmt.Println("no Model found, need register your model")
   140  		os.Exit(2)
   141  	}
   142  
   143  	Q := al.DbBaser.TableQuote()
   144  	T := al.DbBaser.DbTypes()
   145  	sep := fmt.Sprintf("%s, %s", Q, Q)
   146  
   147  	tableIndexes = make(map[string][]dbIndex)
   148  
   149  	for _, mi := range modelCache.allOrdered() {
   150  		sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
   151  		sql += fmt.Sprintf("--  Table Structure for `%s`\n", mi.fullName)
   152  		sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
   153  
   154  		sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
   155  
   156  		columns := make([]string, 0, len(mi.fields.fieldsDB))
   157  
   158  		sqlIndexes := [][]string{}
   159  
   160  		for _, fi := range mi.fields.fieldsDB {
   161  
   162  			column := fmt.Sprintf("    %s%s%s ", Q, fi.column, Q)
   163  			col := getColumnTyp(al, fi)
   164  
   165  			if fi.auto {
   166  				switch al.Driver {
   167  				case DRSqlite, DRPostgres:
   168  					column += T["auto"]
   169  				default:
   170  					column += col + " " + T["auto"]
   171  				}
   172  			} else if fi.pk {
   173  				column += col + " " + T["pk"]
   174  			} else {
   175  				column += col
   176  
   177  				if !fi.null {
   178  					column += " " + "NOT NULL"
   179  				}
   180  
   181  				//if fi.initial.String() != "" {
   182  				//	column += " DEFAULT " + fi.initial.String()
   183  				//}
   184  
   185  				// Append attribute DEFAULT
   186  				column += getColumnDefault(fi)
   187  
   188  				if fi.unique {
   189  					column += " " + "UNIQUE"
   190  				}
   191  
   192  				if fi.index {
   193  					sqlIndexes = append(sqlIndexes, []string{fi.column})
   194  				}
   195  			}
   196  
   197  			if strings.Contains(column, "%COL%") {
   198  				column = strings.Replace(column, "%COL%", fi.column, -1)
   199  			}
   200  
   201  			if fi.description != "" && al.Driver != DRSqlite {
   202  				column += " " + fmt.Sprintf("COMMENT '%s'", fi.description)
   203  			}
   204  
   205  			columns = append(columns, column)
   206  		}
   207  
   208  		if mi.model != nil {
   209  			allnames := getTableUnique(mi.addrField)
   210  			if !mi.manual && len(mi.uniques) > 0 {
   211  				allnames = append(allnames, mi.uniques)
   212  			}
   213  			for _, names := range allnames {
   214  				cols := make([]string, 0, len(names))
   215  				for _, name := range names {
   216  					if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
   217  						cols = append(cols, fi.column)
   218  					} else {
   219  						panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName))
   220  					}
   221  				}
   222  				column := fmt.Sprintf("    UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q)
   223  				columns = append(columns, column)
   224  			}
   225  		}
   226  
   227  		sql += strings.Join(columns, ",\n")
   228  		sql += "\n)"
   229  
   230  		if al.Driver == DRMySQL {
   231  			var engine string
   232  			if mi.model != nil {
   233  				engine = getTableEngine(mi.addrField)
   234  			}
   235  			if engine == "" {
   236  				engine = al.Engine
   237  			}
   238  			sql += " ENGINE=" + engine
   239  		}
   240  
   241  		sql += ";"
   242  		sqls = append(sqls, sql)
   243  
   244  		if mi.model != nil {
   245  			for _, names := range getTableIndex(mi.addrField) {
   246  				cols := make([]string, 0, len(names))
   247  				for _, name := range names {
   248  					if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
   249  						cols = append(cols, fi.column)
   250  					} else {
   251  						panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName))
   252  					}
   253  				}
   254  				sqlIndexes = append(sqlIndexes, cols)
   255  			}
   256  		}
   257  
   258  		for _, names := range sqlIndexes {
   259  			name := mi.table + "_" + strings.Join(names, "_")
   260  			cols := strings.Join(names, sep)
   261  			sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
   262  
   263  			index := dbIndex{}
   264  			index.Table = mi.table
   265  			index.Name = name
   266  			index.SQL = sql
   267  
   268  			tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
   269  		}
   270  
   271  	}
   272  
   273  	return
   274  }
   275  
   276  // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
   277  func getColumnDefault(fi *fieldInfo) string {
   278  	var (
   279  		v, t, d string
   280  	)
   281  
   282  	// Skip default attribute if field is in relations
   283  	if fi.rel || fi.reverse {
   284  		return v
   285  	}
   286  
   287  	t = " DEFAULT '%s' "
   288  
   289  	// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
   290  	switch fi.fieldType {
   291  	case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
   292  		return v
   293  
   294  	case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
   295  		TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
   296  		TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
   297  		TypeDecimalField:
   298  		t = " DEFAULT %s "
   299  		d = "0"
   300  	case TypeBooleanField:
   301  		t = " DEFAULT %s "
   302  		d = "FALSE"
   303  	case TypeJSONField, TypeJsonbField:
   304  		d = "{}"
   305  	}
   306  
   307  	if fi.colDefault {
   308  		if !fi.initial.Exist() {
   309  			v = fmt.Sprintf(t, "")
   310  		} else {
   311  			v = fmt.Sprintf(t, fi.initial.String())
   312  		}
   313  	} else {
   314  		if !fi.null {
   315  			v = fmt.Sprintf(t, d)
   316  		}
   317  	}
   318  
   319  	return v
   320  }