github.com/go-courier/sqlx/v2@v2.23.13/connectors/mysql/mysql_connector.go (about) 1 package mysql 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql/driver" 7 "fmt" 8 "io" 9 "reflect" 10 "strconv" 11 "strings" 12 13 typex "github.com/go-courier/x/types" 14 15 "github.com/go-courier/sqlx/v2" 16 "github.com/go-courier/sqlx/v2/builder" 17 "github.com/go-courier/sqlx/v2/migration" 18 "github.com/go-sql-driver/mysql" 19 ) 20 21 var _ interface { 22 driver.Connector 23 builder.Dialect 24 } = (*MysqlConnector)(nil) 25 26 type MysqlConnector struct { 27 Host string 28 DBName string 29 Extra string 30 Engine string 31 Charset string 32 } 33 34 func dsn(host string, dbName string, extra string) string { 35 if extra != "" { 36 extra = "?" + extra 37 } 38 return host + "/" + dbName + extra 39 } 40 41 func (c MysqlConnector) WithDBName(dbName string) driver.Connector { 42 c.DBName = dbName 43 return &c 44 } 45 46 func (c *MysqlConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error { 47 output := migration.MigrationOutputFromContext(ctx) 48 49 // mysql without schema 50 d := db.D().WithSchema("") 51 dialect := db.Dialect() 52 53 prevDB, err := dbFromInformationSchema(db) 54 if err != nil { 55 return err 56 } 57 58 exec := func(expr builder.SqlExpr) error { 59 if expr == nil || expr.IsNil() { 60 return nil 61 } 62 63 if output != nil { 64 _, _ = io.WriteString(output, builder.ResolveExpr(expr).Query()) 65 _, _ = io.WriteString(output, "\n") 66 return nil 67 } 68 69 _, err := db.ExecExpr(expr) 70 return err 71 } 72 73 if prevDB == nil { 74 prevDB = &sqlx.Database{ 75 Name: d.Name, 76 } 77 78 if err := exec(dialect.CreateDatabase(d.Name)); err != nil { 79 return err 80 } 81 } 82 83 for _, name := range d.Tables.TableNames() { 84 table := d.Tables.Table(name) 85 prevTable := prevDB.Table(name) 86 87 if prevTable == nil { 88 for _, expr := range dialect.CreateTableIsNotExists(table) { 89 if err := exec(expr); err != nil { 90 return err 91 } 92 } 93 continue 94 } 95 96 exprList := table.Diff(prevTable, dialect) 97 98 for _, expr := range exprList { 99 if err := exec(expr); err != nil { 100 return err 101 } 102 } 103 } 104 105 return nil 106 } 107 108 func (c *MysqlConnector) Connect(ctx context.Context) (driver.Conn, error) { 109 d := c.Driver() 110 111 conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra)) 112 if err != nil { 113 if c.IsErrorUnknownDatabase(err) { 114 conn, err := d.Open(dsn(c.Host, "", c.Extra)) 115 if err != nil { 116 return nil, err 117 } 118 if _, err := conn.(driver.ExecerContext).ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil { 119 return nil, err 120 } 121 if err := conn.Close(); err != nil { 122 return nil, err 123 } 124 return c.Connect(ctx) 125 } 126 return nil, err 127 } 128 return conn, nil 129 } 130 131 func (c MysqlConnector) Driver() driver.Driver { 132 return (&MySqlLoggingDriver{}).Driver() 133 } 134 135 func (MysqlConnector) DriverName() string { 136 return "mysql" 137 } 138 139 func (MysqlConnector) PrimaryKeyName() string { 140 return "primary" 141 } 142 143 func (c MysqlConnector) IsErrorUnknownDatabase(err error) bool { 144 if mysqlErr, ok := sqlx.UnwrapAll(err).(*mysql.MySQLError); ok && mysqlErr.Number == 1049 { 145 return true 146 } 147 return false 148 } 149 150 func (c MysqlConnector) IsErrorConflict(err error) bool { 151 if mysqlErr, ok := sqlx.UnwrapAll(err).(*mysql.MySQLError); ok && mysqlErr.Number == 1062 { 152 return true 153 } 154 return false 155 } 156 157 func quoteString(name string) string { 158 if len(name) < 2 || 159 (name[0] == '`' && name[len(name)-1] == '`') { 160 return name 161 } 162 163 return "`" + name + "`" 164 } 165 166 func (c *MysqlConnector) CreateDatabase(dbName string) builder.SqlExpr { 167 e := builder.Expr("CREATE DATABASE ") 168 e.WriteQuery(quoteString(dbName)) 169 e.WriteEnd() 170 return e 171 } 172 173 func (c *MysqlConnector) CreateSchema(schema string) builder.SqlExpr { 174 e := builder.Expr("CREATE SCHEMA ") 175 e.WriteQuery(schema) 176 e.WriteEnd() 177 return e 178 } 179 180 func (c *MysqlConnector) DropDatabase(dbName string) builder.SqlExpr { 181 e := builder.Expr("DROP DATABASE ") 182 e.WriteQuery(quoteString(dbName)) 183 e.WriteEnd() 184 return e 185 } 186 187 func (c *MysqlConnector) AddIndex(key *builder.Key) builder.SqlExpr { 188 if key.IsPrimary() { 189 e := builder.Expr("ALTER TABLE ") 190 e.WriteExpr(key.Table) 191 e.WriteQuery(" ADD PRIMARY KEY ") 192 e.WriteExpr(key.Def.TableExpr(key.Table)) 193 e.WriteEnd() 194 return e 195 } 196 197 e := builder.Expr("CREATE ") 198 if key.Method == "SPATIAL" { 199 e.WriteQuery("SPATIAL ") 200 } else if key.IsUnique { 201 e.WriteQuery("UNIQUE ") 202 } 203 e.WriteQuery("INDEX ") 204 205 e.WriteQuery(key.Name) 206 207 if key.Method == "BTREE" || key.Method == "HASH" { 208 e.WriteQuery(" USING ") 209 e.WriteQuery(key.Method) 210 } 211 212 e.WriteQuery(" ON ") 213 e.WriteExpr(key.Table) 214 215 e.WriteQueryByte(' ') 216 e.WriteExpr(key.Def.TableExpr(key.Table)) 217 218 e.WriteEnd() 219 return e 220 } 221 222 func (c *MysqlConnector) DropIndex(key *builder.Key) builder.SqlExpr { 223 if key.IsPrimary() { 224 e := builder.Expr("ALTER TABLE ") 225 e.WriteExpr(key.Table) 226 e.WriteQuery(" DROP PRIMARY KEY") 227 e.WriteEnd() 228 return e 229 } 230 e := builder.Expr("DROP ") 231 232 e.WriteQuery("INDEX ") 233 e.WriteQuery(key.Name) 234 235 e.WriteQuery(" ON ") 236 e.WriteExpr(key.Table) 237 e.WriteEnd() 238 239 return e 240 } 241 242 func (c *MysqlConnector) CreateTableIsNotExists(table *builder.Table) (exprs []builder.SqlExpr) { 243 expr := builder.Expr("CREATE TABLE IF NOT EXISTS ") 244 expr.WriteExpr(table) 245 expr.WriteQueryByte(' ') 246 expr.WriteGroup(func(e *builder.Ex) { 247 if table.Columns.IsNil() { 248 return 249 } 250 251 table.Columns.Range(func(col *builder.Column, idx int) { 252 if col.DeprecatedActions != nil { 253 return 254 } 255 256 if idx > 0 { 257 e.WriteQueryByte(',') 258 } 259 e.WriteQueryByte('\n') 260 e.WriteQueryByte('\t') 261 262 e.WriteExpr(col) 263 e.WriteQueryByte(' ') 264 e.WriteExpr(c.DataType(col.ColumnType)) 265 }) 266 267 table.Keys.Range(func(key *builder.Key, idx int) { 268 if key.IsPrimary() { 269 e.WriteQueryByte(',') 270 e.WriteQueryByte('\n') 271 e.WriteQueryByte('\t') 272 e.WriteQuery("PRIMARY KEY ") 273 e.WriteExpr(key.Def.TableExpr(key.Table)) 274 } 275 }) 276 277 expr.WriteQueryByte('\n') 278 }) 279 280 expr.WriteQuery(" ENGINE=") 281 282 if c.Engine == "" { 283 expr.WriteQuery("InnoDB") 284 } else { 285 expr.WriteQuery(c.Engine) 286 } 287 288 expr.WriteQuery(" CHARSET=") 289 290 if c.Charset == "" { 291 expr.WriteQuery("utf8mb4") 292 } else { 293 expr.WriteQuery(c.Charset) 294 } 295 296 expr.WriteEnd() 297 exprs = append(exprs, expr) 298 299 table.Keys.Range(func(key *builder.Key, idx int) { 300 if !key.IsPrimary() { 301 exprs = append(exprs, c.AddIndex(key)) 302 } 303 }) 304 305 return 306 } 307 308 func (c *MysqlConnector) DropTable(t *builder.Table) builder.SqlExpr { 309 e := builder.Expr("DROP TABLE IF EXISTS ") 310 e.WriteQuery(t.Name) 311 e.WriteEnd() 312 return e 313 } 314 315 func (c *MysqlConnector) TruncateTable(t *builder.Table) builder.SqlExpr { 316 e := builder.Expr("TRUNCATE TABLE ") 317 e.WriteQuery(t.Name) 318 e.WriteEnd() 319 return e 320 } 321 322 func (c *MysqlConnector) AddColumn(col *builder.Column) builder.SqlExpr { 323 e := builder.Expr("ALTER TABLE ") 324 e.WriteExpr(col.Table) 325 e.WriteQuery(" ADD COLUMN ") 326 e.WriteExpr(col) 327 e.WriteQueryByte(' ') 328 e.WriteExpr(c.DataType(col.ColumnType)) 329 e.WriteEnd() 330 return e 331 } 332 333 func (c *MysqlConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr { 334 e := builder.Expr("ALTER TABLE ") 335 e.WriteExpr(col.Table) 336 e.WriteQuery(" CHANGE ") 337 e.WriteExpr(col) 338 e.WriteQueryByte(' ') 339 e.WriteExpr(target) 340 e.WriteQueryByte(' ') 341 e.WriteExpr(c.DataType(target.ColumnType)) 342 e.WriteEnd() 343 return e 344 } 345 346 func (c *MysqlConnector) ModifyColumn(col *builder.Column, prev *builder.Column) builder.SqlExpr { 347 e := builder.Expr("ALTER TABLE ") 348 e.WriteExpr(col.Table) 349 e.WriteQuery(" MODIFY COLUMN ") 350 e.WriteExpr(col) 351 e.WriteQueryByte(' ') 352 e.WriteExpr(c.DataType(col.ColumnType)) 353 354 e.WriteQuery(" /* FROM") 355 e.WriteExpr(c.DataType(prev.ColumnType)) 356 e.WriteQuery(" */") 357 358 e.WriteEnd() 359 return e 360 } 361 362 func (c *MysqlConnector) DropColumn(col *builder.Column) builder.SqlExpr { 363 e := builder.Expr("ALTER TABLE ") 364 e.WriteExpr(col.Table) 365 e.WriteQuery(" DROP COLUMN ") 366 e.WriteQuery(col.Name) 367 e.WriteEnd() 368 return e 369 } 370 371 func (c *MysqlConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr { 372 dbDataType := dealias(c.dbDataType(columnType.Type, columnType)) 373 return builder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType)) 374 } 375 376 func (c *MysqlConnector) dataType(typ typex.Type, columnType *builder.ColumnType) string { 377 dbDataType := dealias(c.dbDataType(typ, columnType)) 378 return dbDataType + autocompleteSize(dbDataType, columnType) 379 } 380 381 func (c *MysqlConnector) dbDataType(typ typex.Type, columnType *builder.ColumnType) string { 382 if columnType.DataType != "" { 383 return columnType.DataType 384 } 385 386 if rv, ok := typex.TryNew(typ); ok { 387 if dtd, ok := rv.Interface().(builder.DataTypeDescriber); ok { 388 return dtd.DataType(c.DriverName()) 389 } 390 } 391 392 switch typ.Kind() { 393 case reflect.Ptr: 394 return c.dataType(typ.Elem(), columnType) 395 case reflect.Bool: 396 return "boolean" 397 case reflect.Int8: 398 return "tinyint" 399 case reflect.Uint8: 400 return "tinyint unsigned" 401 case reflect.Int16: 402 return "smallint" 403 case reflect.Uint16: 404 return "smallint unsigned" 405 case reflect.Int, reflect.Int32: 406 return "int" 407 case reflect.Uint, reflect.Uint32: 408 return "int unsigned" 409 case reflect.Int64: 410 return "bigint" 411 case reflect.Uint64: 412 return "bigint unsigned" 413 case reflect.Float32: 414 return "float" 415 case reflect.Float64: 416 return "double" 417 case reflect.String: 418 size := columnType.Length 419 if size < 65535/3 { 420 return "varchar" 421 } 422 return "text" 423 case reflect.Slice: 424 if typ.Elem().Kind() == reflect.Uint8 { 425 return "mediumblob" 426 } 427 } 428 switch typ.Name() { 429 case "NullInt64": 430 return "bigint" 431 case "NullFloat64": 432 return "double" 433 case "NullBool": 434 return "tinyint" 435 case "Time": 436 return "datetime" 437 } 438 panic(fmt.Errorf("unsupport type %s", typ)) 439 } 440 441 func (c *MysqlConnector) dataTypeModify(columnType *builder.ColumnType) string { 442 buf := bytes.NewBuffer(nil) 443 444 if !columnType.Null { 445 buf.WriteString(" NOT NULL") 446 } 447 448 if columnType.AutoIncrement { 449 buf.WriteString(" AUTO_INCREMENT") 450 } 451 452 if columnType.Default != nil { 453 buf.WriteString(" DEFAULT ") 454 buf.WriteString(*columnType.Default) 455 } 456 457 if columnType.OnUpdate != nil { 458 buf.WriteString(" ON UPDATE ") 459 buf.WriteString(*columnType.OnUpdate) 460 } 461 462 return buf.String() 463 } 464 465 func autocompleteSize(dataType string, columnType *builder.ColumnType) string { 466 switch strings.ToLower(dataType) { 467 case "varchar": 468 size := columnType.Length 469 if size == 0 { 470 size = 255 471 } 472 return sizeModifier(size, columnType.Decimal) 473 case "float", "double", "decimal": 474 if columnType.Length > 0 { 475 return sizeModifier(columnType.Length, columnType.Decimal) 476 } 477 } 478 return "" 479 } 480 481 func dealias(dataType string) string { 482 return dataType 483 } 484 485 func sizeModifier(length uint64, decimal uint64) string { 486 if length > 0 { 487 size := strconv.FormatUint(length, 10) 488 if decimal > 0 { 489 return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")" 490 } 491 return "(" + size + ")" 492 } 493 return "" 494 }