github.com/go-courier/sqlx/v2@v2.23.13/connectors/mysql/mysql_connector.go (about)

     1  package mysql
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"io"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  
    13  	typex "github.com/go-courier/x/types"
    14  
    15  	"github.com/go-courier/sqlx/v2"
    16  	"github.com/go-courier/sqlx/v2/builder"
    17  	"github.com/go-courier/sqlx/v2/migration"
    18  	"github.com/go-sql-driver/mysql"
    19  )
    20  
    21  var _ interface {
    22  	driver.Connector
    23  	builder.Dialect
    24  } = (*MysqlConnector)(nil)
    25  
    26  type MysqlConnector struct {
    27  	Host    string
    28  	DBName  string
    29  	Extra   string
    30  	Engine  string
    31  	Charset string
    32  }
    33  
    34  func dsn(host string, dbName string, extra string) string {
    35  	if extra != "" {
    36  		extra = "?" + extra
    37  	}
    38  	return host + "/" + dbName + extra
    39  }
    40  
    41  func (c MysqlConnector) WithDBName(dbName string) driver.Connector {
    42  	c.DBName = dbName
    43  	return &c
    44  }
    45  
    46  func (c *MysqlConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error {
    47  	output := migration.MigrationOutputFromContext(ctx)
    48  
    49  	// mysql without schema
    50  	d := db.D().WithSchema("")
    51  	dialect := db.Dialect()
    52  
    53  	prevDB, err := dbFromInformationSchema(db)
    54  	if err != nil {
    55  		return err
    56  	}
    57  
    58  	exec := func(expr builder.SqlExpr) error {
    59  		if expr == nil || expr.IsNil() {
    60  			return nil
    61  		}
    62  
    63  		if output != nil {
    64  			_, _ = io.WriteString(output, builder.ResolveExpr(expr).Query())
    65  			_, _ = io.WriteString(output, "\n")
    66  			return nil
    67  		}
    68  
    69  		_, err := db.ExecExpr(expr)
    70  		return err
    71  	}
    72  
    73  	if prevDB == nil {
    74  		prevDB = &sqlx.Database{
    75  			Name: d.Name,
    76  		}
    77  
    78  		if err := exec(dialect.CreateDatabase(d.Name)); err != nil {
    79  			return err
    80  		}
    81  	}
    82  
    83  	for _, name := range d.Tables.TableNames() {
    84  		table := d.Tables.Table(name)
    85  		prevTable := prevDB.Table(name)
    86  
    87  		if prevTable == nil {
    88  			for _, expr := range dialect.CreateTableIsNotExists(table) {
    89  				if err := exec(expr); err != nil {
    90  					return err
    91  				}
    92  			}
    93  			continue
    94  		}
    95  
    96  		exprList := table.Diff(prevTable, dialect)
    97  
    98  		for _, expr := range exprList {
    99  			if err := exec(expr); err != nil {
   100  				return err
   101  			}
   102  		}
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  func (c *MysqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
   109  	d := c.Driver()
   110  
   111  	conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra))
   112  	if err != nil {
   113  		if c.IsErrorUnknownDatabase(err) {
   114  			conn, err := d.Open(dsn(c.Host, "", c.Extra))
   115  			if err != nil {
   116  				return nil, err
   117  			}
   118  			if _, err := conn.(driver.ExecerContext).ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil {
   119  				return nil, err
   120  			}
   121  			if err := conn.Close(); err != nil {
   122  				return nil, err
   123  			}
   124  			return c.Connect(ctx)
   125  		}
   126  		return nil, err
   127  	}
   128  	return conn, nil
   129  }
   130  
   131  func (c MysqlConnector) Driver() driver.Driver {
   132  	return (&MySqlLoggingDriver{}).Driver()
   133  }
   134  
   135  func (MysqlConnector) DriverName() string {
   136  	return "mysql"
   137  }
   138  
   139  func (MysqlConnector) PrimaryKeyName() string {
   140  	return "primary"
   141  }
   142  
   143  func (c MysqlConnector) IsErrorUnknownDatabase(err error) bool {
   144  	if mysqlErr, ok := sqlx.UnwrapAll(err).(*mysql.MySQLError); ok && mysqlErr.Number == 1049 {
   145  		return true
   146  	}
   147  	return false
   148  }
   149  
   150  func (c MysqlConnector) IsErrorConflict(err error) bool {
   151  	if mysqlErr, ok := sqlx.UnwrapAll(err).(*mysql.MySQLError); ok && mysqlErr.Number == 1062 {
   152  		return true
   153  	}
   154  	return false
   155  }
   156  
   157  func quoteString(name string) string {
   158  	if len(name) < 2 ||
   159  		(name[0] == '`' && name[len(name)-1] == '`') {
   160  		return name
   161  	}
   162  
   163  	return "`" + name + "`"
   164  }
   165  
   166  func (c *MysqlConnector) CreateDatabase(dbName string) builder.SqlExpr {
   167  	e := builder.Expr("CREATE DATABASE ")
   168  	e.WriteQuery(quoteString(dbName))
   169  	e.WriteEnd()
   170  	return e
   171  }
   172  
   173  func (c *MysqlConnector) CreateSchema(schema string) builder.SqlExpr {
   174  	e := builder.Expr("CREATE SCHEMA ")
   175  	e.WriteQuery(schema)
   176  	e.WriteEnd()
   177  	return e
   178  }
   179  
   180  func (c *MysqlConnector) DropDatabase(dbName string) builder.SqlExpr {
   181  	e := builder.Expr("DROP DATABASE ")
   182  	e.WriteQuery(quoteString(dbName))
   183  	e.WriteEnd()
   184  	return e
   185  }
   186  
   187  func (c *MysqlConnector) AddIndex(key *builder.Key) builder.SqlExpr {
   188  	if key.IsPrimary() {
   189  		e := builder.Expr("ALTER TABLE ")
   190  		e.WriteExpr(key.Table)
   191  		e.WriteQuery(" ADD PRIMARY KEY ")
   192  		e.WriteExpr(key.Def.TableExpr(key.Table))
   193  		e.WriteEnd()
   194  		return e
   195  	}
   196  
   197  	e := builder.Expr("CREATE ")
   198  	if key.Method == "SPATIAL" {
   199  		e.WriteQuery("SPATIAL ")
   200  	} else if key.IsUnique {
   201  		e.WriteQuery("UNIQUE ")
   202  	}
   203  	e.WriteQuery("INDEX ")
   204  
   205  	e.WriteQuery(key.Name)
   206  
   207  	if key.Method == "BTREE" || key.Method == "HASH" {
   208  		e.WriteQuery(" USING ")
   209  		e.WriteQuery(key.Method)
   210  	}
   211  
   212  	e.WriteQuery(" ON ")
   213  	e.WriteExpr(key.Table)
   214  
   215  	e.WriteQueryByte(' ')
   216  	e.WriteExpr(key.Def.TableExpr(key.Table))
   217  
   218  	e.WriteEnd()
   219  	return e
   220  }
   221  
   222  func (c *MysqlConnector) DropIndex(key *builder.Key) builder.SqlExpr {
   223  	if key.IsPrimary() {
   224  		e := builder.Expr("ALTER TABLE ")
   225  		e.WriteExpr(key.Table)
   226  		e.WriteQuery(" DROP PRIMARY KEY")
   227  		e.WriteEnd()
   228  		return e
   229  	}
   230  	e := builder.Expr("DROP ")
   231  
   232  	e.WriteQuery("INDEX ")
   233  	e.WriteQuery(key.Name)
   234  
   235  	e.WriteQuery(" ON ")
   236  	e.WriteExpr(key.Table)
   237  	e.WriteEnd()
   238  
   239  	return e
   240  }
   241  
   242  func (c *MysqlConnector) CreateTableIsNotExists(table *builder.Table) (exprs []builder.SqlExpr) {
   243  	expr := builder.Expr("CREATE TABLE IF NOT EXISTS ")
   244  	expr.WriteExpr(table)
   245  	expr.WriteQueryByte(' ')
   246  	expr.WriteGroup(func(e *builder.Ex) {
   247  		if table.Columns.IsNil() {
   248  			return
   249  		}
   250  
   251  		table.Columns.Range(func(col *builder.Column, idx int) {
   252  			if col.DeprecatedActions != nil {
   253  				return
   254  			}
   255  
   256  			if idx > 0 {
   257  				e.WriteQueryByte(',')
   258  			}
   259  			e.WriteQueryByte('\n')
   260  			e.WriteQueryByte('\t')
   261  
   262  			e.WriteExpr(col)
   263  			e.WriteQueryByte(' ')
   264  			e.WriteExpr(c.DataType(col.ColumnType))
   265  		})
   266  
   267  		table.Keys.Range(func(key *builder.Key, idx int) {
   268  			if key.IsPrimary() {
   269  				e.WriteQueryByte(',')
   270  				e.WriteQueryByte('\n')
   271  				e.WriteQueryByte('\t')
   272  				e.WriteQuery("PRIMARY KEY ")
   273  				e.WriteExpr(key.Def.TableExpr(key.Table))
   274  			}
   275  		})
   276  
   277  		expr.WriteQueryByte('\n')
   278  	})
   279  
   280  	expr.WriteQuery(" ENGINE=")
   281  
   282  	if c.Engine == "" {
   283  		expr.WriteQuery("InnoDB")
   284  	} else {
   285  		expr.WriteQuery(c.Engine)
   286  	}
   287  
   288  	expr.WriteQuery(" CHARSET=")
   289  
   290  	if c.Charset == "" {
   291  		expr.WriteQuery("utf8mb4")
   292  	} else {
   293  		expr.WriteQuery(c.Charset)
   294  	}
   295  
   296  	expr.WriteEnd()
   297  	exprs = append(exprs, expr)
   298  
   299  	table.Keys.Range(func(key *builder.Key, idx int) {
   300  		if !key.IsPrimary() {
   301  			exprs = append(exprs, c.AddIndex(key))
   302  		}
   303  	})
   304  
   305  	return
   306  }
   307  
   308  func (c *MysqlConnector) DropTable(t *builder.Table) builder.SqlExpr {
   309  	e := builder.Expr("DROP TABLE IF EXISTS ")
   310  	e.WriteQuery(t.Name)
   311  	e.WriteEnd()
   312  	return e
   313  }
   314  
   315  func (c *MysqlConnector) TruncateTable(t *builder.Table) builder.SqlExpr {
   316  	e := builder.Expr("TRUNCATE TABLE ")
   317  	e.WriteQuery(t.Name)
   318  	e.WriteEnd()
   319  	return e
   320  }
   321  
   322  func (c *MysqlConnector) AddColumn(col *builder.Column) builder.SqlExpr {
   323  	e := builder.Expr("ALTER TABLE ")
   324  	e.WriteExpr(col.Table)
   325  	e.WriteQuery(" ADD COLUMN ")
   326  	e.WriteExpr(col)
   327  	e.WriteQueryByte(' ')
   328  	e.WriteExpr(c.DataType(col.ColumnType))
   329  	e.WriteEnd()
   330  	return e
   331  }
   332  
   333  func (c *MysqlConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr {
   334  	e := builder.Expr("ALTER TABLE ")
   335  	e.WriteExpr(col.Table)
   336  	e.WriteQuery(" CHANGE ")
   337  	e.WriteExpr(col)
   338  	e.WriteQueryByte(' ')
   339  	e.WriteExpr(target)
   340  	e.WriteQueryByte(' ')
   341  	e.WriteExpr(c.DataType(target.ColumnType))
   342  	e.WriteEnd()
   343  	return e
   344  }
   345  
   346  func (c *MysqlConnector) ModifyColumn(col *builder.Column, prev *builder.Column) builder.SqlExpr {
   347  	e := builder.Expr("ALTER TABLE ")
   348  	e.WriteExpr(col.Table)
   349  	e.WriteQuery(" MODIFY COLUMN ")
   350  	e.WriteExpr(col)
   351  	e.WriteQueryByte(' ')
   352  	e.WriteExpr(c.DataType(col.ColumnType))
   353  
   354  	e.WriteQuery(" /* FROM")
   355  	e.WriteExpr(c.DataType(prev.ColumnType))
   356  	e.WriteQuery(" */")
   357  
   358  	e.WriteEnd()
   359  	return e
   360  }
   361  
   362  func (c *MysqlConnector) DropColumn(col *builder.Column) builder.SqlExpr {
   363  	e := builder.Expr("ALTER TABLE ")
   364  	e.WriteExpr(col.Table)
   365  	e.WriteQuery(" DROP COLUMN ")
   366  	e.WriteQuery(col.Name)
   367  	e.WriteEnd()
   368  	return e
   369  }
   370  
   371  func (c *MysqlConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr {
   372  	dbDataType := dealias(c.dbDataType(columnType.Type, columnType))
   373  	return builder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType))
   374  }
   375  
   376  func (c *MysqlConnector) dataType(typ typex.Type, columnType *builder.ColumnType) string {
   377  	dbDataType := dealias(c.dbDataType(typ, columnType))
   378  	return dbDataType + autocompleteSize(dbDataType, columnType)
   379  }
   380  
   381  func (c *MysqlConnector) dbDataType(typ typex.Type, columnType *builder.ColumnType) string {
   382  	if columnType.DataType != "" {
   383  		return columnType.DataType
   384  	}
   385  
   386  	if rv, ok := typex.TryNew(typ); ok {
   387  		if dtd, ok := rv.Interface().(builder.DataTypeDescriber); ok {
   388  			return dtd.DataType(c.DriverName())
   389  		}
   390  	}
   391  
   392  	switch typ.Kind() {
   393  	case reflect.Ptr:
   394  		return c.dataType(typ.Elem(), columnType)
   395  	case reflect.Bool:
   396  		return "boolean"
   397  	case reflect.Int8:
   398  		return "tinyint"
   399  	case reflect.Uint8:
   400  		return "tinyint unsigned"
   401  	case reflect.Int16:
   402  		return "smallint"
   403  	case reflect.Uint16:
   404  		return "smallint unsigned"
   405  	case reflect.Int, reflect.Int32:
   406  		return "int"
   407  	case reflect.Uint, reflect.Uint32:
   408  		return "int unsigned"
   409  	case reflect.Int64:
   410  		return "bigint"
   411  	case reflect.Uint64:
   412  		return "bigint unsigned"
   413  	case reflect.Float32:
   414  		return "float"
   415  	case reflect.Float64:
   416  		return "double"
   417  	case reflect.String:
   418  		size := columnType.Length
   419  		if size < 65535/3 {
   420  			return "varchar"
   421  		}
   422  		return "text"
   423  	case reflect.Slice:
   424  		if typ.Elem().Kind() == reflect.Uint8 {
   425  			return "mediumblob"
   426  		}
   427  	}
   428  	switch typ.Name() {
   429  	case "NullInt64":
   430  		return "bigint"
   431  	case "NullFloat64":
   432  		return "double"
   433  	case "NullBool":
   434  		return "tinyint"
   435  	case "Time":
   436  		return "datetime"
   437  	}
   438  	panic(fmt.Errorf("unsupport type %s", typ))
   439  }
   440  
   441  func (c *MysqlConnector) dataTypeModify(columnType *builder.ColumnType) string {
   442  	buf := bytes.NewBuffer(nil)
   443  
   444  	if !columnType.Null {
   445  		buf.WriteString(" NOT NULL")
   446  	}
   447  
   448  	if columnType.AutoIncrement {
   449  		buf.WriteString(" AUTO_INCREMENT")
   450  	}
   451  
   452  	if columnType.Default != nil {
   453  		buf.WriteString(" DEFAULT ")
   454  		buf.WriteString(*columnType.Default)
   455  	}
   456  
   457  	if columnType.OnUpdate != nil {
   458  		buf.WriteString(" ON UPDATE ")
   459  		buf.WriteString(*columnType.OnUpdate)
   460  	}
   461  
   462  	return buf.String()
   463  }
   464  
   465  func autocompleteSize(dataType string, columnType *builder.ColumnType) string {
   466  	switch strings.ToLower(dataType) {
   467  	case "varchar":
   468  		size := columnType.Length
   469  		if size == 0 {
   470  			size = 255
   471  		}
   472  		return sizeModifier(size, columnType.Decimal)
   473  	case "float", "double", "decimal":
   474  		if columnType.Length > 0 {
   475  			return sizeModifier(columnType.Length, columnType.Decimal)
   476  		}
   477  	}
   478  	return ""
   479  }
   480  
   481  func dealias(dataType string) string {
   482  	return dataType
   483  }
   484  
   485  func sizeModifier(length uint64, decimal uint64) string {
   486  	if length > 0 {
   487  		size := strconv.FormatUint(length, 10)
   488  		if decimal > 0 {
   489  			return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")"
   490  		}
   491  		return "(" + size + ")"
   492  	}
   493  	return ""
   494  }