github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/db_mysql.go (about) 1 // The original package is migrated from beego and modified, you can find orignal from following link: 2 // "github.com/beego/beego/" 3 // 4 // Copyright 2023 IAC. All Rights Reserved. 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 package orm 19 20 import ( 21 "context" 22 "fmt" 23 "reflect" 24 "strings" 25 ) 26 27 // mysql operators. 28 var mysqlOperators = map[string]string{ 29 "exact": "= ?", 30 "iexact": "LIKE ?", 31 "strictexact": "= BINARY ?", 32 "contains": "LIKE BINARY ?", 33 "icontains": "LIKE ?", 34 // "regex": "REGEXP BINARY ?", 35 // "iregex": "REGEXP ?", 36 "gt": "> ?", 37 "gte": ">= ?", 38 "lt": "< ?", 39 "lte": "<= ?", 40 "eq": "= ?", 41 "ne": "!= ?", 42 "startswith": "LIKE BINARY ?", 43 "endswith": "LIKE BINARY ?", 44 "istartswith": "LIKE ?", 45 "iendswith": "LIKE ?", 46 } 47 48 // mysql column field types. 49 var mysqlTypes = map[string]string{ 50 "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", 51 "pk": "NOT NULL PRIMARY KEY", 52 "bool": "bool", 53 "string": "varchar(%d)", 54 "string-char": "char(%d)", 55 "string-text": "longtext", 56 "time.Time-date": "date", 57 "time.Time": "datetime", 58 "int8": "tinyint", 59 "int16": "smallint", 60 "int32": "integer", 61 "int64": "bigint", 62 "uint8": "tinyint unsigned", 63 "uint16": "smallint unsigned", 64 "uint32": "integer unsigned", 65 "uint64": "bigint unsigned", 66 "float64": "double precision", 67 "float64-decimal": "numeric(%d, %d)", 68 "time.Time-precision": "datetime(%d)", 69 } 70 71 // mysql dbBaser implementation. 72 type dbBaseMysql struct { 73 dbBase 74 } 75 76 var _ dbBaser = new(dbBaseMysql) 77 78 // get mysql operator. 79 func (d *dbBaseMysql) OperatorSQL(operator string) string { 80 return mysqlOperators[operator] 81 } 82 83 // get mysql table field types. 84 func (d *dbBaseMysql) DbTypes() map[string]string { 85 return mysqlTypes 86 } 87 88 // show table sql for mysql. 89 func (d *dbBaseMysql) ShowTablesQuery() string { 90 return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" 91 } 92 93 // show columns sql of table for mysql. 94 func (d *dbBaseMysql) ShowColumnsQuery(table string) string { 95 return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ 96 "WHERE table_schema = DATABASE() AND table_name = '%s'", table) 97 } 98 99 // execute sql to check index exist. 100 func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { 101 row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ 102 "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) 103 var cnt int 104 row.Scan(&cnt) 105 return cnt > 0 106 } 107 108 // InsertOrUpdate a row 109 // If your primary key or unique column conflict will update 110 // If no will insert 111 // Add "`" for mysql sql building 112 func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { 113 var iouStr string 114 argsMap := map[string]string{} 115 116 iouStr = "ON DUPLICATE KEY UPDATE" 117 118 // Get on the key-value pairs 119 for _, v := range args { 120 kv := strings.Split(v, "=") 121 if len(kv) == 2 { 122 argsMap[strings.ToLower(kv[0])] = kv[1] 123 } 124 } 125 126 isMulti := false 127 names := make([]string, 0, len(mi.fields.dbcols)-1) 128 Q := d.ins.TableQuote() 129 values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) 130 if err != nil { 131 return 0, err 132 } 133 134 marks := make([]string, len(names)) 135 updateValues := make([]interface{}, 0) 136 updates := make([]string, len(names)) 137 138 for i, v := range names { 139 marks[i] = "?" 140 valueStr := argsMap[strings.ToLower(v)] 141 if valueStr != "" { 142 updates[i] = "`" + v + "`" + "=" + valueStr 143 } else { 144 updates[i] = "`" + v + "`" + "=?" 145 updateValues = append(updateValues, values[i]) 146 } 147 } 148 149 values = append(values, updateValues...) 150 151 sep := fmt.Sprintf("%s, %s", Q, Q) 152 qmarks := strings.Join(marks, ", ") 153 qupdates := strings.Join(updates, ", ") 154 columns := strings.Join(names, sep) 155 156 multi := len(values) / len(names) 157 158 if isMulti { 159 qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks 160 } 161 // conflitValue maybe is an int,can`t use fmt.Sprintf 162 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) 163 164 d.ins.ReplaceMarks(&query) 165 166 if isMulti || !d.ins.HasReturningID(mi, &query) { 167 res, err := q.ExecContext(ctx, query, values...) 168 if err == nil { 169 if isMulti { 170 return res.RowsAffected() 171 } 172 173 lastInsertId, err := res.LastInsertId() 174 if err != nil { 175 DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) 176 return lastInsertId, ErrLastInsertIdUnavailable 177 } else { 178 return lastInsertId, nil 179 } 180 } 181 return 0, err 182 } 183 184 row := q.QueryRowContext(ctx, query, values...) 185 var id int64 186 err = row.Scan(&id) 187 return id, err 188 } 189 190 // create new mysql dbBaser. 191 func newdbBaseMysql() dbBaser { 192 b := new(dbBaseMysql) 193 b.ins = b 194 return b 195 }