github.com/astaxie/beego@v1.12.3/orm/db_mysql.go (about)

     1  // Copyright 2014 beego Author. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package orm
    16  
    17  import (
    18  	"fmt"
    19  	"reflect"
    20  	"strings"
    21  )
    22  
    23  // mysql operators.
    24  var mysqlOperators = map[string]string{
    25  	"exact":       "= ?",
    26  	"iexact":      "LIKE ?",
    27  	"strictexact": "= BINARY ?",
    28  	"contains":    "LIKE BINARY ?",
    29  	"icontains":   "LIKE ?",
    30  	// "regex":       "REGEXP BINARY ?",
    31  	// "iregex":      "REGEXP ?",
    32  	"gt":          "> ?",
    33  	"gte":         ">= ?",
    34  	"lt":          "< ?",
    35  	"lte":         "<= ?",
    36  	"eq":          "= ?",
    37  	"ne":          "!= ?",
    38  	"startswith":  "LIKE BINARY ?",
    39  	"endswith":    "LIKE BINARY ?",
    40  	"istartswith": "LIKE ?",
    41  	"iendswith":   "LIKE ?",
    42  }
    43  
    44  // mysql column field types.
    45  var mysqlTypes = map[string]string{
    46  	"auto":            "AUTO_INCREMENT NOT NULL PRIMARY KEY",
    47  	"pk":              "NOT NULL PRIMARY KEY",
    48  	"bool":            "bool",
    49  	"string":          "varchar(%d)",
    50  	"string-char":     "char(%d)",
    51  	"string-text":     "longtext",
    52  	"time.Time-date":  "date",
    53  	"time.Time":       "datetime",
    54  	"int8":            "tinyint",
    55  	"int16":           "smallint",
    56  	"int32":           "integer",
    57  	"int64":           "bigint",
    58  	"uint8":           "tinyint unsigned",
    59  	"uint16":          "smallint unsigned",
    60  	"uint32":          "integer unsigned",
    61  	"uint64":          "bigint unsigned",
    62  	"float64":         "double precision",
    63  	"float64-decimal": "numeric(%d, %d)",
    64  }
    65  
    66  // mysql dbBaser implementation.
    67  type dbBaseMysql struct {
    68  	dbBase
    69  }
    70  
    71  var _ dbBaser = new(dbBaseMysql)
    72  
    73  // get mysql operator.
    74  func (d *dbBaseMysql) OperatorSQL(operator string) string {
    75  	return mysqlOperators[operator]
    76  }
    77  
    78  // get mysql table field types.
    79  func (d *dbBaseMysql) DbTypes() map[string]string {
    80  	return mysqlTypes
    81  }
    82  
    83  // show table sql for mysql.
    84  func (d *dbBaseMysql) ShowTablesQuery() string {
    85  	return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
    86  }
    87  
    88  // show columns sql of table for mysql.
    89  func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
    90  	return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
    91  		"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
    92  }
    93  
    94  // execute sql to check index exist.
    95  func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
    96  	row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
    97  		"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
    98  	var cnt int
    99  	row.Scan(&cnt)
   100  	return cnt > 0
   101  }
   102  
   103  // InsertOrUpdate a row
   104  // If your primary key or unique column conflict will update
   105  // If no will insert
   106  // Add "`" for mysql sql building
   107  func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
   108  	var iouStr string
   109  	argsMap := map[string]string{}
   110  
   111  	iouStr = "ON DUPLICATE KEY UPDATE"
   112  
   113  	//Get on the key-value pairs
   114  	for _, v := range args {
   115  		kv := strings.Split(v, "=")
   116  		if len(kv) == 2 {
   117  			argsMap[strings.ToLower(kv[0])] = kv[1]
   118  		}
   119  	}
   120  
   121  	isMulti := false
   122  	names := make([]string, 0, len(mi.fields.dbcols)-1)
   123  	Q := d.ins.TableQuote()
   124  	values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ)
   125  
   126  	if err != nil {
   127  		return 0, err
   128  	}
   129  
   130  	marks := make([]string, len(names))
   131  	updateValues := make([]interface{}, 0)
   132  	updates := make([]string, len(names))
   133  
   134  	for i, v := range names {
   135  		marks[i] = "?"
   136  		valueStr := argsMap[strings.ToLower(v)]
   137  		if valueStr != "" {
   138  			updates[i] = "`" + v + "`" + "=" + valueStr
   139  		} else {
   140  			updates[i] = "`" + v + "`" + "=?"
   141  			updateValues = append(updateValues, values[i])
   142  		}
   143  	}
   144  
   145  	values = append(values, updateValues...)
   146  
   147  	sep := fmt.Sprintf("%s, %s", Q, Q)
   148  	qmarks := strings.Join(marks, ", ")
   149  	qupdates := strings.Join(updates, ", ")
   150  	columns := strings.Join(names, sep)
   151  
   152  	multi := len(values) / len(names)
   153  
   154  	if isMulti {
   155  		qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
   156  	}
   157  	//conflitValue maybe is a int,can`t use fmt.Sprintf
   158  	query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr)
   159  
   160  	d.ins.ReplaceMarks(&query)
   161  
   162  	if isMulti || !d.ins.HasReturningID(mi, &query) {
   163  		res, err := q.Exec(query, values...)
   164  		if err == nil {
   165  			if isMulti {
   166  				return res.RowsAffected()
   167  			}
   168  			return res.LastInsertId()
   169  		}
   170  		return 0, err
   171  	}
   172  
   173  	row := q.QueryRow(query, values...)
   174  	var id int64
   175  	err = row.Scan(&id)
   176  	return id, err
   177  }
   178  
   179  // create new mysql dbBaser.
   180  func newdbBaseMysql() dbBaser {
   181  	b := new(dbBaseMysql)
   182  	b.ins = b
   183  	return b
   184  }