github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/db_mysql.go (about)

     1  // The original package is migrated from beego and modified, you can find orignal from following link:
     2  //    "github.com/beego/beego/"
     3  //
     4  // Copyright 2023 IAC. All Rights Reserved.
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //      http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package orm
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"reflect"
    24  	"strings"
    25  )
    26  
    27  // mysql operators.
    28  var mysqlOperators = map[string]string{
    29  	"exact":       "= ?",
    30  	"iexact":      "LIKE ?",
    31  	"strictexact": "= BINARY ?",
    32  	"contains":    "LIKE BINARY ?",
    33  	"icontains":   "LIKE ?",
    34  	// "regex":       "REGEXP BINARY ?",
    35  	// "iregex":      "REGEXP ?",
    36  	"gt":          "> ?",
    37  	"gte":         ">= ?",
    38  	"lt":          "< ?",
    39  	"lte":         "<= ?",
    40  	"eq":          "= ?",
    41  	"ne":          "!= ?",
    42  	"startswith":  "LIKE BINARY ?",
    43  	"endswith":    "LIKE BINARY ?",
    44  	"istartswith": "LIKE ?",
    45  	"iendswith":   "LIKE ?",
    46  }
    47  
    48  // mysql column field types.
    49  var mysqlTypes = map[string]string{
    50  	"auto":                "AUTO_INCREMENT NOT NULL PRIMARY KEY",
    51  	"pk":                  "NOT NULL PRIMARY KEY",
    52  	"bool":                "bool",
    53  	"string":              "varchar(%d)",
    54  	"string-char":         "char(%d)",
    55  	"string-text":         "longtext",
    56  	"time.Time-date":      "date",
    57  	"time.Time":           "datetime",
    58  	"int8":                "tinyint",
    59  	"int16":               "smallint",
    60  	"int32":               "integer",
    61  	"int64":               "bigint",
    62  	"uint8":               "tinyint unsigned",
    63  	"uint16":              "smallint unsigned",
    64  	"uint32":              "integer unsigned",
    65  	"uint64":              "bigint unsigned",
    66  	"float64":             "double precision",
    67  	"float64-decimal":     "numeric(%d, %d)",
    68  	"time.Time-precision": "datetime(%d)",
    69  }
    70  
    71  // mysql dbBaser implementation.
    72  type dbBaseMysql struct {
    73  	dbBase
    74  }
    75  
    76  var _ dbBaser = new(dbBaseMysql)
    77  
    78  // get mysql operator.
    79  func (d *dbBaseMysql) OperatorSQL(operator string) string {
    80  	return mysqlOperators[operator]
    81  }
    82  
    83  // get mysql table field types.
    84  func (d *dbBaseMysql) DbTypes() map[string]string {
    85  	return mysqlTypes
    86  }
    87  
    88  // show table sql for mysql.
    89  func (d *dbBaseMysql) ShowTablesQuery() string {
    90  	return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
    91  }
    92  
    93  // show columns sql of table for mysql.
    94  func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
    95  	return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
    96  		"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
    97  }
    98  
    99  // execute sql to check index exist.
   100  func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
   101  	row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
   102  		"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
   103  	var cnt int
   104  	row.Scan(&cnt)
   105  	return cnt > 0
   106  }
   107  
   108  // InsertOrUpdate a row
   109  // If your primary key or unique column conflict will update
   110  // If no will insert
   111  // Add "`" for mysql sql building
   112  func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
   113  	var iouStr string
   114  	argsMap := map[string]string{}
   115  
   116  	iouStr = "ON DUPLICATE KEY UPDATE"
   117  
   118  	// Get on the key-value pairs
   119  	for _, v := range args {
   120  		kv := strings.Split(v, "=")
   121  		if len(kv) == 2 {
   122  			argsMap[strings.ToLower(kv[0])] = kv[1]
   123  		}
   124  	}
   125  
   126  	isMulti := false
   127  	names := make([]string, 0, len(mi.fields.dbcols)-1)
   128  	Q := d.ins.TableQuote()
   129  	values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ)
   130  	if err != nil {
   131  		return 0, err
   132  	}
   133  
   134  	marks := make([]string, len(names))
   135  	updateValues := make([]interface{}, 0)
   136  	updates := make([]string, len(names))
   137  
   138  	for i, v := range names {
   139  		marks[i] = "?"
   140  		valueStr := argsMap[strings.ToLower(v)]
   141  		if valueStr != "" {
   142  			updates[i] = "`" + v + "`" + "=" + valueStr
   143  		} else {
   144  			updates[i] = "`" + v + "`" + "=?"
   145  			updateValues = append(updateValues, values[i])
   146  		}
   147  	}
   148  
   149  	values = append(values, updateValues...)
   150  
   151  	sep := fmt.Sprintf("%s, %s", Q, Q)
   152  	qmarks := strings.Join(marks, ", ")
   153  	qupdates := strings.Join(updates, ", ")
   154  	columns := strings.Join(names, sep)
   155  
   156  	multi := len(values) / len(names)
   157  
   158  	if isMulti {
   159  		qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
   160  	}
   161  	// conflitValue maybe is an int,can`t use fmt.Sprintf
   162  	query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr)
   163  
   164  	d.ins.ReplaceMarks(&query)
   165  
   166  	if isMulti || !d.ins.HasReturningID(mi, &query) {
   167  		res, err := q.ExecContext(ctx, query, values...)
   168  		if err == nil {
   169  			if isMulti {
   170  				return res.RowsAffected()
   171  			}
   172  
   173  			lastInsertId, err := res.LastInsertId()
   174  			if err != nil {
   175  				DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
   176  				return lastInsertId, ErrLastInsertIdUnavailable
   177  			} else {
   178  				return lastInsertId, nil
   179  			}
   180  		}
   181  		return 0, err
   182  	}
   183  
   184  	row := q.QueryRowContext(ctx, query, values...)
   185  	var id int64
   186  	err = row.Scan(&id)
   187  	return id, err
   188  }
   189  
   190  // create new mysql dbBaser.
   191  func newdbBaseMysql() dbBaser {
   192  	b := new(dbBaseMysql)
   193  	b.ins = b
   194  	return b
   195  }