github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/gormgen/internal/generate/table.go (about)

     1  package generate
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  
     7  	"gorm.io/gorm"
     8  
     9  	"github.com/unionj-cloud/go-doudou/v2/toolkit/gormgen/internal/model"
    10  )
    11  
    12  // ITableInfo table info interface
    13  type ITableInfo interface {
    14  	GetTableColumns(schemaName string, tableName string) (result []*model.Column, err error)
    15  
    16  	GetTableIndex(schemaName string, tableName string) (indexes []gorm.Index, err error)
    17  }
    18  
    19  func getTableInfo(db *gorm.DB) ITableInfo {
    20  	return &tableInfo{db}
    21  }
    22  
    23  func getTableColumns(db *gorm.DB, schemaName string, tableName string, indexTag bool) (result []*model.Column, err error) {
    24  	if db == nil {
    25  		return nil, errors.New("gorm db is nil")
    26  	}
    27  
    28  	mt := getTableInfo(db)
    29  	result, err = mt.GetTableColumns(schemaName, tableName)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  	if !indexTag || len(result) == 0 {
    34  		return result, nil
    35  	}
    36  
    37  	index, err := mt.GetTableIndex(schemaName, tableName)
    38  	if err != nil { //ignore find index err
    39  		db.Logger.Warn(context.Background(), "GetTableIndex for %s,err=%s", tableName, err.Error())
    40  		return result, nil
    41  	}
    42  	if len(index) == 0 {
    43  		return result, nil
    44  	}
    45  
    46  	im := model.GroupByColumn(index)
    47  	for _, c := range result {
    48  		c.Indexes = im[c.Name()]
    49  	}
    50  	return result, nil
    51  }
    52  
    53  type tableInfo struct{ *gorm.DB }
    54  
    55  // GetTableColumns  struct
    56  func (t *tableInfo) GetTableColumns(schemaName string, tableName string) (result []*model.Column, err error) {
    57  	types, err := t.Migrator().ColumnTypes(tableName)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	for _, column := range types {
    62  		result = append(result, &model.Column{ColumnType: column, TableName: tableName, UseScanType: t.Dialector.Name() != "mysql" && t.Dialector.Name() != "sqlite"})
    63  	}
    64  	return result, nil
    65  }
    66  
    67  // GetTableIndex  index
    68  func (t *tableInfo) GetTableIndex(schemaName string, tableName string) (indexes []gorm.Index, err error) {
    69  	return t.Migrator().GetIndexes(tableName)
    70  }