github.com/unionj-cloud/go-doudou@v1.3.8-0.20221011095552-0088008e5b31/cmd/internal/ddl/table/ddl.go (about) 1 package table 2 3 import ( 4 "context" 5 "fmt" 6 mapset "github.com/deckarep/golang-set" 7 "github.com/iancoleman/strcase" 8 "github.com/jmoiron/sqlx" 9 "github.com/kelseyhightower/envconfig" 10 "github.com/pkg/errors" 11 "github.com/rs/zerolog" 12 "github.com/testcontainers/testcontainers-go" 13 "github.com/testcontainers/testcontainers-go/wait" 14 "github.com/unionj-cloud/go-doudou/cmd/internal/astutils" 15 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/columnenum" 16 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/config" 17 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/ddlast" 18 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/extraenum" 19 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/sortenum" 20 "github.com/unionj-cloud/go-doudou/toolkit/caller" 21 "github.com/unionj-cloud/go-doudou/toolkit/pathutils" 22 "github.com/unionj-cloud/go-doudou/toolkit/sliceutils" 23 "github.com/unionj-cloud/go-doudou/toolkit/sqlext/wrapper" 24 "github.com/unionj-cloud/go-doudou/toolkit/stringutils" 25 "github.com/unionj-cloud/go-doudou/toolkit/zlogger" 26 "go/ast" 27 "go/parser" 28 "go/token" 29 "os" 30 "path/filepath" 31 "reflect" 32 "strings" 33 "time" 34 ) 35 36 // CreateTable create table from Table 37 func CreateTable(ctx context.Context, db wrapper.Querier, t Table) error { 38 var ( 39 statement string 40 err error 41 ) 42 if statement, err = t.CreateSql(); err != nil { 43 return err 44 } 45 fmt.Println(statement) 46 if _, err = db.ExecContext(ctx, statement); err != nil { 47 return err 48 } 49 return err 50 } 51 52 // ChangeColumn change a column definition by Column 53 func ChangeColumn(ctx context.Context, db wrapper.Querier, col Column) error { 54 var ( 55 statement string 56 err error 57 ) 58 if statement, err = col.ChangeColumnSql(); err != nil { 59 return err 60 } 61 fmt.Println(statement) 62 if _, err = db.ExecContext(ctx, statement); err != nil { 63 return err 64 } 65 return err 66 } 67 68 // AddColumn add a column by Column 69 func AddColumn(ctx context.Context, db wrapper.Querier, col Column) error { 70 var ( 71 statement string 72 err error 73 ) 74 if statement, err = col.AddColumnSql(); err != nil { 75 return err 76 } 77 fmt.Println(statement) 78 if _, err = db.ExecContext(ctx, statement); err != nil { 79 return err 80 } 81 return err 82 } 83 84 // dropAddIndex drop and then add an existing index with the same key_name 85 func dropAddIndex(ctx context.Context, db wrapper.Querier, idx Index) error { 86 var err error 87 if err = dropIndex(ctx, db, idx); err != nil { 88 return errors.Wrap(err, caller.NewCaller().String()) 89 } 90 if err = addIndex(ctx, db, idx); err != nil { 91 return errors.Wrap(err, caller.NewCaller().String()) 92 } 93 return nil 94 } 95 96 // addIndex add a new index 97 func addIndex(ctx context.Context, db wrapper.Querier, idx Index) error { 98 var ( 99 statement string 100 err error 101 ) 102 if statement, err = idx.AddIndexSql(); err != nil { 103 return errors.Wrap(err, caller.NewCaller().String()) 104 } 105 fmt.Println(statement) 106 if _, err = db.ExecContext(ctx, statement); err != nil { 107 return errors.Wrap(err, caller.NewCaller().String()) 108 } 109 return nil 110 } 111 112 // dropIndex drop an existing index 113 func dropIndex(ctx context.Context, db wrapper.Querier, idx Index) error { 114 var ( 115 statement string 116 err error 117 ) 118 if statement, err = idx.DropIndexSql(); err != nil { 119 return errors.Wrap(err, caller.NewCaller().String()) 120 } 121 fmt.Println(statement) 122 if _, err = db.ExecContext(ctx, statement); err != nil { 123 return errors.Wrap(err, caller.NewCaller().String()) 124 } 125 return nil 126 } 127 128 // dropAddFk drop and then add an existing foreign key with the same constraint 129 func dropAddFk(ctx context.Context, db wrapper.Querier, fk ForeignKey) error { 130 var err error 131 if err = dropFk(ctx, db, fk); err != nil { 132 return errors.Wrap(err, caller.NewCaller().String()) 133 } 134 if err = addFk(ctx, db, fk); err != nil { 135 return errors.Wrap(err, caller.NewCaller().String()) 136 } 137 return nil 138 } 139 140 // addFk add a new foreign key 141 func addFk(ctx context.Context, db wrapper.Querier, fk ForeignKey) error { 142 var ( 143 statement string 144 err error 145 ) 146 if statement, err = fk.AddFkSql(); err != nil { 147 return errors.Wrap(err, caller.NewCaller().String()) 148 } 149 fmt.Println(statement) 150 if _, err = db.ExecContext(ctx, statement); err != nil { 151 return errors.Wrap(err, caller.NewCaller().String()) 152 } 153 return nil 154 } 155 156 // dropFk drop an existing foreign key 157 func dropFk(ctx context.Context, db wrapper.Querier, fk ForeignKey) error { 158 var ( 159 statement string 160 err error 161 ) 162 if statement, err = fk.DropFkSql(); err != nil { 163 return errors.Wrap(err, caller.NewCaller().String()) 164 } 165 fmt.Println(statement) 166 if _, err = db.ExecContext(ctx, statement); err != nil { 167 return errors.Wrap(err, caller.NewCaller().String()) 168 } 169 return nil 170 } 171 172 func Table2struct(ctx context.Context, pre, schema string, existTables []string, db *sqlx.DB) (tables []Table) { 173 var err error 174 for _, t := range existTables { 175 if stringutils.IsNotEmpty(pre) && !strings.HasPrefix(t, pre) { 176 continue 177 } 178 var dbIndice []DbIndex 179 if err = db.SelectContext(ctx, &dbIndice, fmt.Sprintf("SHOW INDEXES FROM %s", t)); err != nil { 180 panic(errors.Wrap(err, caller.NewCaller().String())) 181 } 182 183 idxMap := make(map[string][]DbIndex) 184 185 for _, idx := range dbIndice { 186 if val, exists := idxMap[idx.KeyName]; exists { 187 val = append(val, idx) 188 idxMap[idx.KeyName] = val 189 } else { 190 idxMap[idx.KeyName] = []DbIndex{ 191 idx, 192 } 193 } 194 } 195 196 indexes, colIdxMap := idxListAndMap(idxMap) 197 198 var columns []DbColumn 199 if err = db.SelectContext(ctx, &columns, fmt.Sprintf("SHOW FULL COLUMNS FROM %s", t)); err != nil { 200 panic(errors.Wrap(err, caller.NewCaller().String())) 201 } 202 203 fks := foreignKeys(ctx, db, schema, t) 204 fkMap := make(map[string]ForeignKey) 205 for _, item := range fks { 206 fkMap[item.Fk] = item 207 } 208 209 var cols []Column 210 var fields []astutils.FieldMeta 211 for _, item := range columns { 212 col := dbColumn2Column(item, colIdxMap, t, fkMap[item.Field]) 213 fields = append(fields, col.Meta) 214 cols = append(cols, col) 215 } 216 217 domain := astutils.StructMeta{ 218 Name: strcase.ToCamel(strings.TrimPrefix(t, pre)), 219 Fields: fields, 220 } 221 222 var pkColumn Column 223 for _, column := range cols { 224 if column.Pk { 225 pkColumn = column 226 break 227 } 228 } 229 230 tables = append(tables, Table{ 231 Name: t, 232 Columns: cols, 233 Pk: pkColumn.Name, 234 Indexes: indexes, 235 Meta: domain, 236 Fks: fks, 237 }) 238 } 239 return 240 } 241 242 func foreignKeys(ctx context.Context, db wrapper.Querier, schema, t string) (fks []ForeignKey) { 243 var ( 244 dbForeignKeys []DbForeignKey 245 err error 246 ) 247 rawSql := ` 248 SELECT TABLE_NAME,COLUMN_NAME,CONSTRAINT_NAME, REFERENCED_TABLE_NAME,REFERENCED_COLUMN_NAME 249 FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE 250 WHERE TABLE_SCHEMA = ? AND REFERENCED_TABLE_SCHEMA = ? AND TABLE_NAME = ? 251 ` 252 if err = db.SelectContext(ctx, &dbForeignKeys, db.Rebind(rawSql), schema, schema, t); err != nil { 253 panic(errors.Wrap(err, caller.NewCaller().String())) 254 } 255 for _, item := range dbForeignKeys { 256 var ( 257 dbActions []DbAction 258 dbAction DbAction 259 ) 260 rawSql = ` 261 select CONSTRAINT_NAME, UPDATE_RULE, DELETE_RULE, TABLE_NAME, REFERENCED_TABLE_NAME 262 from information_schema.REFERENTIAL_CONSTRAINTS 263 where CONSTRAINT_SCHEMA=? and TABLE_NAME=? and CONSTRAINT_NAME=? 264 ` 265 if err = db.SelectContext(ctx, &dbActions, db.Rebind(rawSql), schema, t, item.ConstraintName); err != nil { 266 panic(errors.Wrap(err, caller.NewCaller().String())) 267 } 268 if len(dbActions) > 0 { 269 dbAction = dbActions[0] 270 } 271 var rules []string 272 if stringutils.IsNotEmpty(dbAction.DeleteRule) { 273 rules = append(rules, fmt.Sprintf("ON DELETE %s", dbAction.DeleteRule)) 274 } 275 if stringutils.IsNotEmpty(dbAction.UpdateRule) { 276 rules = append(rules, fmt.Sprintf("ON UPDATE %s", dbAction.UpdateRule)) 277 } 278 var fullRule string 279 if len(rules) > 0 { 280 fullRule = strings.Join(rules, " ") 281 } 282 fks = append(fks, ForeignKey{ 283 Table: t, 284 Constraint: item.ConstraintName, 285 Fk: item.ColumnName, 286 ReferencedTable: item.ReferencedTableName, 287 ReferencedCol: item.ReferencedColumnName, 288 UpdateRule: dbAction.UpdateRule, 289 DeleteRule: dbAction.DeleteRule, 290 FullRule: fullRule, 291 }) 292 } 293 return 294 } 295 296 func idxListAndMap(idxMap map[string][]DbIndex) ([]Index, map[string][]IndexItem) { 297 var indexes []Index 298 colIdxMap := make(map[string][]IndexItem) 299 for k, v := range idxMap { 300 if len(v) == 0 { 301 continue 302 } 303 items := make([]IndexItem, len(v)) 304 for i, idx := range v { 305 var sor sortenum.Sort 306 if idx.Collation == "B" { 307 sor = sortenum.Desc 308 } else { 309 sor = sortenum.Asc 310 } 311 items[i] = IndexItem{ 312 Unique: !v[0].NonUnique, 313 Name: k, 314 Column: idx.ColumnName, 315 Order: idx.SeqInIndex, 316 Sort: sor, 317 } 318 if val, exists := colIdxMap[idx.ColumnName]; exists { 319 val = append(val, items[i]) 320 colIdxMap[idx.ColumnName] = val 321 } else { 322 colIdxMap[idx.ColumnName] = []IndexItem{ 323 items[i], 324 } 325 } 326 } 327 indexes = append(indexes, Index{ 328 Unique: !v[0].NonUnique, 329 Name: k, 330 Items: items, 331 }) 332 } 333 return indexes, colIdxMap 334 } 335 336 func dbColumn2Column(item DbColumn, colIdxMap map[string][]IndexItem, t string, fk ForeignKey) Column { 337 extra := item.Extra 338 if strings.Contains(extra, "auto_increment") { 339 extra = "" 340 } 341 extra = strings.TrimSpace(strings.TrimPrefix(extra, "DEFAULT_GENERATED")) 342 if stringutils.IsNotEmpty(item.Comment) { 343 extra += fmt.Sprintf(" comment '%s'", item.Comment) 344 } 345 extra = strings.TrimSpace(extra) 346 var defaultVal string 347 if item.Default != nil { 348 defaultVal = *item.Default 349 } 350 col := Column{ 351 Table: t, 352 Name: item.Field, 353 Type: columnenum.ColumnType(item.Type), 354 Default: defaultVal, 355 Pk: CheckPk(item.Key), 356 Nullable: CheckNull(item.Null), 357 Unsigned: CheckUnsigned(item.Type), 358 Autoincrement: CheckAutoincrement(item.Extra), 359 Extra: extraenum.Extra(extra), 360 AutoSet: CheckAutoSet(defaultVal), 361 Indexes: colIdxMap[item.Field], 362 Fk: fk, 363 } 364 col.Meta = NewFieldFromColumn(col) 365 return col 366 } 367 368 func Struct2Table(ctx context.Context, dir, pre string, existTables []string, db *sqlx.DB, schema string) (tables []Table) { 369 var ( 370 files []string 371 err error 372 tx *sqlx.Tx 373 root *ast.File 374 ) 375 if err = filepath.Walk(dir, astutils.Visit(&files)); err != nil { 376 panic(errors.Wrap(err, caller.NewCaller().String())) 377 } 378 sc := astutils.NewStructCollector(astutils.ExprString) 379 for _, file := range files { 380 fset := token.NewFileSet() 381 if root, err = parser.ParseFile(fset, file, nil, parser.ParseComments); err != nil { 382 panic(errors.Wrap(err, caller.NewCaller().String())) 383 } 384 ast.Walk(sc, root) 385 } 386 387 flattened := ddlast.FlatEmbed(sc.Structs) 388 for _, sm := range flattened { 389 tables = append(tables, NewTableFromStruct(sm, pre)) 390 } 391 392 if tx, err = db.BeginTxx(ctx, nil); err != nil { 393 panic(errors.Wrap(err, caller.NewCaller().String())) 394 } 395 defer func() { 396 if r := recover(); r != nil { 397 if _err := tx.Rollback(); _err != nil { 398 err = errors.Wrap(_err, "") 399 } 400 panic(errors.Wrap(err, caller.NewCaller().String())) 401 } 402 }() 403 404 if _, err = tx.ExecContext(ctx, `SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0;`); err != nil { 405 panic(errors.Wrap(err, caller.NewCaller().String())) 406 } 407 408 for _, t := range tables { 409 if sliceutils.StringContains(existTables, t.Name) { 410 var columns []DbColumn 411 if err = tx.SelectContext(ctx, &columns, fmt.Sprintf("desc %s", t.Name)); err != nil { 412 panic(errors.Wrap(err, caller.NewCaller().String())) 413 } 414 var existColumnNames []interface{} 415 for _, dbCol := range columns { 416 existColumnNames = append(existColumnNames, dbCol.Field) 417 } 418 existColSet := mapset.NewSetFromSlice(existColumnNames) 419 420 for _, col := range t.Columns { 421 if existColSet.Contains(col.Name) { 422 if err = ChangeColumn(ctx, tx, col); err != nil { 423 panic(errors.Wrap(err, caller.NewCaller().String())) 424 } 425 } else { 426 if err = AddColumn(ctx, tx, col); err != nil { 427 panic(errors.Wrap(err, caller.NewCaller().String())) 428 } 429 } 430 } 431 fks := foreignKeys(ctx, tx, schema, t.Name) 432 updateIndexFromStruct(ctx, tx, t, fks) 433 updateFkFromStruct(ctx, tx, t, fks) 434 } else { 435 if err = CreateTable(ctx, tx, t); err != nil { 436 panic(errors.Wrap(err, caller.NewCaller().String())) 437 } 438 } 439 } 440 441 if _, err = tx.ExecContext(ctx, `SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS;`); err != nil { 442 panic(errors.Wrap(err, caller.NewCaller().String())) 443 } 444 _ = tx.Commit() 445 return 446 } 447 448 func updateFkFromStruct(ctx context.Context, tx *sqlx.Tx, t Table, fks []ForeignKey) { 449 fkMap := make(map[string]ForeignKey) 450 for _, fk := range fks { 451 fkMap[fk.Constraint] = fk 452 } 453 for _, fk := range t.Fks { 454 if current, exists := fkMap[fk.Constraint]; exists { 455 current.DeleteRule = "" 456 current.UpdateRule = "" 457 fk.DeleteRule = "" 458 fk.UpdateRule = "" 459 if reflect.DeepEqual(fk, current) { 460 continue 461 } 462 if err := dropAddFk(ctx, tx, fk); err != nil { 463 panic(errors.Wrap(err, caller.NewCaller().String())) 464 } 465 } else { 466 if err := addFk(ctx, tx, fk); err != nil { 467 panic(errors.Wrap(err, caller.NewCaller().String())) 468 } 469 } 470 } 471 472 var constraints []string 473 for _, fk := range t.Fks { 474 constraints = append(constraints, fk.Constraint) 475 } 476 for k, v := range fkMap { 477 if !sliceutils.StringContains(constraints, k) { 478 if err := dropFk(ctx, tx, v); err != nil { 479 panic(errors.Wrap(err, caller.NewCaller().String())) 480 } 481 } 482 } 483 } 484 485 func updateIndexFromStruct(ctx context.Context, tx *sqlx.Tx, t Table, fks []ForeignKey) { 486 var dbIndexes []DbIndex 487 if err := tx.SelectContext(ctx, &dbIndexes, fmt.Sprintf("SHOW INDEXES FROM %s", t.Name)); err != nil { 488 panic(errors.Wrap(err, caller.NewCaller().String())) 489 } 490 491 keyIndexMap := make(map[string][]DbIndex) 492 for _, index := range dbIndexes { 493 if index.KeyName == "PRIMARY" { 494 continue 495 } 496 if val, exists := keyIndexMap[index.KeyName]; exists { 497 val = append(val, index) 498 keyIndexMap[index.KeyName] = val 499 } else { 500 keyIndexMap[index.KeyName] = []DbIndex{index} 501 } 502 } 503 504 for _, idx := range t.Indexes { 505 if current, exists := keyIndexMap[idx.Name]; exists { 506 copied := NewIndexFromDbIndexes(current) 507 if reflect.DeepEqual(idx, copied) { 508 continue 509 } 510 idx.Table = t.Name 511 if err := dropAddIndex(ctx, tx, idx); err != nil { 512 panic(errors.Wrap(err, caller.NewCaller().String())) 513 } 514 } else { 515 idx.Table = t.Name 516 if err := addIndex(ctx, tx, idx); err != nil { 517 panic(errors.Wrap(err, caller.NewCaller().String())) 518 } 519 } 520 } 521 522 var idxKeys []string 523 for _, idx := range t.Indexes { 524 idxKeys = append(idxKeys, idx.Name) 525 } 526 for k, v := range keyIndexMap { 527 if !sliceutils.StringContains(idxKeys, k) { 528 shouldDrop := true 529 if len(v) == 1 { 530 idx := v[0] 531 for _, fk := range fks { 532 if fk.Table == idx.Table && fk.Fk == idx.ColumnName { 533 shouldDrop = false 534 break 535 } 536 } 537 } 538 if shouldDrop { 539 idx := NewIndexFromDbIndexes(v) 540 idx.Table = t.Name 541 if err := dropIndex(ctx, tx, idx); err != nil { 542 panic(errors.Wrap(err, caller.NewCaller().String())) 543 } 544 } 545 } 546 } 547 } 548 549 func Setup() (func(), *sqlx.DB, error) { 550 var terminateContainer func() // variable to store function to terminate container 551 var host string 552 var port int 553 var err error 554 terminateContainer, host, port, err = setupMySQLContainer(zlogger.Logger, pathutils.Abs("../testdata/sql"), "") 555 if err != nil { 556 return nil, nil, errors.Wrap(err, "failed to setup MySQL container") 557 } 558 os.Setenv("DB_HOST", host) 559 os.Setenv("DB_PORT", fmt.Sprint(port)) 560 os.Setenv("DB_USER", "root") 561 os.Setenv("DB_PASSWD", "1234") 562 os.Setenv("DB_SCHEMA", "test") 563 os.Setenv("DB_CHARSET", "utf8mb4") 564 var conf config.DbConfig 565 err = envconfig.Process("db", &conf) 566 if err != nil { 567 return nil, nil, errors.Wrap(err, "[go-doudou] Error processing env") 568 } 569 conn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s", 570 conf.User, 571 conf.Passwd, 572 conf.Host, 573 conf.Port, 574 conf.Schema, 575 conf.Charset) 576 conn += `&loc=Asia%2FShanghai&parseTime=True` 577 var db *sqlx.DB 578 db, err = sqlx.Connect("mysql", conn) 579 if err != nil { 580 return nil, nil, errors.Wrap(err, caller.NewCaller().String()) 581 } 582 db.MapperFunc(strcase.ToSnake) 583 db = db.Unsafe() 584 return terminateContainer, db, nil 585 } 586 587 func setupMySQLContainer(logger zerolog.Logger, initdb string, dbname string) (func(), string, int, error) { 588 logger.Info().Msg("setup MySQL Container") 589 ctx := context.Background() 590 if stringutils.IsEmpty(dbname) { 591 dbname = "test" 592 } 593 req := testcontainers.ContainerRequest{ 594 Image: "mysql:latest", 595 ExposedPorts: []string{"3306/tcp", "33060/tcp"}, 596 Env: map[string]string{ 597 "MYSQL_ROOT_PASSWORD": "1234", 598 "MYSQL_DATABASE": dbname, 599 }, 600 BindMounts: map[string]string{ 601 initdb: "/docker-entrypoint-initdb.d", 602 }, 603 WaitingFor: wait.ForLog("port: 3306 MySQL Community Server - GPL").WithStartupTimeout(60 * time.Second), 604 } 605 606 mysqlC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 607 ContainerRequest: req, 608 Started: true, 609 }) 610 611 if err != nil { 612 logger.Error().Msgf("error starting mysql container: %s", err) 613 panic(fmt.Sprintf("%v", err)) 614 } 615 616 closeContainer := func() { 617 logger.Info().Msg("terminating container") 618 err := mysqlC.Terminate(ctx) 619 if err != nil { 620 logger.Error().Msgf("error terminating mysql container: %s", err) 621 panic(fmt.Sprintf("%v", err)) 622 } 623 } 624 625 host, _ := mysqlC.Host(ctx) 626 p, _ := mysqlC.MappedPort(ctx, "3306/tcp") 627 port := p.Int() 628 629 return closeContainer, host, port, nil 630 }