gitee.com/eden-framework/sqlx@v0.0.3/mysqlconnector/mysql_connector.go (about)

     1  package mysqlconnector
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"log"
     9  	"reflect"
    10  	"strconv"
    11  
    12  	"gitee.com/eden-framework/sqlx"
    13  	"gitee.com/eden-framework/sqlx/builder"
    14  	"gitee.com/eden-framework/sqlx/migration"
    15  	"github.com/go-sql-driver/mysql"
    16  	"github.com/sirupsen/logrus"
    17  )
    18  
    19  var _ interface {
    20  	driver.Connector
    21  	builder.Dialect
    22  } = (*MysqlConnector)(nil)
    23  
    24  type MysqlConnector struct {
    25  	Host    string
    26  	DBName  string
    27  	Extra   string
    28  	Engine  string
    29  	Charset string
    30  }
    31  
    32  func dsn(host string, dbName string, extra string) string {
    33  	if extra != "" {
    34  		extra = "?" + extra
    35  	}
    36  	return host + "/" + dbName + extra
    37  }
    38  
    39  func (c MysqlConnector) WithDBName(dbName string) driver.Connector {
    40  	c.DBName = dbName
    41  	return &c
    42  }
    43  
    44  func (c *MysqlConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error {
    45  	opts := migration.MigrationOptsFromContext(ctx)
    46  
    47  	// mysql without schema
    48  	d := db.D().WithSchema("")
    49  	dialect := db.Dialect()
    50  
    51  	prevDB := dbFromInformationSchema(db)
    52  
    53  	if prevDB == nil {
    54  		prevDB = &sqlx.Database{
    55  			Name: d.Name,
    56  		}
    57  		if _, err := db.ExecExpr(dialect.CreateDatabase(d.Name)); err != nil {
    58  			return err
    59  		}
    60  	}
    61  
    62  	for _, name := range d.Tables.TableNames() {
    63  		table := d.Tables.Table(name)
    64  		prevTable := prevDB.Table(name)
    65  
    66  		if prevTable == nil {
    67  			for _, expr := range dialect.CreateTableIsNotExists(table) {
    68  				if _, err := db.ExecExpr(expr); err != nil {
    69  					return err
    70  				}
    71  			}
    72  			continue
    73  		}
    74  
    75  		exprList := table.Diff(prevTable, dialect)
    76  
    77  		for _, expr := range exprList {
    78  			if !expr.IsNil() {
    79  				if opts.DryRun {
    80  					log.Printf(builder.ResolveExpr(expr).Query())
    81  				} else {
    82  					if _, err := db.ExecExpr(expr); err != nil {
    83  						return err
    84  					}
    85  				}
    86  			}
    87  		}
    88  	}
    89  
    90  	return nil
    91  }
    92  
    93  func (c *MysqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
    94  	d := c.Driver()
    95  	conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra))
    96  	if err != nil {
    97  		if c.IsErrorUnknownDatabase(err) {
    98  			conn, err := d.Open(dsn(c.Host, "", c.Extra))
    99  			if err != nil {
   100  				return nil, err
   101  			}
   102  			if _, err := conn.(driver.ExecerContext).
   103  				ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil {
   104  				return nil, err
   105  			}
   106  			if err := conn.Close(); err != nil {
   107  				return nil, err
   108  			}
   109  			return c.Connect(ctx)
   110  		}
   111  		return nil, err
   112  	}
   113  	return conn, nil
   114  }
   115  
   116  func (MysqlConnector) Driver() driver.Driver {
   117  	return &MySqlLoggingDriver{Driver: &mysql.MySQLDriver{}, Logger: logrus.StandardLogger()}
   118  }
   119  
   120  func (MysqlConnector) DriverName() string {
   121  	return "mysql"
   122  }
   123  
   124  func (MysqlConnector) PrimaryKeyName() string {
   125  	return "primary"
   126  }
   127  
   128  func (c MysqlConnector) IsErrorUnknownDatabase(err error) bool {
   129  	if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == 1049 {
   130  		return true
   131  	}
   132  	return false
   133  }
   134  
   135  func (c MysqlConnector) IsErrorConflict(err error) bool {
   136  	if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == 1062 {
   137  		return true
   138  	}
   139  	return false
   140  }
   141  
   142  func (c *MysqlConnector) CreateDatabase(dbName string) builder.SqlExpr {
   143  	e := builder.Expr("CREATE DATABASE ")
   144  	e.WriteString(dbName)
   145  	e.WriteEnd()
   146  	return e
   147  }
   148  
   149  func (c *MysqlConnector) CreateSchema(schema string) builder.SqlExpr {
   150  	e := builder.Expr("CREATE SCHEMA ")
   151  	e.WriteString(schema)
   152  	e.WriteEnd()
   153  	return e
   154  }
   155  
   156  func (c *MysqlConnector) DropDatabase(dbName string) builder.SqlExpr {
   157  	e := builder.Expr("DROP DATABASE ")
   158  	e.WriteString(dbName)
   159  	e.WriteEnd()
   160  	return e
   161  }
   162  
   163  func (c *MysqlConnector) AddIndex(key *builder.Key) builder.SqlExpr {
   164  	if key.IsPrimary() {
   165  		e := builder.Expr("ALTER TABLE ")
   166  		e.WriteExpr(key.Table)
   167  		e.WriteString(" ADD PRIMARY KEY ")
   168  		e.WriteGroup(func(e *builder.Ex) {
   169  			e.WriteExpr(key.Columns)
   170  		})
   171  		e.WriteEnd()
   172  		return e
   173  	}
   174  
   175  	e := builder.Expr("CREATE ")
   176  	if key.Method == "SPATIAL" {
   177  		e.WriteString("SPATIAL ")
   178  	} else if key.IsUnique {
   179  		e.WriteString("UNIQUE ")
   180  	}
   181  	e.WriteString("INDEX ")
   182  
   183  	e.WriteString(key.Name)
   184  
   185  	e.WriteString(" ON ")
   186  	e.WriteExpr(key.Table)
   187  	e.WriteByte(' ')
   188  	e.WriteGroup(func(e *builder.Ex) {
   189  		e.WriteExpr(key.Columns)
   190  	})
   191  
   192  	if key.Method == "BTREE" || key.Method == "HASH" {
   193  		e.WriteString(" USING ")
   194  		e.WriteString(key.Method)
   195  	}
   196  
   197  	e.WriteEnd()
   198  	return e
   199  }
   200  
   201  func (c *MysqlConnector) DropIndex(key *builder.Key) builder.SqlExpr {
   202  	if key.IsPrimary() {
   203  		e := builder.Expr("ALTER TABLE ")
   204  		e.WriteExpr(key.Table)
   205  		e.WriteString(" DROP PRIMARY KEY")
   206  		e.WriteEnd()
   207  		return e
   208  	}
   209  	e := builder.Expr("DROP ")
   210  
   211  	e.WriteString("INDEX ")
   212  	e.WriteString(key.Name)
   213  
   214  	e.WriteString(" ON ")
   215  	e.WriteExpr(key.Table)
   216  	e.WriteEnd()
   217  
   218  	return e
   219  }
   220  
   221  func (c *MysqlConnector) CreateTableIsNotExists(table *builder.Table) (exprs []builder.SqlExpr) {
   222  	expr := builder.Expr("CREATE TABLE IF NOT EXISTS ")
   223  	expr.WriteExpr(table)
   224  	expr.WriteByte(' ')
   225  	expr.WriteGroup(func(e *builder.Ex) {
   226  		if table.Columns.IsNil() {
   227  			return
   228  		}
   229  
   230  		table.Columns.Range(func(col *builder.Column, idx int) {
   231  			if col.DeprecatedActions != nil {
   232  				return
   233  			}
   234  
   235  			if idx > 0 {
   236  				e.WriteByte(',')
   237  			}
   238  			e.WriteByte('\n')
   239  			e.WriteByte('\t')
   240  
   241  			e.WriteExpr(col)
   242  			e.WriteByte(' ')
   243  			e.WriteExpr(c.DataType(col.ColumnType))
   244  		})
   245  
   246  		table.Keys.Range(func(key *builder.Key, idx int) {
   247  			if key.IsPrimary() {
   248  				e.WriteByte(',')
   249  				e.WriteByte('\n')
   250  				e.WriteByte('\t')
   251  				e.WriteString("PRIMARY KEY ")
   252  				e.WriteGroup(func(e *builder.Ex) {
   253  					e.WriteExpr(key.Columns)
   254  				})
   255  			}
   256  		})
   257  
   258  		expr.WriteByte('\n')
   259  	})
   260  
   261  	expr.WriteString(" ENGINE=")
   262  
   263  	if c.Engine == "" {
   264  		expr.WriteString("InnoDB")
   265  	} else {
   266  		expr.WriteString(c.Engine)
   267  	}
   268  
   269  	expr.WriteString(" CHARSET=")
   270  
   271  	if c.Charset == "" {
   272  		expr.WriteString("utf8mb4")
   273  	} else {
   274  		expr.WriteString(c.Charset)
   275  	}
   276  
   277  	expr.WriteEnd()
   278  	exprs = append(exprs, expr)
   279  
   280  	table.Keys.Range(func(key *builder.Key, idx int) {
   281  		if !key.IsPrimary() {
   282  			exprs = append(exprs, c.AddIndex(key))
   283  		}
   284  	})
   285  
   286  	return
   287  }
   288  
   289  func (c *MysqlConnector) DropTable(t *builder.Table) builder.SqlExpr {
   290  	e := builder.Expr("DROP TABLE IF EXISTS ")
   291  	e.WriteString(t.Name)
   292  	e.WriteEnd()
   293  	return e
   294  }
   295  
   296  func (c *MysqlConnector) TruncateTable(t *builder.Table) builder.SqlExpr {
   297  	e := builder.Expr("TRUNCATE TABLE ")
   298  	e.WriteString(t.Name)
   299  	e.WriteEnd()
   300  	return e
   301  }
   302  
   303  func (c *MysqlConnector) AddColumn(col *builder.Column) builder.SqlExpr {
   304  	e := builder.Expr("ALTER TABLE ")
   305  	e.WriteExpr(col.Table)
   306  	e.WriteString(" ADD COLUMN ")
   307  	e.WriteExpr(col)
   308  	e.WriteByte(' ')
   309  	e.WriteExpr(c.DataType(col.ColumnType))
   310  	e.WriteEnd()
   311  	return e
   312  }
   313  
   314  func (c *MysqlConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr {
   315  	e := builder.Expr("ALTER TABLE ")
   316  	e.WriteExpr(col.Table)
   317  	e.WriteString(" CHANGE ")
   318  	e.WriteExpr(col)
   319  	e.WriteByte(' ')
   320  	e.WriteExpr(target)
   321  	e.WriteByte(' ')
   322  	e.WriteExpr(c.DataType(target.ColumnType))
   323  	e.WriteEnd()
   324  	return e
   325  }
   326  
   327  func (c *MysqlConnector) ModifyColumn(col *builder.Column) builder.SqlExpr {
   328  	e := builder.Expr("ALTER TABLE ")
   329  	e.WriteExpr(col.Table)
   330  	e.WriteString(" MODIFY COLUMN ")
   331  	e.WriteExpr(col)
   332  	e.WriteByte(' ')
   333  	e.WriteExpr(c.DataType(col.ColumnType))
   334  	e.WriteEnd()
   335  	return e
   336  }
   337  
   338  func (c *MysqlConnector) DropColumn(col *builder.Column) builder.SqlExpr {
   339  	e := builder.Expr("ALTER TABLE ")
   340  	e.WriteExpr(col.Table)
   341  	e.WriteString(" DROP COLUMN ")
   342  	e.WriteString(col.Name)
   343  	e.WriteEnd()
   344  	return e
   345  }
   346  
   347  func (c *MysqlConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr {
   348  	return builder.Expr(c.dataType(columnType.Type, columnType) + c.dataTypeModify(columnType))
   349  }
   350  
   351  func (c *MysqlConnector) dataType(typ reflect.Type, columnType *builder.ColumnType) string {
   352  	if columnType.GetDataType != nil {
   353  		return columnType.GetDataType(c.DriverName())
   354  	}
   355  
   356  	switch typ.Kind() {
   357  	case reflect.Ptr:
   358  		return c.dataType(typ.Elem(), columnType)
   359  	case reflect.Bool:
   360  		return "boolean"
   361  	case reflect.Int8:
   362  		return "tinyint"
   363  	case reflect.Uint8:
   364  		return "tinyint unsigned"
   365  	case reflect.Int16:
   366  		return "smallint"
   367  	case reflect.Uint16:
   368  		return "smallint unsigned"
   369  	case reflect.Int, reflect.Int32:
   370  		return "int"
   371  	case reflect.Uint, reflect.Uint32:
   372  		return "int unsigned"
   373  	case reflect.Int64:
   374  		return "bigint"
   375  	case reflect.Uint64:
   376  		return "bigint unsigned"
   377  	case reflect.Float32:
   378  		return "float" + sizeModifier(columnType.Length, columnType.Decimal)
   379  	case reflect.Float64:
   380  		return "double" + sizeModifier(columnType.Length, columnType.Decimal)
   381  	case reflect.String:
   382  		size := columnType.Length
   383  		if size == 0 {
   384  			size = 255
   385  		}
   386  		if size < 65535/3 {
   387  			return "varchar" + sizeModifier(size, 0)
   388  		}
   389  		return "text"
   390  	case reflect.Slice:
   391  		if typ.Elem().Kind() == reflect.Uint8 {
   392  			return "mediumblob"
   393  		}
   394  	}
   395  	switch typ.Name() {
   396  	case "NullInt64":
   397  		return "bigint"
   398  	case "NullFloat64":
   399  		return "double"
   400  	case "NullBool":
   401  		return "tinyint"
   402  	case "Time":
   403  		return "datetime"
   404  	}
   405  	panic(fmt.Errorf("unsupport type %s", typ))
   406  }
   407  
   408  func (c *MysqlConnector) dataTypeModify(columnType *builder.ColumnType) string {
   409  	buf := bytes.NewBuffer(nil)
   410  
   411  	if !columnType.Null {
   412  		buf.WriteString(" NOT NULL")
   413  	}
   414  
   415  	if columnType.AutoIncrement {
   416  		buf.WriteString(" AUTO_INCREMENT")
   417  	}
   418  
   419  	if columnType.Default != nil {
   420  		buf.WriteString(" DEFAULT ")
   421  		buf.WriteString(*columnType.Default)
   422  	}
   423  
   424  	if columnType.Comment != "" {
   425  		buf.WriteString(" COMMENT ")
   426  		buf.WriteString(strconv.Quote(columnType.Comment))
   427  	}
   428  
   429  	return buf.String()
   430  }
   431  
   432  func sizeModifier(length uint64, decimal uint64) string {
   433  	if length > 0 {
   434  		size := strconv.FormatUint(length, 10)
   435  		if decimal > 0 {
   436  			return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")"
   437  		}
   438  		return "(" + size + ")"
   439  	}
   440  	return ""
   441  }