github.com/unionj-cloud/go-doudou@v1.3.8-0.20221011095552-0088008e5b31/cmd/internal/ddl/table/table_test.go (about) 1 package table 2 3 import ( 4 "fmt" 5 "github.com/unionj-cloud/go-doudou/cmd/internal/astutils" 6 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/columnenum" 7 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/ddlast" 8 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/extraenum" 9 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/keyenum" 10 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/nullenum" 11 "github.com/unionj-cloud/go-doudou/cmd/internal/ddl/sortenum" 12 "github.com/unionj-cloud/go-doudou/toolkit/pathutils" 13 "go/ast" 14 "go/parser" 15 "go/token" 16 "path/filepath" 17 "reflect" 18 "testing" 19 ) 20 21 func ExampleNewTableFromStruct() { 22 testDir := pathutils.Abs("../testdata/domain") 23 var files []string 24 var err error 25 err = filepath.Walk(testDir, astutils.Visit(&files)) 26 if err != nil { 27 panic(err) 28 } 29 sc := astutils.NewStructCollector(astutils.ExprString) 30 for _, file := range files { 31 fset := token.NewFileSet() 32 root, err := parser.ParseFile(fset, file, nil, parser.ParseComments) 33 if err != nil { 34 panic(err) 35 } 36 ast.Walk(sc, root) 37 } 38 flattened := ddlast.FlatEmbed(sc.Structs) 39 40 for _, sm := range flattened { 41 tab := NewTableFromStruct(sm) 42 fmt.Println(len(tab.Indexes)) 43 var statement string 44 if statement, err = tab.CreateSql(); err != nil { 45 panic(err) 46 } 47 fmt.Println(statement) 48 } 49 50 // Output: 51 //0 52 //CREATE TABLE `order` ( 53 //`id` INT NOT NULL AUTO_INCREMENT, 54 //`amount` BIGINT NOT NULL, 55 //`user_id` int NOT NULL, 56 //`create_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, 57 //`delete_at` DATETIME NULL, 58 //`update_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 59 //PRIMARY KEY (`id`), 60 //CONSTRAINT `fk_ddl_user` FOREIGN KEY (`user_id`) 61 //REFERENCES `ddl_user`(`id`) 62 //ON DELETE CASCADE ON UPDATE NO ACTION) 63 //5 64 //CREATE TABLE `user` ( 65 //`id` INT NOT NULL AUTO_INCREMENT, 66 //`name` VARCHAR(255) NOT NULL DEFAULT 'jack', 67 //`phone` VARCHAR(255) NOT NULL DEFAULT '13552053960' comment '手机号', 68 //`age` INT NOT NULL, 69 //`no` int NOT NULL, 70 //`unique_col` int NOT NULL, 71 //`unique_col_2` int NOT NULL, 72 //`school` VARCHAR(255) NULL DEFAULT 'harvard' comment '学校', 73 //`is_student` TINYINT NOT NULL, 74 //`rule` varchar(255) NOT NULL comment '链接匹配规则,匹配的链接采用该css规则来爬', 75 //`rule_type` varchar(45) NOT NULL comment '链接匹配规则类型,支持prefix前缀匹配和regex正则匹配', 76 //`arrive_at` datetime NULL comment '到货时间', 77 //`status` tinyint(4) NOT NULL comment '0进行中 78 //1完结 79 //2取消', 80 //`create_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, 81 //`delete_at` DATETIME NULL, 82 //`update_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 83 //PRIMARY KEY (`id`), 84 //INDEX `age_idx` (`age` asc), 85 //INDEX `name_phone_idx` (`phone` asc,`name` asc), 86 //UNIQUE INDEX `no_idx` (`no` asc), 87 //UNIQUE INDEX `rule_idx` (`rule` asc), 88 //UNIQUE INDEX `unique_col_idx` (`unique_col` asc,`unique_col_2` asc)) 89 } 90 91 func TestTable_CreateSql(t1 *testing.T) { 92 type fields struct { 93 Name string 94 Columns []Column 95 Pk string 96 UniqueIndexes []Index 97 Indexes []Index 98 } 99 tests := []struct { 100 name string 101 fields fields 102 want string 103 wantErr bool 104 }{ 105 { 106 name: "users", 107 fields: fields{ 108 Name: "users", 109 Columns: []Column{ 110 { 111 Name: "id", 112 Type: columnenum.IntType, 113 Default: "", 114 Pk: true, 115 Nullable: false, 116 Unsigned: false, 117 Autoincrement: true, 118 Extra: "", 119 }, 120 { 121 Name: "name", 122 Type: columnenum.VarcharType, 123 Default: "'wubin'", 124 Pk: false, 125 Nullable: true, 126 Unsigned: false, 127 Autoincrement: false, 128 Extra: "", 129 }, 130 { 131 Name: "phone", 132 Type: columnenum.VarcharType, 133 Default: "'13552053960'", 134 Pk: false, 135 Nullable: true, 136 Unsigned: false, 137 Autoincrement: false, 138 Extra: "comment '手机号'", 139 }, 140 { 141 Name: "age", 142 Type: columnenum.IntType, 143 Default: "", 144 Pk: false, 145 Nullable: true, 146 Unsigned: false, 147 Autoincrement: false, 148 Extra: "", 149 }, 150 { 151 Name: "no", 152 Type: columnenum.IntType, 153 Default: "", 154 Pk: false, 155 Nullable: false, 156 Unsigned: false, 157 Autoincrement: false, 158 Extra: "", 159 }, 160 }, 161 Pk: "id", 162 Indexes: []Index{ 163 { 164 Name: "name_phone_idx", 165 Items: []IndexItem{ 166 { 167 Column: "name", 168 Order: 2, 169 Sort: "asc", 170 }, 171 { 172 Column: "phone", 173 Order: 1, 174 Sort: "desc", 175 }, 176 }, 177 }, 178 { 179 Unique: true, 180 Name: "uni_no", 181 Items: []IndexItem{ 182 { 183 Column: "no", 184 Order: 0, 185 Sort: "asc", 186 }, 187 }, 188 }, 189 }, 190 }, 191 want: "CREATE TABLE `users` (\n`id` INT NOT NULL AUTO_INCREMENT,\n`name` VARCHAR(255) NULL DEFAULT 'wubin',\n`phone` VARCHAR(255) NULL DEFAULT '13552053960' comment '手机号',\n`age` INT NULL,\n`no` INT NOT NULL,\nPRIMARY KEY (`id`),\nINDEX `name_phone_idx` (`name` asc,`phone` desc),\nUNIQUE INDEX `uni_no` (`no` asc))", 192 wantErr: false, 193 }, 194 } 195 for _, tt := range tests { 196 t1.Run(tt.name, func(t1 *testing.T) { 197 t := &Table{ 198 Name: tt.fields.Name, 199 Columns: tt.fields.Columns, 200 Pk: tt.fields.Pk, 201 Indexes: tt.fields.Indexes, 202 } 203 got, err := t.CreateSql() 204 fmt.Println(got) 205 if (err != nil) != tt.wantErr { 206 t1.Errorf("CreateSql() error = %v, wantErr %v", err, tt.wantErr) 207 return 208 } 209 if got != tt.want { 210 t1.Errorf("CreateSql() got = %v, want %v", got, tt.want) 211 } 212 }) 213 } 214 } 215 216 func TestColumn_AlterColumnSql(t *testing.T) { 217 type fields struct { 218 Table string 219 Name string 220 Type columnenum.ColumnType 221 Default string 222 Pk bool 223 Nullable bool 224 Unsigned bool 225 Autoincrement bool 226 Extra extraenum.Extra 227 } 228 tests := []struct { 229 name string 230 fields fields 231 want string 232 wantErr bool 233 }{ 234 { 235 name: "column", 236 fields: fields{ 237 Table: "users", 238 Name: "phone", 239 Type: columnenum.VarcharType, 240 Default: "'13552053960'", 241 Pk: false, 242 Nullable: false, 243 Unsigned: false, 244 Autoincrement: false, 245 Extra: "comment '手机号'", 246 }, 247 want: "ALTER TABLE `users`\nCHANGE COLUMN `phone` `phone` VARCHAR(255) NOT NULL DEFAULT '13552053960' comment '手机号';", 248 wantErr: false, 249 }, 250 } 251 for _, tt := range tests { 252 t.Run(tt.name, func(t *testing.T) { 253 c := &Column{ 254 Table: tt.fields.Table, 255 Name: tt.fields.Name, 256 Type: tt.fields.Type, 257 Default: tt.fields.Default, 258 Pk: tt.fields.Pk, 259 Nullable: tt.fields.Nullable, 260 Unsigned: tt.fields.Unsigned, 261 Autoincrement: tt.fields.Autoincrement, 262 Extra: tt.fields.Extra, 263 } 264 got, err := c.ChangeColumnSql() 265 fmt.Println(got) 266 if (err != nil) != tt.wantErr { 267 t.Errorf("ChangeColumnSql() error = %v, wantErr %v", err, tt.wantErr) 268 return 269 } 270 if got != tt.want { 271 t.Errorf("ChangeColumnSql() got = %v, want %v", got, tt.want) 272 } 273 }) 274 } 275 } 276 277 func TestColumn_AddColumnSql(t *testing.T) { 278 type fields struct { 279 Table string 280 Name string 281 Type columnenum.ColumnType 282 Default string 283 Pk bool 284 Nullable bool 285 Unsigned bool 286 Autoincrement bool 287 Extra extraenum.Extra 288 } 289 tests := []struct { 290 name string 291 fields fields 292 want string 293 wantErr bool 294 }{ 295 { 296 name: "column", 297 fields: fields{ 298 Table: "users", 299 Name: "school", 300 Type: columnenum.VarcharType, 301 Default: "'harvard'", 302 Pk: false, 303 Nullable: false, 304 Unsigned: false, 305 Autoincrement: false, 306 Extra: "comment '学校'", 307 }, 308 want: "ALTER TABLE `users`\nADD COLUMN `school` VARCHAR(255) NOT NULL DEFAULT 'harvard' comment '学校';", 309 wantErr: false, 310 }, 311 } 312 for _, tt := range tests { 313 t.Run(tt.name, func(t *testing.T) { 314 c := &Column{ 315 Table: tt.fields.Table, 316 Name: tt.fields.Name, 317 Type: tt.fields.Type, 318 Default: tt.fields.Default, 319 Pk: tt.fields.Pk, 320 Nullable: tt.fields.Nullable, 321 Unsigned: tt.fields.Unsigned, 322 Autoincrement: tt.fields.Autoincrement, 323 Extra: tt.fields.Extra, 324 } 325 got, err := c.AddColumnSql() 326 fmt.Println(got) 327 if (err != nil) != tt.wantErr { 328 t.Errorf("AddColumnSql() error = %v, wantErr %v", err, tt.wantErr) 329 return 330 } 331 if got != tt.want { 332 t.Errorf("AddColumnSql() got = %v, want %v", got, tt.want) 333 } 334 }) 335 } 336 } 337 338 func Test_toColumnType(t *testing.T) { 339 type args struct { 340 goType string 341 } 342 tests := []struct { 343 name string 344 args args 345 want columnenum.ColumnType 346 }{ 347 { 348 name: "1", 349 args: args{ 350 goType: "float32", 351 }, 352 want: columnenum.FloatType, 353 }, { 354 name: "2", 355 args: args{ 356 goType: "int", 357 }, 358 want: columnenum.IntType, 359 }, { 360 name: "3", 361 args: args{ 362 goType: "bool", 363 }, 364 want: columnenum.TinyintType, 365 }, { 366 name: "4", 367 args: args{ 368 goType: "time.Time", 369 }, 370 want: columnenum.DatetimeType, 371 }, { 372 name: "5", 373 args: args{ 374 goType: "int64", 375 }, 376 want: columnenum.BigintType, 377 }, { 378 name: "6", 379 args: args{ 380 goType: "float64", 381 }, 382 want: columnenum.DoubleType, 383 }, { 384 name: "7", 385 args: args{ 386 goType: "string", 387 }, 388 want: columnenum.VarcharType, 389 }, { 390 name: "8", 391 args: args{ 392 goType: "decimal.Decimal", 393 }, 394 want: "decimal(6,2)", 395 }, 396 } 397 for _, tt := range tests { 398 t.Run(tt.name, func(t *testing.T) { 399 if got := toColumnType(tt.args.goType); got != tt.want { 400 t.Errorf("toColumnType() = %v, want %v", got, tt.want) 401 } 402 }) 403 } 404 } 405 406 func Test_toGoType(t *testing.T) { 407 type args struct { 408 colType columnenum.ColumnType 409 nullable bool 410 } 411 tests := []struct { 412 name string 413 args args 414 want string 415 }{ 416 { 417 name: "1", 418 args: args{ 419 colType: "int", 420 nullable: false, 421 }, 422 want: "int", 423 }, 424 { 425 name: "2", 426 args: args{ 427 colType: "bigint", 428 nullable: false, 429 }, 430 want: "int64", 431 }, 432 { 433 name: "3", 434 args: args{ 435 colType: "float", 436 nullable: false, 437 }, 438 want: "float32", 439 }, 440 { 441 name: "4", 442 args: args{ 443 colType: "double", 444 nullable: false, 445 }, 446 want: "float64", 447 }, 448 { 449 name: "5", 450 args: args{ 451 colType: "varchar", 452 nullable: false, 453 }, 454 want: "string", 455 }, 456 { 457 name: "6", 458 args: args{ 459 colType: "tinyint", 460 nullable: false, 461 }, 462 want: "int8", 463 }, 464 { 465 name: "7", 466 args: args{ 467 colType: "text", 468 nullable: false, 469 }, 470 want: "string", 471 }, 472 { 473 name: "8", 474 args: args{ 475 colType: "datetime", 476 nullable: false, 477 }, 478 want: "time.Time", 479 }, 480 } 481 for _, tt := range tests { 482 t.Run(tt.name, func(t *testing.T) { 483 if got := toGoType(tt.args.colType, tt.args.nullable); got != tt.want { 484 t.Errorf("toGoType() = %v, want %v", got, tt.want) 485 } 486 }) 487 } 488 } 489 490 func TestNewFieldFromColumn(t *testing.T) { 491 type args struct { 492 col Column 493 } 494 tests := []struct { 495 name string 496 args args 497 want astutils.FieldMeta 498 }{ 499 { 500 name: "1", 501 args: args{ 502 col: Column{ 503 Table: "users", 504 Name: "school", 505 Type: columnenum.VarcharType, 506 Default: "harvard", 507 Pk: false, 508 Nullable: false, 509 Unsigned: false, 510 Autoincrement: false, 511 Extra: "comment '学校'", 512 }, 513 }, 514 want: astutils.FieldMeta{ 515 Name: "School", 516 Type: "string", 517 Tag: `dd:"type:VARCHAR(255);default:'harvard';extra:comment '学校'"`, 518 Comments: nil, 519 }, 520 }, { 521 name: "2", 522 args: args{ 523 col: Column{ 524 Table: "users", 525 Name: "favourite", 526 Type: columnenum.VarcharType, 527 Default: "current_timestamp", 528 Pk: false, 529 Nullable: true, 530 Unsigned: false, 531 Autoincrement: true, 532 Extra: "comment '学校'", 533 Indexes: IndexItems{ 534 { 535 Unique: false, 536 Name: "my_index", 537 Column: "favourite", 538 Order: 1, 539 Sort: sortenum.Asc, 540 }, 541 }, 542 }, 543 }, 544 want: astutils.FieldMeta{ 545 Name: "Favourite", 546 Type: "*string", 547 Tag: `dd:"auto;type:VARCHAR(255);default:current_timestamp;extra:comment '学校';index:my_index,1,asc"`, 548 Comments: nil, 549 }, 550 }, 551 } 552 for _, tt := range tests { 553 t.Run(tt.name, func(t *testing.T) { 554 if got := NewFieldFromColumn(tt.args.col); !reflect.DeepEqual(got, tt.want) { 555 t.Errorf("NewFieldFromColumn() = %v, want %v", got, tt.want) 556 } 557 }) 558 } 559 } 560 561 func TestCheckPk(t *testing.T) { 562 type args struct { 563 key keyenum.Key 564 } 565 tests := []struct { 566 name string 567 args args 568 want bool 569 }{ 570 { 571 name: "", 572 args: args{ 573 key: keyenum.Pri, 574 }, 575 want: true, 576 }, 577 } 578 for _, tt := range tests { 579 t.Run(tt.name, func(t *testing.T) { 580 if got := CheckPk(tt.args.key); got != tt.want { 581 t.Errorf("CheckPk() = %v, want %v", got, tt.want) 582 } 583 }) 584 } 585 } 586 587 func TestCheckNull(t *testing.T) { 588 type args struct { 589 null nullenum.Null 590 } 591 tests := []struct { 592 name string 593 args args 594 want bool 595 }{ 596 { 597 name: "", 598 args: args{ 599 null: nullenum.Yes, 600 }, 601 want: true, 602 }, 603 } 604 for _, tt := range tests { 605 t.Run(tt.name, func(t *testing.T) { 606 if got := CheckNull(tt.args.null); got != tt.want { 607 t.Errorf("CheckNull() = %v, want %v", got, tt.want) 608 } 609 }) 610 } 611 } 612 613 func TestCheckUnsigned(t *testing.T) { 614 type args struct { 615 dbColType string 616 } 617 tests := []struct { 618 name string 619 args args 620 want bool 621 }{ 622 { 623 name: "", 624 args: args{ 625 dbColType: "int unsigned", 626 }, 627 want: true, 628 }, 629 } 630 for _, tt := range tests { 631 t.Run(tt.name, func(t *testing.T) { 632 if got := CheckUnsigned(tt.args.dbColType); got != tt.want { 633 t.Errorf("CheckUnsigned() = %v, want %v", got, tt.want) 634 } 635 }) 636 } 637 } 638 639 func TestCheckAutoincrement(t *testing.T) { 640 type args struct { 641 extra string 642 } 643 tests := []struct { 644 name string 645 args args 646 want bool 647 }{ 648 { 649 name: "", 650 args: args{ 651 extra: "auto_increment", 652 }, 653 want: true, 654 }, 655 } 656 for _, tt := range tests { 657 t.Run(tt.name, func(t *testing.T) { 658 if got := CheckAutoincrement(tt.args.extra); got != tt.want { 659 t.Errorf("CheckAutoincrement() = %v, want %v", got, tt.want) 660 } 661 }) 662 } 663 } 664 665 func TestCheckAutoSet(t *testing.T) { 666 type args struct { 667 defaultVal string 668 } 669 tests := []struct { 670 name string 671 args args 672 want bool 673 }{ 674 { 675 name: "", 676 args: args{ 677 defaultVal: "CURRENT_TIMESTAMP", 678 }, 679 want: true, 680 }, 681 } 682 for _, tt := range tests { 683 t.Run(tt.name, func(t *testing.T) { 684 if got := CheckAutoSet(tt.args.defaultVal); got != tt.want { 685 t.Errorf("CheckAutoSet() = %v, want %v", got, tt.want) 686 } 687 }) 688 } 689 } 690 691 func TestNewIndexFromDbIndexes(t *testing.T) { 692 type args struct { 693 dbIndexes []DbIndex 694 } 695 tests := []struct { 696 name string 697 args args 698 want Index 699 }{ 700 { 701 name: "", 702 args: args{ 703 dbIndexes: []DbIndex{ 704 { 705 Table: "ddl_user", 706 NonUnique: false, 707 KeyName: "age_idx", 708 SeqInIndex: 1, 709 ColumnName: "age", 710 Collation: "A", 711 }, 712 }, 713 }, 714 want: Index{ 715 Unique: true, 716 Name: "age_idx", 717 Items: []IndexItem{ 718 { 719 Column: "age", 720 Order: 1, 721 Sort: sortenum.Asc, 722 }, 723 }, 724 }, 725 }, 726 } 727 for _, tt := range tests { 728 t.Run(tt.name, func(t *testing.T) { 729 if got := NewIndexFromDbIndexes(tt.args.dbIndexes); !reflect.DeepEqual(got, tt.want) { 730 t.Errorf("NewIndexFromDbIndexes() = %v, want %v", got, tt.want) 731 } 732 }) 733 } 734 } 735 736 func TestIndex_DropIndexSql(t *testing.T) { 737 type fields struct { 738 Table string 739 Unique bool 740 Name string 741 Items []IndexItem 742 } 743 tests := []struct { 744 name string 745 fields fields 746 want string 747 wantErr bool 748 }{ 749 { 750 name: "", 751 fields: fields{ 752 Table: "ddl_user", 753 Unique: true, 754 Name: "age_idx", 755 Items: []IndexItem{ 756 { 757 Column: "age", 758 Order: 1, 759 Sort: sortenum.Asc, 760 }, 761 }, 762 }, 763 want: "ALTER TABLE `ddl_user` DROP INDEX `age_idx`;", 764 wantErr: false, 765 }, 766 } 767 for _, tt := range tests { 768 t.Run(tt.name, func(t *testing.T) { 769 idx := &Index{ 770 Table: tt.fields.Table, 771 Unique: tt.fields.Unique, 772 Name: tt.fields.Name, 773 Items: tt.fields.Items, 774 } 775 got, err := idx.DropIndexSql() 776 if (err != nil) != tt.wantErr { 777 t.Errorf("DropIndexSql() error = %v, wantErr %v", err, tt.wantErr) 778 return 779 } 780 if got != tt.want { 781 t.Errorf("DropIndexSql() got = %v, want %v", got, tt.want) 782 } 783 }) 784 } 785 } 786 787 func TestIndex_AddIndexSql(t *testing.T) { 788 type fields struct { 789 Table string 790 Unique bool 791 Name string 792 Items []IndexItem 793 } 794 tests := []struct { 795 name string 796 fields fields 797 want string 798 wantErr bool 799 }{ 800 { 801 name: "", 802 fields: fields{ 803 Table: "ddl_user", 804 Unique: true, 805 Name: "age_idx", 806 Items: []IndexItem{ 807 { 808 Column: "age", 809 Order: 1, 810 Sort: sortenum.Asc, 811 }, 812 }, 813 }, 814 want: "ALTER TABLE `ddl_user` ADD UNIQUE INDEX `age_idx` (`age` asc);", 815 wantErr: false, 816 }, 817 } 818 for _, tt := range tests { 819 t.Run(tt.name, func(t *testing.T) { 820 idx := &Index{ 821 Table: tt.fields.Table, 822 Unique: tt.fields.Unique, 823 Name: tt.fields.Name, 824 Items: tt.fields.Items, 825 } 826 got, err := idx.AddIndexSql() 827 if (err != nil) != tt.wantErr { 828 t.Errorf("AddIndexSql() error = %v, wantErr %v", err, tt.wantErr) 829 return 830 } 831 if got != tt.want { 832 t.Errorf("AddIndexSql() got = %v, want %v", got, tt.want) 833 } 834 }) 835 } 836 }