gitee.com/go-genie/sqlx@v1.0.3/connectors/mysql/mysql_connector.go (about)

     1  package mysql
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"io"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  
    13  	typex "gitee.com/go-genie/xx/types"
    14  
    15  	"gitee.com/go-genie/sqlx"
    16  	"gitee.com/go-genie/sqlx/builder"
    17  	"gitee.com/go-genie/sqlx/command"
    18  	"github.com/go-sql-driver/mysql"
    19  )
    20  
    21  var _ interface {
    22  	driver.Connector
    23  	builder.Dialect
    24  } = (*MysqlConnector)(nil)
    25  
    26  type MysqlConnector struct {
    27  	Host    string
    28  	DBName  string
    29  	Extra   string
    30  	Engine  string
    31  	Charset string
    32  }
    33  
    34  func dsn(host string, dbName string, extra string) string {
    35  	if extra != "" {
    36  		extra = "?" + extra
    37  	}
    38  	return host + "/" + dbName + extra
    39  }
    40  
    41  func (c MysqlConnector) WithDBName(dbName string) driver.Connector {
    42  	c.DBName = dbName
    43  	return &c
    44  }
    45  
    46  func (c *MysqlConnector) Generate(ctx context.Context, db sqlx.DBExecutor) error {
    47  	return nil
    48  }
    49  
    50  func (c *MysqlConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error {
    51  	output := command.MigrationOutputFromContext(ctx)
    52  
    53  	// mysql without schema
    54  	d := db.D().WithSchema("")
    55  	dialect := db.Dialect()
    56  
    57  	prevDB, err := dbFromInformationSchema(db)
    58  	if err != nil {
    59  		return err
    60  	}
    61  
    62  	exec := func(expr builder.SqlExpr) error {
    63  		if expr == nil || expr.IsNil() {
    64  			return nil
    65  		}
    66  
    67  		if output != nil {
    68  			_, _ = io.WriteString(output, builder.ResolveExpr(expr).Query())
    69  			_, _ = io.WriteString(output, "\n")
    70  			return nil
    71  		}
    72  
    73  		_, err := db.ExecExpr(expr)
    74  		return err
    75  	}
    76  
    77  	if prevDB == nil {
    78  		prevDB = &sqlx.Database{
    79  			Name: d.Name,
    80  		}
    81  
    82  		if err := exec(dialect.CreateDatabase(d.Name)); err != nil {
    83  			return err
    84  		}
    85  	}
    86  
    87  	for _, name := range d.Tables.TableNames() {
    88  		table := d.Tables.Table(name)
    89  		prevTable := prevDB.Table(name)
    90  
    91  		if prevTable == nil {
    92  			for _, expr := range dialect.CreateTableIsNotExists(table) {
    93  				if err := exec(expr); err != nil {
    94  					return err
    95  				}
    96  			}
    97  			continue
    98  		}
    99  
   100  		exprList := table.Diff(prevTable, dialect)
   101  
   102  		for _, expr := range exprList {
   103  			if err := exec(expr); err != nil {
   104  				return err
   105  			}
   106  		}
   107  	}
   108  
   109  	return nil
   110  }
   111  
   112  func (c *MysqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
   113  	d := c.Driver()
   114  
   115  	conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra))
   116  	if err != nil {
   117  		if c.IsErrorUnknownDatabase(err) {
   118  			conn, err := d.Open(dsn(c.Host, "", c.Extra))
   119  			if err != nil {
   120  				return nil, err
   121  			}
   122  			if _, err := conn.(driver.ExecerContext).ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil {
   123  				return nil, err
   124  			}
   125  			if err := conn.Close(); err != nil {
   126  				return nil, err
   127  			}
   128  			return c.Connect(ctx)
   129  		}
   130  		return nil, err
   131  	}
   132  	return conn, nil
   133  }
   134  
   135  func (c MysqlConnector) Driver() driver.Driver {
   136  	return (&MySqlLoggingDriver{}).Driver()
   137  }
   138  
   139  func (MysqlConnector) DriverName() string {
   140  	return "mysql"
   141  }
   142  
   143  func (MysqlConnector) PrimaryKeyName() string {
   144  	return "primary"
   145  }
   146  
   147  func (c MysqlConnector) IsErrorUnknownDatabase(err error) bool {
   148  	if mysqlErr, ok := sqlx.UnwrapAll(err).(*mysql.MySQLError); ok && mysqlErr.Number == 1049 {
   149  		return true
   150  	}
   151  	return false
   152  }
   153  
   154  func (c MysqlConnector) IsErrorConflict(err error) bool {
   155  	if mysqlErr, ok := sqlx.UnwrapAll(err).(*mysql.MySQLError); ok && mysqlErr.Number == 1062 {
   156  		return true
   157  	}
   158  	return false
   159  }
   160  
   161  func quoteString(name string) string {
   162  	if len(name) < 2 ||
   163  		(name[0] == '`' && name[len(name)-1] == '`') {
   164  		return name
   165  	}
   166  
   167  	return "`" + name + "`"
   168  }
   169  
   170  func (c *MysqlConnector) CreateDatabase(dbName string) builder.SqlExpr {
   171  	e := builder.Expr("CREATE DATABASE ")
   172  	e.WriteQuery(quoteString(dbName))
   173  	e.WriteEnd()
   174  	return e
   175  }
   176  
   177  func (c *MysqlConnector) CreateSchema(schema string) builder.SqlExpr {
   178  	e := builder.Expr("CREATE SCHEMA ")
   179  	e.WriteQuery(schema)
   180  	e.WriteEnd()
   181  	return e
   182  }
   183  
   184  func (c *MysqlConnector) DropDatabase(dbName string) builder.SqlExpr {
   185  	e := builder.Expr("DROP DATABASE ")
   186  	e.WriteQuery(quoteString(dbName))
   187  	e.WriteEnd()
   188  	return e
   189  }
   190  
   191  func (c *MysqlConnector) AddIndex(key *builder.Key) builder.SqlExpr {
   192  	if key.IsPrimary() {
   193  		e := builder.Expr("ALTER TABLE ")
   194  		e.WriteExpr(key.Table)
   195  		e.WriteQuery(" ADD PRIMARY KEY ")
   196  		e.WriteExpr(key.Def.TableExpr(key.Table))
   197  		e.WriteEnd()
   198  		return e
   199  	}
   200  
   201  	e := builder.Expr("CREATE ")
   202  	if key.Method == "SPATIAL" {
   203  		e.WriteQuery("SPATIAL ")
   204  	} else if key.IsUnique {
   205  		e.WriteQuery("UNIQUE ")
   206  	}
   207  	e.WriteQuery("INDEX ")
   208  
   209  	e.WriteQuery(key.Name)
   210  
   211  	if key.Method == "BTREE" || key.Method == "HASH" {
   212  		e.WriteQuery(" USING ")
   213  		e.WriteQuery(key.Method)
   214  	}
   215  
   216  	e.WriteQuery(" ON ")
   217  	e.WriteExpr(key.Table)
   218  
   219  	e.WriteQueryByte(' ')
   220  	e.WriteExpr(key.Def.TableExpr(key.Table))
   221  
   222  	e.WriteEnd()
   223  	return e
   224  }
   225  
   226  func (c *MysqlConnector) DropIndex(key *builder.Key) builder.SqlExpr {
   227  	if key.IsPrimary() {
   228  		e := builder.Expr("ALTER TABLE ")
   229  		e.WriteExpr(key.Table)
   230  		e.WriteQuery(" DROP PRIMARY KEY")
   231  		e.WriteEnd()
   232  		return e
   233  	}
   234  	e := builder.Expr("DROP ")
   235  
   236  	e.WriteQuery("INDEX ")
   237  	e.WriteQuery(key.Name)
   238  
   239  	e.WriteQuery(" ON ")
   240  	e.WriteExpr(key.Table)
   241  	e.WriteEnd()
   242  
   243  	return e
   244  }
   245  
   246  func (c *MysqlConnector) CreateTableIsNotExists(table *builder.Table) (exprs []builder.SqlExpr) {
   247  	expr := builder.Expr("CREATE TABLE IF NOT EXISTS ")
   248  	expr.WriteExpr(table)
   249  	expr.WriteQueryByte(' ')
   250  	expr.WriteGroup(func(e *builder.Ex) {
   251  		if table.Columns.IsNil() {
   252  			return
   253  		}
   254  
   255  		table.Columns.Range(func(col *builder.Column, idx int) {
   256  			if col.DeprecatedActions != nil {
   257  				return
   258  			}
   259  
   260  			if idx > 0 {
   261  				e.WriteQueryByte(',')
   262  			}
   263  			e.WriteQueryByte('\n')
   264  			e.WriteQueryByte('\t')
   265  
   266  			e.WriteExpr(col)
   267  			e.WriteQueryByte(' ')
   268  			e.WriteExpr(c.DataType(col.ColumnType))
   269  		})
   270  
   271  		table.Keys.Range(func(key *builder.Key, idx int) {
   272  			if key.IsPrimary() {
   273  				e.WriteQueryByte(',')
   274  				e.WriteQueryByte('\n')
   275  				e.WriteQueryByte('\t')
   276  				e.WriteQuery("PRIMARY KEY ")
   277  				e.WriteExpr(key.Def.TableExpr(key.Table))
   278  			}
   279  		})
   280  
   281  		expr.WriteQueryByte('\n')
   282  	})
   283  
   284  	expr.WriteQuery(" ENGINE=")
   285  
   286  	if c.Engine == "" {
   287  		expr.WriteQuery("InnoDB")
   288  	} else {
   289  		expr.WriteQuery(c.Engine)
   290  	}
   291  
   292  	expr.WriteQuery(" CHARSET=")
   293  
   294  	if c.Charset == "" {
   295  		expr.WriteQuery("utf8mb4")
   296  	} else {
   297  		expr.WriteQuery(c.Charset)
   298  	}
   299  
   300  	expr.WriteEnd()
   301  	exprs = append(exprs, expr)
   302  
   303  	table.Keys.Range(func(key *builder.Key, idx int) {
   304  		if !key.IsPrimary() {
   305  			exprs = append(exprs, c.AddIndex(key))
   306  		}
   307  	})
   308  
   309  	return
   310  }
   311  
   312  func (c *MysqlConnector) DropTable(t *builder.Table) builder.SqlExpr {
   313  	e := builder.Expr("DROP TABLE IF EXISTS ")
   314  	e.WriteQuery(t.Name)
   315  	e.WriteEnd()
   316  	return e
   317  }
   318  
   319  func (c *MysqlConnector) TruncateTable(t *builder.Table) builder.SqlExpr {
   320  	e := builder.Expr("TRUNCATE TABLE ")
   321  	e.WriteQuery(t.Name)
   322  	e.WriteEnd()
   323  	return e
   324  }
   325  
   326  func (c *MysqlConnector) AddColumn(col *builder.Column) builder.SqlExpr {
   327  	e := builder.Expr("ALTER TABLE ")
   328  	e.WriteExpr(col.Table)
   329  	e.WriteQuery(" ADD COLUMN ")
   330  	e.WriteExpr(col)
   331  	e.WriteQueryByte(' ')
   332  	e.WriteExpr(c.DataType(col.ColumnType))
   333  	e.WriteEnd()
   334  	return e
   335  }
   336  
   337  func (c *MysqlConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr {
   338  	e := builder.Expr("ALTER TABLE ")
   339  	e.WriteExpr(col.Table)
   340  	e.WriteQuery(" CHANGE ")
   341  	e.WriteExpr(col)
   342  	e.WriteQueryByte(' ')
   343  	e.WriteExpr(target)
   344  	e.WriteQueryByte(' ')
   345  	e.WriteExpr(c.DataType(target.ColumnType))
   346  	e.WriteEnd()
   347  	return e
   348  }
   349  
   350  func (c *MysqlConnector) ModifyColumn(col *builder.Column, prev *builder.Column) builder.SqlExpr {
   351  	e := builder.Expr("ALTER TABLE ")
   352  	e.WriteExpr(col.Table)
   353  	e.WriteQuery(" MODIFY COLUMN ")
   354  	e.WriteExpr(col)
   355  	e.WriteQueryByte(' ')
   356  	e.WriteExpr(c.DataType(col.ColumnType))
   357  
   358  	e.WriteQuery(" /* FROM")
   359  	e.WriteExpr(c.DataType(prev.ColumnType))
   360  	e.WriteQuery(" */")
   361  
   362  	e.WriteEnd()
   363  	return e
   364  }
   365  
   366  func (c *MysqlConnector) DropColumn(col *builder.Column) builder.SqlExpr {
   367  	e := builder.Expr("ALTER TABLE ")
   368  	e.WriteExpr(col.Table)
   369  	e.WriteQuery(" DROP COLUMN ")
   370  	e.WriteQuery(col.Name)
   371  	e.WriteEnd()
   372  	return e
   373  }
   374  
   375  func (c *MysqlConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr {
   376  	dbDataType := dealias(c.dbDataType(columnType.Type, columnType))
   377  	return builder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType))
   378  }
   379  
   380  func (c *MysqlConnector) dataType(typ typex.Type, columnType *builder.ColumnType) string {
   381  	dbDataType := dealias(c.dbDataType(typ, columnType))
   382  	return dbDataType + autocompleteSize(dbDataType, columnType)
   383  }
   384  
   385  func (c *MysqlConnector) dbDataType(typ typex.Type, columnType *builder.ColumnType) string {
   386  	if columnType.DataType != "" {
   387  		return columnType.DataType
   388  	}
   389  
   390  	if rv, ok := typex.TryNew(typ); ok {
   391  		if dtd, ok := rv.Interface().(builder.DataTypeDescriber); ok {
   392  			return dtd.DataType(c.DriverName())
   393  		}
   394  	}
   395  
   396  	switch typ.Kind() {
   397  	case reflect.Ptr:
   398  		return c.dataType(typ.Elem(), columnType)
   399  	case reflect.Bool:
   400  		return "boolean"
   401  	case reflect.Int8:
   402  		return "tinyint"
   403  	case reflect.Uint8:
   404  		return "tinyint unsigned"
   405  	case reflect.Int16:
   406  		return "smallint"
   407  	case reflect.Uint16:
   408  		return "smallint unsigned"
   409  	case reflect.Int, reflect.Int32:
   410  		return "int"
   411  	case reflect.Uint, reflect.Uint32:
   412  		return "int unsigned"
   413  	case reflect.Int64:
   414  		return "bigint"
   415  	case reflect.Uint64:
   416  		return "bigint unsigned"
   417  	case reflect.Float32:
   418  		return "float"
   419  	case reflect.Float64:
   420  		return "double"
   421  	case reflect.String:
   422  		size := columnType.Length
   423  		if size < 65535/3 {
   424  			return "varchar"
   425  		}
   426  		return "text"
   427  	case reflect.Slice:
   428  		if typ.Elem().Kind() == reflect.Uint8 {
   429  			return "mediumblob"
   430  		}
   431  	}
   432  	switch typ.Name() {
   433  	case "NullInt64":
   434  		return "bigint"
   435  	case "NullFloat64":
   436  		return "double"
   437  	case "NullBool":
   438  		return "tinyint"
   439  	case "Time":
   440  		return "datetime"
   441  	}
   442  	panic(fmt.Errorf("unsupport type %s", typ))
   443  }
   444  
   445  func (c *MysqlConnector) dataTypeModify(columnType *builder.ColumnType) string {
   446  	buf := bytes.NewBuffer(nil)
   447  
   448  	if !columnType.Null {
   449  		buf.WriteString(" NOT NULL")
   450  	}
   451  
   452  	if columnType.AutoIncrement {
   453  		buf.WriteString(" AUTO_INCREMENT")
   454  	}
   455  
   456  	if columnType.Default != nil {
   457  		buf.WriteString(" DEFAULT ")
   458  		buf.WriteString(*columnType.Default)
   459  	}
   460  
   461  	if columnType.OnUpdate != nil {
   462  		buf.WriteString(" ON UPDATE ")
   463  		buf.WriteString(*columnType.OnUpdate)
   464  	}
   465  
   466  	return buf.String()
   467  }
   468  
   469  func autocompleteSize(dataType string, columnType *builder.ColumnType) string {
   470  	switch strings.ToLower(dataType) {
   471  	case "varchar":
   472  		size := columnType.Length
   473  		if size == 0 {
   474  			size = 255
   475  		}
   476  		return sizeModifier(size, columnType.Decimal)
   477  	case "float", "double", "decimal":
   478  		if columnType.Length > 0 {
   479  			return sizeModifier(columnType.Length, columnType.Decimal)
   480  		}
   481  	}
   482  	return ""
   483  }
   484  
   485  func dealias(dataType string) string {
   486  	return dataType
   487  }
   488  
   489  func sizeModifier(length uint64, decimal uint64) string {
   490  	if length > 0 {
   491  		size := strconv.FormatUint(length, 10)
   492  		if decimal > 0 {
   493  			return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")"
   494  		}
   495  		return "(" + size + ")"
   496  	}
   497  	return ""
   498  }